dim/server.go

250 lines
7.2 KiB
Go

package main
import (
"crypto/rand"
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
)
const (
LevelFatal = `FATAL`
LevelError = `ERROR`
LevelInfo = `INFO`
)
type (
// Server is the central handler of all incoming requests. It generates
// the context to call handlers, which then process the requests.
Server struct {
db *sql.DB
routes map[string]Handler
debug bool
}
// Handler is a function receiving a Context to process a request.
// It is expected to investigate the Request and fill the Result.
Handler func(c *Context, req Request, resp *Response) error
// Context is the pre filled global context every handler receives.
// It contains a prepared transaction for usage and important details like
// the user account.
Context struct {
id string
req *http.Request
w http.ResponseWriter
debug bool // print debug output to the console
username string
tx *sql.Tx
}
// Request contains the method name and parameters requested by the client.
Request struct {
// Method is the name to route to the correct function.
Method string `json:"method"`
// Params is the list of parameters in the request. These can be
// read directly or parsed using any of the parse methods.
Params []json.RawMessage `json:"params"`
}
// Response can have messages and/or a result to return the to client.
Response struct {
ID string `json:"id"`
Messages map[string][]string `json:"messages,omitempty"`
Result interface{} `json:"result,omitempty"`
}
ident []byte
)
// NewServer creates a new server handler.
func NewServer(db *sql.DB, debug bool) (*Server, error) {
if db == nil {
return nil, fmt.Errorf("database connection is not set")
}
return &Server{
db: db,
routes: map[string]Handler{},
debug: debug,
}, nil
}
// Register takes a new handler which will be called when the name is called.
func (s *Server) Register(name string, handler Handler) {
if _, found := s.routes[name]; found {
log.Fatalf("route with name %s already exists", name)
}
s.routes[name] = handler
}
// Handle implements http.HandleFunc to serve content for the standard http
// interface.
func (s *Server) Handle(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("only POST requests allowed"))
return
}
id, err := newIdent()
if err != nil {
log.Printf("could not generate request id: %s", err)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("sorry, can't currently process your request"))
return
}
c := &Context{
id: id,
req: r,
w: w,
debug: s.debug,
username: "unknown",
}
req := Request{}
res := &Response{
ID: c.id,
Messages: map[string][]string{},
Result: map[string]interface{}{},
}
tx, err := s.db.Begin()
if err != nil {
c.Logf(LevelError, "could not create transaction: %s", err)
w.WriteHeader(http.StatusInternalServerError)
res.AddMessage(LevelFatal, "database problems occured")
c.render(res)
return
}
defer tx.Rollback() // make a rollback, just in case something goes wrong
c.tx = tx
// set username for transaction
// TODO check username to be ASCII and nothing else
_, err = c.tx.Exec(
fmt.Sprintf(`set local dim.username to '%s'; set local dim.transaction = '%s'`, c.username, c.id),
)
if err != nil {
c.Logf(LevelError, "could not set transaction username: %s", err)
w.WriteHeader(http.StatusInternalServerError)
res.AddMessage(LevelFatal, "could not create transaction")
c.render(res)
return
}
dec := json.NewDecoder(r.Body)
defer r.Body.Close()
if err := dec.Decode(&req); err != nil {
res.AddMessage(LevelError, fmt.Sprintf("could not parse payload: %s", err))
c.w.WriteHeader(http.StatusBadRequest)
c.render(res)
return
}
handler, found := s.routes[req.Method]
if !found {
res.AddMessage(LevelError, "method %s does not exist", req.Method)
c.w.WriteHeader(http.StatusNotFound)
c.render(res)
return
}
c.Logf(LevelInfo, "method '%s' called with '%s'", req.Method, req.Params)
if err := handler(c, req, res); err != nil {
c.w.WriteHeader(http.StatusInternalServerError)
c.Logf(LevelError, "method '%s' returned an error: %s", req.Method, err)
}
if err := tx.Commit(); err != nil {
c.Logf(LevelFatal, "could not commit changes: %s", err)
res.AddMessage(LevelError, "changes were not committed")
}
c.render(res)
}
// Render converts the Result to json and sends it back to the client.
func (c *Context) render(res *Response) {
enc := json.NewEncoder(c.w)
if err := enc.Encode(res); err != nil {
c.Logf(LevelError, "%s - could not encode result: %s\n%s", c.id, err, res)
return
}
}
// Logf logs a message to stdout, outfitted with the request ID.
func (c *Context) Logf(level, msg string, args ...interface{}) {
log.Printf("%s - %s - %s", c.id, level, fmt.Sprintf(msg, args...))
}
// Debugf logs output only when the server is set into debug mode.
func (c *Context) Debugf(level, msg string, args ...interface{}) {
if c.debug {
log.Printf("%s - %s - %s", c.id, level, fmt.Sprintf(msg, args...))
}
}
// Generate a useable request ID, so that it can be found in the logs.
func newIdent() (string, error) {
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return fmt.Sprintf("%X-%X-%X-%X-%X", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), nil
}
// AddMessage adds a new message to the result.
// When Render is called, these messages will be sent to the client.
func (r *Response) AddMessage(level string, msg string, args ...interface{}) {
r.Messages[level] = append(r.Messages[level], fmt.Sprintf(msg, args...))
}
// Len returns the length of the parameter list.
func (r Request) Len() int {
return len(r.Params)
}
// Parse will parse the exact number of parameters into arguments.
// The boundaries are checked before.
func (r Request) Parse(args ...interface{}) error {
num := len(args)
if r.Len() != num {
return fmt.Errorf("expected %d parameters, got %d", num, r.Len())
}
return r.ParseAtLeast(num, args...)
}
// ParseAtLeast parses at least *num* many parameters into arguments.
// This function checks the boundaries before beginning to parse and returns
// errors if any boundary does not match.
func (r Request) ParseAtLeast(num int, args ...interface{}) error {
if len(args) < num {
return fmt.Errorf("requested %d arguments to parse, but only %d arguments given", num, len(args))
}
if r.Len() < num {
return fmt.Errorf("need %d parameters, but only got %d", num, r.Len())
}
if r.Len() > len(args) {
return fmt.Errorf("found %d parameters, when only %d are required", r.Len(), len(args))
}
for i, param := range r.Params {
if err := json.Unmarshal(param, args[i]); err != nil {
return fmt.Errorf("argument at position %d can't be parsed: %v", i, err)
}
}
return nil
}
// ParseAt unmarshalls the argument at *pos* into the container.
// If the position is outside the size of incoming arguments an error is raised.
func (r Request) ParseAt(pos int, container interface{}) error {
if pos >= r.Len() {
return fmt.Errorf("out of bounds") // TODO make generic ErrOutOfBounds
}
if err := json.Unmarshal(r.Params[pos], container); err != nil {
return fmt.Errorf("could not unmarshal parameter '%d': %s", pos, err)
}
return nil
}