diff --git a/types/fields.go b/types/fields.go new file mode 100644 index 0000000..21de707 --- /dev/null +++ b/types/fields.go @@ -0,0 +1,64 @@ +package types + +import ( + "encoding/json" + "fmt" + "regexp" + "sort" +) + +type ( + // FieldList is a parameter type to represent a list of fields in the database + // to return. It can be used with query.FieldListToSelect to build a + // select clause to return the data that was requested. + FieldList struct { + fields map[string]bool + } +) + +var ( + // fieldIdentifier filters field names to allow only sane values + // and at the same time make them save for database queries. + fieldIdentifier = regexp.MustCompile(`\A[a-zA-Z]+[a-zA-Z0-9_\-]`) +) + +func NewFieldList(fields ...string) FieldList { + fl := FieldList{ + fields: map[string]bool{}, + } + for _, field := range fields { + fl.fields[field] = true + } + return fl +} + +func (fl *FieldList) UnmarshalJSON(raw []byte) error { + fields := []string{} + if err := json.Unmarshal(raw, &fields); err != nil { + return err + } + fl.fields = map[string]bool{} + for _, field := range fields { + if !fieldIdentifier.Match([]byte(field)) { + return fmt.Errorf("`%s` is not an allowed field name. Allowed is only alpha numerical", field) + } + fl.fields[field] = true + } + return nil +} + +// Contains returns true, when the string is found in the list. +func (fl FieldList) Contains(in string) bool { + _, found := fl.fields[in] + return found +} + +// Fields returns a sorted list of the requested fields. +func (fl FieldList) Fields() []string { + res := []string{} + for name, _ := range fl.fields { + res = append(res, name) + } + sort.Strings(res) + return res +}