package main import ( "crypto/rand" "database/sql" "encoding/json" "fmt" "log" "net/http" ) const ( LevelFatal = `FATAL` LevelError = `ERROR` LevelInfo = `INFO` ) type ( // Server is the central handler of all incoming requests. It generates // the context to call handlers, which then process the requests. Server struct { db *sql.DB routes map[string]Handler debug bool } // Handler is a function receiving a Context to process a request. // It is expected to investigate the Request and fill the Result. Handler func(c *Context, req Request, resp *Response) error // Context is the pre filled global context every handler receives. // It contains a prepared transaction for usage and important details like // the user account. Context struct { id string req *http.Request w http.ResponseWriter debug bool // print debug output to the console username string tx *sql.Tx } // Request contains the method name and parameters requested by the client. Request struct { // Method is the name to route to the correct function. Method string `json:"method"` // Params is the list of parameters in the request. These can be // read directly or parsed using any of the parse methods. Params []json.RawMessage `json:"params"` } // Response can have messages and/or a result to return the to client. Response struct { ID string `json:"id"` Messages map[string][]string `json:"messages,omitempty"` Result interface{} `json:"result,omitempty"` } ident []byte ) // NewServer creates a new server handler. func NewServer(db *sql.DB, debug bool) (*Server, error) { if db == nil { return nil, fmt.Errorf("database connection is not set") } return &Server{ db: db, routes: map[string]Handler{}, debug: debug, }, nil } // Register takes a new handler which will be called when the name is called. func (s *Server) Register(name string, handler Handler) { if _, found := s.routes[name]; found { log.Fatalf("route with name %s already exists", name) } s.routes[name] = handler } // Handle implements http.HandleFunc to serve content for the standard http // interface. func (s *Server) Handle(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { w.WriteHeader(http.StatusMethodNotAllowed) w.Write([]byte("only POST requests allowed")) return } id, err := newIdent() if err != nil { log.Printf("could not generate request id: %s", err) w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("sorry, can't currently process your request")) return } c := &Context{ id: id, req: r, w: w, debug: s.debug, username: "unknown", } req := Request{} res := &Response{ ID: c.id, Messages: map[string][]string{}, Result: map[string]interface{}{}, } tx, err := s.db.Begin() if err != nil { c.Logf(LevelError, "could not create transaction: %s", err) w.WriteHeader(http.StatusInternalServerError) res.AddMessage(LevelFatal, "database problems occured") c.render(res) return } defer tx.Rollback() // make a rollback, just in case something goes wrong c.tx = tx // set username for transaction // TODO check username to be ASCII and nothing else _, err = c.tx.Exec( fmt.Sprintf(`set local dim.username to '%s'; set local dim.transaction = '%s'`, c.username, c.id), ) if err != nil { c.Logf(LevelError, "could not set transaction username: %s", err) w.WriteHeader(http.StatusInternalServerError) res.AddMessage(LevelFatal, "could not create transaction") c.render(res) return } dec := json.NewDecoder(r.Body) defer r.Body.Close() if err := dec.Decode(&req); err != nil { res.AddMessage(LevelError, fmt.Sprintf("could not parse payload: %s", err)) c.w.WriteHeader(http.StatusBadRequest) c.render(res) return } handler, found := s.routes[req.Method] if !found { res.AddMessage(LevelError, "method %s does not exist", req.Method) c.w.WriteHeader(http.StatusNotFound) c.render(res) return } c.Logf(LevelInfo, "method '%s' called with '%s'", req.Method, req.Params) if err := handler(c, req, res); err != nil { c.w.WriteHeader(http.StatusInternalServerError) c.Logf(LevelError, "method '%s' returned an error: %s", req.Method, err) } if err := tx.Commit(); err != nil { c.Logf(LevelFatal, "could not commit changes: %s", err) res.AddMessage(LevelError, "changes were not committed") } c.render(res) } // Render converts the Result to json and sends it back to the client. func (c *Context) render(res *Response) { enc := json.NewEncoder(c.w) if err := enc.Encode(res); err != nil { c.Logf(LevelError, "%s - could not encode result: %s\n%s", c.id, err, res) return } } // Logf logs a message to stdout, outfitted with the request ID. func (c *Context) Logf(level, msg string, args ...interface{}) { log.Printf("%s - %s - %s", c.id, level, fmt.Sprintf(msg, args...)) } // Debugf logs output only when the server is set into debug mode. func (c *Context) Debugf(level, msg string, args ...interface{}) { if c.debug { log.Printf("%s - %s - %s", c.id, level, fmt.Sprintf(msg, args...)) } } // Generate a useable request ID, so that it can be found in the logs. func newIdent() (string, error) { b := make([]byte, 16) _, err := rand.Read(b) if err != nil { return "", err } return fmt.Sprintf("%X-%X-%X-%X-%X", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), nil } // AddMessage adds a new message to the result. // When Render is called, these messages will be sent to the client. func (r *Response) AddMessage(level string, msg string, args ...interface{}) { r.Messages[level] = append(r.Messages[level], fmt.Sprintf(msg, args...)) } // Len returns the length of the parameter list. func (r Request) Len() int { return len(r.Params) } // Parse will parse the exact number of parameters into arguments. // The boundaries are checked before. func (r Request) Parse(args ...interface{}) error { num := len(args) if r.Len() != num { return fmt.Errorf("expected %d parameters, got %d", num, r.Len()) } return r.ParseAtLeast(num, args...) } // ParseAtLeast parses at least *num* many parameters into arguments. // This function checks the boundaries before beginning to parse and returns // errors if any boundary does not match. func (r Request) ParseAtLeast(num int, args ...interface{}) error { if len(args) < num { return fmt.Errorf("requested %d arguments to parse, but only %d arguments given", num, len(args)) } if r.Len() < num { return fmt.Errorf("need %d parameters, but only got %d", num, r.Len()) } if r.Len() > len(args) { return fmt.Errorf("found %d parameters, when only %d are required", r.Len(), len(args)) } for i, param := range r.Params { if err := json.Unmarshal(param, args[i]); err != nil { return fmt.Errorf("argument at position %d can't be parsed: %v", i, err) } } return nil } // ParseAt unmarshalls the argument at *pos* into the container. // If the position is outside the size of incoming arguments an error is raised. func (r Request) ParseAt(pos int, container interface{}) error { if pos >= r.Len() { return fmt.Errorf("out of bounds") // TODO make generic ErrOutOfBounds } if err := json.Unmarshal(r.Params[pos], container); err != nil { return fmt.Errorf("could not unmarshal parameter '%d': %s", pos, err) } return nil }