diff options
author | Gibheer <gibheer+git@zero-knowledge.org> | 2024-09-05 19:38:25 +0200 |
---|---|---|
committer | Gibheer <gibheer+git@zero-knowledge.org> | 2024-09-05 19:38:25 +0200 |
commit | 6ea4d2c82de80efc87708e5e182034b7c6c2019e (patch) | |
tree | 35c0856a929040216c82153ca62d43b27530a887 /vendor/github.com/jackc/pgx/v5/named_args.go | |
parent | 6f64eeace1b66639b9380b44e88a8d54850a4306 (diff) |
lib/pq is out of maintenance for some time now, so switch to the newer
more active library. Looks like it finally stabilized after a long time.
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/named_args.go')
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/named_args.go | 295 |
1 files changed, 295 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/named_args.go b/vendor/github.com/jackc/pgx/v5/named_args.go new file mode 100644 index 0000000..c88991e --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/named_args.go @@ -0,0 +1,295 @@ +package pgx + +import ( + "context" + "fmt" + "strconv" + "strings" + "unicode/utf8" +) + +// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$' +// ordinal placeholder and construct the appropriate arguments. +// +// For example, the following two queries are equivalent: +// +// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) +// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) +// +// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be +// letters, numbers, or underscores. +type NamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(na, sql, false) +} + +// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all +// named arguments that the sql query uses, and no extra arguments. +type StrictNamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(sna, sql, true) +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + nameToOrdinal: make(map[namedArg]int, len(na)), + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + sb := strings.Builder{} + for _, p := range l.parts { + switch p := p.(type) { + case string: + sb.WriteString(p) + case namedArg: + sb.WriteRune('$') + sb.WriteString(strconv.Itoa(l.nameToOrdinal[p])) + } + } + + newArgs = make([]any, len(l.nameToOrdinal)) + for name, ordinal := range l.nameToOrdinal { + var found bool + newArgs[ordinal-1], found = na[string(name)] + if isStrict && !found { + return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name) + } + } + + if isStrict { + for name := range na { + if _, found := l.nameToOrdinal[namedArg(name)]; !found { + return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name) + } + } + } + + return sb.String(), newArgs, nil +} + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '@': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if isLetter(nextRune) || nextRune == '_' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return namedArgState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func isLetter(r rune) bool { + return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') +} + +func namedArgState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if r == utf8.RuneError { + if l.pos-l.start > 0 { + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, na) + l.start = l.pos + } + return nil + } else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') { + l.pos -= width + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, namedArg(na)) + l.start = l.pos + return rawState + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} |