From 6ea4d2c82de80efc87708e5e182034b7c6c2019e Mon Sep 17 00:00:00 2001 From: Gibheer Date: Thu, 5 Sep 2024 19:38:25 +0200 Subject: switch from github.com/lib/pq to github.com/jackc/pgx/v5 lib/pq is out of maintenance for some time now, so switch to the newer more active library. Looks like it finally stabilized after a long time. --- .../jackc/pgx/v5/internal/iobufpool/iobufpool.go | 70 +++++ .../jackc/pgx/v5/internal/pgio/README.md | 6 + .../github.com/jackc/pgx/v5/internal/pgio/doc.go | 6 + .../github.com/jackc/pgx/v5/internal/pgio/write.go | 40 +++ .../jackc/pgx/v5/internal/sanitize/sanitize.go | 331 +++++++++++++++++++++ .../jackc/pgx/v5/internal/stmtcache/lru_cache.go | 112 +++++++ .../jackc/pgx/v5/internal/stmtcache/stmtcache.go | 45 +++ .../pgx/v5/internal/stmtcache/unlimited_cache.go | 77 +++++ 8 files changed, 687 insertions(+) create mode 100644 vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go create mode 100644 vendor/github.com/jackc/pgx/v5/internal/pgio/README.md create mode 100644 vendor/github.com/jackc/pgx/v5/internal/pgio/doc.go create mode 100644 vendor/github.com/jackc/pgx/v5/internal/pgio/write.go create mode 100644 vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go create mode 100644 vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go create mode 100644 vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go create mode 100644 vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go (limited to 'vendor/github.com/jackc/pgx/v5/internal') diff --git a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go new file mode 100644 index 0000000..89e0c22 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go @@ -0,0 +1,70 @@ +// Package iobufpool implements a global segregated-fit pool of buffers for IO. +// +// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid +// an allocation is purposely not documented. https://github.com/golang/go/issues/16323 +package iobufpool + +import "sync" + +const minPoolExpOf2 = 8 + +var pools [18]*sync.Pool + +func init() { + for i := range pools { + bufLen := 1 << (minPoolExpOf2 + i) + pools[i] = &sync.Pool{ + New: func() any { + buf := make([]byte, bufLen) + return &buf + }, + } + } +} + +// Get gets a []byte of len size with cap <= size*2. +func Get(size int) *[]byte { + i := getPoolIdx(size) + if i >= len(pools) { + buf := make([]byte, size) + return &buf + } + + ptrBuf := (pools[i].Get().(*[]byte)) + *ptrBuf = (*ptrBuf)[:size] + + return ptrBuf +} + +func getPoolIdx(size int) int { + size-- + size >>= minPoolExpOf2 + i := 0 + for size > 0 { + size >>= 1 + i++ + } + + return i +} + +// Put returns buf to the pool. +func Put(buf *[]byte) { + i := putPoolIdx(cap(*buf)) + if i < 0 { + return + } + + pools[i].Put(buf) +} + +func putPoolIdx(size int) int { + minPoolSize := 1 << minPoolExpOf2 + for i := range pools { + if size == minPoolSize< 0") + } + + if argIdx >= len(args) { + return "", fmt.Errorf("insufficient arguments") + } + arg := args[argIdx] + switch arg := arg.(type) { + case nil: + str = "null" + case int64: + str = strconv.FormatInt(arg, 10) + case float64: + str = strconv.FormatFloat(arg, 'f', -1, 64) + case bool: + str = strconv.FormatBool(arg) + case []byte: + str = QuoteBytes(arg) + case string: + str = QuoteString(arg) + case time.Time: + str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + default: + return "", fmt.Errorf("invalid arg type: %T", arg) + } + argUse[argIdx] = true + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + str = " " + str + " " + default: + return "", fmt.Errorf("invalid Part type: %T", part) + } + buf.WriteString(str) + } + + for i, used := range argUse { + if !used { + return "", fmt.Errorf("unused argument: %d", i) + } + } + return buf.String(), nil +} + +func NewQuery(sql string) (*Query, error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + query := &Query{Parts: l.parts} + + return query, nil +} + +func QuoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + +func QuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []Part +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '$': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if '0' <= nextRune && nextRune <= '9' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return placeholderState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +// placeholderState consumes a placeholder value. The $ must have already has +// already been consumed. The first rune must be a digit. +func placeholderState(l *sqlLexer) stateFn { + num := 0 + + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if '0' <= r && r <= '9' { + num *= 10 + num += int(r - '0') + } else { + l.parts = append(l.parts, num) + l.pos -= width + l.start = l.pos + return rawState + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +// SanitizeSQL replaces placeholder values with args. It quotes and escapes args +// as necessary. This function is only safe when standard_conforming_strings is +// on. +func SanitizeSQL(sql string, args ...any) (string, error) { + query, err := NewQuery(sql) + if err != nil { + return "", err + } + return query.Sanitize(args...) +} diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go new file mode 100644 index 0000000..dec83f4 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go @@ -0,0 +1,112 @@ +package stmtcache + +import ( + "container/list" + + "github.com/jackc/pgx/v5/pgconn" +) + +// LRUCache implements Cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + cap int + m map[string]*list.Element + l *list.List + invalidStmts []*pgconn.StatementDescription +} + +// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. +func NewLRUCache(cap int) *LRUCache { + return &LRUCache{ + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *LRUCache) Get(key string) *pgconn.StatementDescription { + if el, ok := c.m[key]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription) + } + + return nil + +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or +// sd.SQL has been invalidated and HandleInvalidated has not been called yet. +func (c *LRUCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + // The statement may have been invalidated but not yet handled. Do not readd it to the cache. + for _, invalidSD := range c.invalidStmts { + if invalidSD.SQL == sd.SQL { + return + } + } + + if c.l.Len() == c.cap { + c.invalidateOldest() + } + + el := c.l.PushFront(sd) + c.m[sd.SQL] = el +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *LRUCache) Invalidate(sql string) { + if el, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + c.l.Remove(el) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *LRUCache) InvalidateAll() { + el := c.l.Front() + for el != nil { + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + el = el.Next() + } + + c.m = make(map[string]*list.Element) + c.l = list.New() +} + +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *LRUCache) RemoveInvalidated() { + c.invalidStmts = nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +func (c *LRUCache) invalidateOldest() { + oldest := c.l.Back() + sd := oldest.Value.(*pgconn.StatementDescription) + c.invalidStmts = append(c.invalidStmts, sd) + delete(c.m, sd.SQL) + c.l.Remove(oldest) +} diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go new file mode 100644 index 0000000..d57bdd2 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go @@ -0,0 +1,45 @@ +// Package stmtcache is a cache for statement descriptions. +package stmtcache + +import ( + "crypto/sha256" + "encoding/hex" + + "github.com/jackc/pgx/v5/pgconn" +) + +// StatementName returns a statement name that will be stable for sql across multiple connections and program +// executions. +func StatementName(sql string) string { + digest := sha256.Sum256([]byte(sql)) + return "stmtcache_" + hex.EncodeToString(digest[0:24]) +} + +// Cache caches statement descriptions. +type Cache interface { + // Get returns the statement description for sql. Returns nil if not found. + Get(sql string) *pgconn.StatementDescription + + // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. + Put(sd *pgconn.StatementDescription) + + // Invalidate invalidates statement description identified by sql. Does nothing if not found. + Invalidate(sql string) + + // InvalidateAll invalidates all statement descriptions. + InvalidateAll() + + // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. + GetInvalidated() []*pgconn.StatementDescription + + // RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a + // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were + // never seen by the call to GetInvalidated. + RemoveInvalidated() + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int +} diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go new file mode 100644 index 0000000..6964132 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go @@ -0,0 +1,77 @@ +package stmtcache + +import ( + "math" + + "github.com/jackc/pgx/v5/pgconn" +) + +// UnlimitedCache implements Cache with no capacity limit. +type UnlimitedCache struct { + m map[string]*pgconn.StatementDescription + invalidStmts []*pgconn.StatementDescription +} + +// NewUnlimitedCache creates a new UnlimitedCache. +func NewUnlimitedCache() *UnlimitedCache { + return &UnlimitedCache{ + m: make(map[string]*pgconn.StatementDescription), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { + return c.m[sql] +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + c.m[sd.SQL] = sd +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *UnlimitedCache) Invalidate(sql string) { + if sd, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, sd) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *UnlimitedCache) InvalidateAll() { + for _, sd := range c.m { + c.invalidStmts = append(c.invalidStmts, sd) + } + + c.m = make(map[string]*pgconn.StatementDescription) +} + +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *UnlimitedCache) RemoveInvalidated() { + c.invalidStmts = nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *UnlimitedCache) Len() int { + return len(c.m) +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *UnlimitedCache) Cap() int { + return math.MaxInt +} -- cgit v1.2.3-70-g09d2