package types import ( "encoding/json" "fmt" "regexp" "sort" ) type ( // FieldList is a parameter type to represent a list of fields in the database // to return. It can be used with query.FieldListToSelect to build a // select clause to return the data that was requested. FieldList struct { fields map[string]bool } // FieldMap is a set of key/value pairs. // It can be used to with query.FieldMapToUpdate to build a // set clause for an update statement. FieldMap struct { fields map[string]interface{} } ) var ( // fieldIdentifier filters field names to allow only sane values // and at the same time make them save for database queries. fieldIdentifier = regexp.MustCompile(`\A[a-zA-Z]+([a-zA-Z0-9_\-.]*)`) ) func NewFieldList(fields ...string) FieldList { fl := FieldList{ fields: map[string]bool{}, } for _, field := range fields { fl.fields[field] = true } return fl } func (fl *FieldList) UnmarshalJSON(raw []byte) error { fields := []string{} if err := json.Unmarshal(raw, &fields); err != nil { return err } fl.fields = map[string]bool{} for _, field := range fields { if !fieldIdentifier.Match([]byte(field)) { return fmt.Errorf("`%s` is not an allowed field name. Allowed is only alpha numerical", field) } fl.fields[field] = true } return nil } // Contains returns true, when the string is found in the list. func (fl FieldList) Contains(in string) bool { _, found := fl.fields[in] return found } // Fields returns a sorted list of the requested fields. func (fl FieldList) Fields() []string { res := []string{} for name, _ := range fl.fields { res = append(res, name) } sort.Strings(res) return res } // NewFieldMap builds a FieldMap with the provided defaults. func NewFieldMap(fields map[string]interface{}) FieldMap { return FieldMap{fields: fields} } // UnmarshalJSON implements the json decoding interface so that it can be used with // with the request parsing functions. func (fm *FieldMap) UnmarshalJSON(raw []byte) error { fields := map[string]interface{}{} if err := json.Unmarshal(raw, &fields); err != nil { return err } for k, _ := range fields { if !fieldIdentifier.Match([]byte(k)) { return fmt.Errorf("`%s` is not an allowed field name. Allowed is only alpha numerical", k) } } fm.fields = fields return nil } // Fields returns all key/value pairs. func (fm FieldMap) Fields() map[string]interface{} { return fm.fields } // Size returns the number of keys. func (fm FieldMap) Size() int { return len(fm.fields) }