diff --git a/layer3domain_attr.go b/layer3domain_attr.go new file mode 100644 index 0000000..41e7607 --- /dev/null +++ b/layer3domain_attr.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + + "dim/query" + "dim/types" +) + +type ( + Layer3DomainSetOptions struct { + Attributes types.FieldMap `json:"attributes"` + } +) + +func layer3DomainSetAttr(c *Context, req Request, res *Response) error { + name := "" + attrs := types.FieldMap{} + if err := req.ParseAtLeast(2, &name, &attrs); err != nil { + res.AddMessage(LevelError, "could not parse options: %s", err) + return nil + } + if name == "" { + res.AddMessage(LevelError, "empty name was provided") + return nil + } + if attrs.Size() == 0 { + res.AddMessage(LevelError, "no key/value pairs provided to update") + return nil + } + + setClause, args, err := query.FieldMapToUpdate(attrs, layer3DomainListMap) + if err != nil { + res.AddMessage(LevelError, "could not encode requested attributes: %s", err) + return nil + } + queryStr := fmt.Sprintf("update layer3domains l set %s where name = $%d", setClause, len(args)+1) + args = append(args, name) // don't forget to add the where clause parameter + if _, err := c.tx.Exec(queryStr, args...); err != nil { + res.AddMessage(LevelError, "could not set attributes") + c.Logf(LevelError, "could not set attributes on layer3domain '%s': %s - query: `%s` - args: `%#v`", name, err, queryStr, args) + return nil + } + return nil +} diff --git a/layer3domain_create.go b/layer3domain_create.go index 8e478ba..47b92d6 100644 --- a/layer3domain_create.go +++ b/layer3domain_create.go @@ -8,6 +8,7 @@ type ( Layer3DomainCreateOptions string ) +// Layer3DomainCreate creates a new layer3domain. func layer3DomainCreate(c *Context, req Request, res *Response) error { name := "" options := Layer3DomainCreateOptions("{}") diff --git a/layer3domain_list.go b/layer3domain_list.go new file mode 100644 index 0000000..ebe0a8d --- /dev/null +++ b/layer3domain_list.go @@ -0,0 +1,52 @@ +package main + +import ( + "fmt" + + "dim/query" + "dim/types" +) + +type ( + Layer3DomainListOptions struct { + Attributes types.FieldList `json:"attributes"` + } +) + +var ( + layer3DomainListMap = map[string]string{ + "name": "l.name", + "modified_by": "l.modified_by", + "modified_at": "l.modified_at", + "created_by": "l.created_by", + "created_at": "l.created_at", + } +) + +// Layer3DomainList lists all registered layer3domains. +func layer3DomainList(c *Context, req Request, res *Response) error { + options := Layer3DomainListOptions{ + Attributes: types.NewFieldList("name"), + } + if err := req.ParseAtLeast(0, &options); err != nil { + res.AddMessage(LevelError, "could not parse options: %s", err) + return nil + } + + selClause := query.FieldListToSelect("l", options.Attributes, layer3DomainListMap) + from := "layer3domains l" + queryStr := fmt.Sprintf(`select %s from %s`, selClause, from) + rows, err := c.tx.Query(queryStr) + if err != nil { + res.AddMessage(LevelError, "could not return result") + return fmt.Errorf("could not get layer3domain list: %s - query %s", err, queryStr) + } + defer rows.Close() + res.Result, err = query.RowsToMap(rows) + if err != nil { + res.Result = nil + res.AddMessage(LevelError, "could not return result") + return fmt.Errorf("could not parse layer3domain list: %#v", err) + } + return nil +} diff --git a/query/query.go b/query/query.go index 744ee6f..bc17f26 100644 --- a/query/query.go +++ b/query/query.go @@ -7,6 +7,7 @@ with the necessary parameter keys. package query import ( + "encoding/json" "fmt" "strings" @@ -46,3 +47,68 @@ func nameToAttrPath(tabName, name string) string { } return fmt.Sprintf("%s.attributes->%s", tabName, strings.Join(parts, "->")) } + +// FieldMapToUpdate generates the necessary elements for an update. +// +// It returns the set clause for the update statement and the arguments for the placeholders. +// The index will start with 1, so every other parameter not included in the update needs to +// use the size of the field map + 1 as the next index. +// If the key points is not found in the nameMap, the value will be joined with the attributes +// column of the table. +// An error is returned when the attribute values can't be encoded correctly. +func FieldMapToUpdate(fm types.FieldMap, nameMap map[string]string) (string, []interface{}, error) { + setClause := []string{} + args := []interface{}{} + attrVals := map[string]interface{}{} + i := 0 + for key, val := range fm.Fields() { + i++ + if name, found := nameMap[key]; found { + setClause = append(setClause, fmt.Sprintf("%s = $%d", name, i)) + if val == "" { + args = append(args, nil) + } else { + args = append(args, val) + } + } else { + parts := strings.Split(key, ".") + attrVals = setJSONPath(attrVals, parts, val) + } + } + if len(attrVals) > 0 { + setClause = append( + setClause, + fmt.Sprintf("attributes = jsonb_strip_nulls(attributes || $%d::jsonb)", len(args)+1), + ) + raw, err := json.Marshal(attrVals) + if err != nil { + return "", []interface{}{}, fmt.Errorf("could not encode attributes: %#v", err) + } + args = append(args, string(raw)) + } + return strings.Join(setClause, ","), args, nil +} + +// Set a value to a nested map structure. +// The path must be a list of steps to traverse the map structure. +func setJSONPath(target map[string]interface{}, path []string, val interface{}) map[string]interface{} { + res := target + if len(path) > 1 { + raw, found := res[path[0]] + + if !found { + res[path[0]] = map[string]interface{}{} + raw = res[path[0]] + } else { + values, worked := raw.(map[string]interface{}) + if !worked { + values = map[string]interface{}{} + res[path[0]] = values + } + } + res[path[0]] = setJSONPath(res[path[0]].(map[string]interface{}), path[1:], val) + return res + } + res[path[0]] = val + return res +} diff --git a/query/query_test.go b/query/query_test.go index 04ea498..94cbcfa 100644 --- a/query/query_test.go +++ b/query/query_test.go @@ -54,3 +54,47 @@ func TestNameToAttrPath(t *testing.T) { } } } + +func TestFieldMapToUpdate(t *testing.T) { + tests := []struct { + table string + vals types.FieldMap + mapping map[string]string + set string // expected set clause + args []interface{} // expected arguments + }{ + { // check for normal field mapping + "zoo", + types.NewFieldMap(map[string]interface{}{"key": "value"}), + map[string]string{"key": "field"}, + "zoo.field = $1", + []interface{}{"value"}, + }, + { // generate attributes field + "zoo", + types.NewFieldMap(map[string]interface{}{"key2": "value"}), + map[string]string{"key": "field"}, + "zoo.attributes->'key2' = $1", + []interface{}{"value"}, + }, + { // mixed mapped and unmapped field + "zoo", + types.NewFieldMap(map[string]interface{}{"key2": "value", "key": "value"}), + map[string]string{"key": "field"}, + "zoo.attributes->'key2' = $1,zoo.field = $2", + []interface{}{"value", "value"}, + }, + } + + for _, test := range tests { + set, args := FieldMapToUpdate(test.table, test.vals, test.mapping) + if set != test.set { + t.Errorf("expected set clause `%s`, got `%s`", test.set, set) + } + for i, arg := range args { + if arg != test.args[i] { + t.Errorf("expected argument at pos %d to be %#v, but was %#v", i, test.args[i], arg) + } + } + } +} diff --git a/types/fields.go b/types/fields.go index 21de707..a67ff53 100644 --- a/types/fields.go +++ b/types/fields.go @@ -14,12 +14,18 @@ type ( FieldList struct { fields map[string]bool } + // FieldMap is a set of key/value pairs. + // It can be used to with query.FieldMapToUpdate to build a + // set clause for an update statement. + FieldMap struct { + fields map[string]interface{} + } ) var ( // fieldIdentifier filters field names to allow only sane values // and at the same time make them save for database queries. - fieldIdentifier = regexp.MustCompile(`\A[a-zA-Z]+[a-zA-Z0-9_\-]`) + fieldIdentifier = regexp.MustCompile(`\A[a-zA-Z]+([a-zA-Z0-9_\-.]*)`) ) func NewFieldList(fields ...string) FieldList { @@ -62,3 +68,34 @@ func (fl FieldList) Fields() []string { sort.Strings(res) return res } + +// NewFieldMap builds a FieldMap with the provided defaults. +func NewFieldMap(fields map[string]interface{}) FieldMap { + return FieldMap{fields: fields} +} + +// UnmarshalJSON implements the json decoding interface so that it can be used with +// with the request parsing functions. +func (fm *FieldMap) UnmarshalJSON(raw []byte) error { + fields := map[string]interface{}{} + if err := json.Unmarshal(raw, &fields); err != nil { + return err + } + for k, _ := range fields { + if !fieldIdentifier.Match([]byte(k)) { + return fmt.Errorf("`%s` is not an allowed field name. Allowed is only alpha numerical", k) + } + } + fm.fields = fields + return nil +} + +// Fields returns all key/value pairs. +func (fm FieldMap) Fields() map[string]interface{} { + return fm.fields +} + +// Size returns the number of keys. +func (fm FieldMap) Size() int { + return len(fm.fields) +}