aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/internal
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/internal')
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go70
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/pgio/README.md6
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/pgio/doc.go6
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/pgio/write.go40
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go331
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go112
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go45
-rw-r--r--vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go77
8 files changed, 687 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go
new file mode 100644
index 0000000..89e0c22
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go
@@ -0,0 +1,70 @@
+// Package iobufpool implements a global segregated-fit pool of buffers for IO.
+//
+// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
+// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
+package iobufpool
+
+import "sync"
+
+const minPoolExpOf2 = 8
+
+var pools [18]*sync.Pool
+
+func init() {
+ for i := range pools {
+ bufLen := 1 << (minPoolExpOf2 + i)
+ pools[i] = &sync.Pool{
+ New: func() any {
+ buf := make([]byte, bufLen)
+ return &buf
+ },
+ }
+ }
+}
+
+// Get gets a []byte of len size with cap <= size*2.
+func Get(size int) *[]byte {
+ i := getPoolIdx(size)
+ if i >= len(pools) {
+ buf := make([]byte, size)
+ return &buf
+ }
+
+ ptrBuf := (pools[i].Get().(*[]byte))
+ *ptrBuf = (*ptrBuf)[:size]
+
+ return ptrBuf
+}
+
+func getPoolIdx(size int) int {
+ size--
+ size >>= minPoolExpOf2
+ i := 0
+ for size > 0 {
+ size >>= 1
+ i++
+ }
+
+ return i
+}
+
+// Put returns buf to the pool.
+func Put(buf *[]byte) {
+ i := putPoolIdx(cap(*buf))
+ if i < 0 {
+ return
+ }
+
+ pools[i].Put(buf)
+}
+
+func putPoolIdx(size int) int {
+ minPoolSize := 1 << minPoolExpOf2
+ for i := range pools {
+ if size == minPoolSize<<i {
+ return i
+ }
+ }
+
+ return -1
+}
diff --git a/vendor/github.com/jackc/pgx/v5/internal/pgio/README.md b/vendor/github.com/jackc/pgx/v5/internal/pgio/README.md
new file mode 100644
index 0000000..b2fc580
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/pgio/README.md
@@ -0,0 +1,6 @@
+# pgio
+
+Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
+
+pgio provides functions for appending integers to a []byte while doing byte
+order conversion.
diff --git a/vendor/github.com/jackc/pgx/v5/internal/pgio/doc.go b/vendor/github.com/jackc/pgx/v5/internal/pgio/doc.go
new file mode 100644
index 0000000..ef2dcc7
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/pgio/doc.go
@@ -0,0 +1,6 @@
+// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
+/*
+pgio provides functions for appending integers to a []byte while doing byte
+order conversion.
+*/
+package pgio
diff --git a/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go b/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go
new file mode 100644
index 0000000..96aedf9
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go
@@ -0,0 +1,40 @@
+package pgio
+
+import "encoding/binary"
+
+func AppendUint16(buf []byte, n uint16) []byte {
+ wp := len(buf)
+ buf = append(buf, 0, 0)
+ binary.BigEndian.PutUint16(buf[wp:], n)
+ return buf
+}
+
+func AppendUint32(buf []byte, n uint32) []byte {
+ wp := len(buf)
+ buf = append(buf, 0, 0, 0, 0)
+ binary.BigEndian.PutUint32(buf[wp:], n)
+ return buf
+}
+
+func AppendUint64(buf []byte, n uint64) []byte {
+ wp := len(buf)
+ buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
+ binary.BigEndian.PutUint64(buf[wp:], n)
+ return buf
+}
+
+func AppendInt16(buf []byte, n int16) []byte {
+ return AppendUint16(buf, uint16(n))
+}
+
+func AppendInt32(buf []byte, n int32) []byte {
+ return AppendUint32(buf, uint32(n))
+}
+
+func AppendInt64(buf []byte, n int64) []byte {
+ return AppendUint64(buf, uint64(n))
+}
+
+func SetInt32(buf []byte, n int32) {
+ binary.BigEndian.PutUint32(buf, uint32(n))
+}
diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go
new file mode 100644
index 0000000..df58c44
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go
@@ -0,0 +1,331 @@
+package sanitize
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+ "unicode/utf8"
+)
+
+// Part is either a string or an int. A string is raw SQL. An int is a
+// argument placeholder.
+type Part any
+
+type Query struct {
+ Parts []Part
+}
+
+// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
+// character. utf8.RuneError is not an error if it is also width 3.
+//
+// https://github.com/jackc/pgx/issues/1380
+const replacementcharacterwidth = 3
+
+func (q *Query) Sanitize(args ...any) (string, error) {
+ argUse := make([]bool, len(args))
+ buf := &bytes.Buffer{}
+
+ for _, part := range q.Parts {
+ var str string
+ switch part := part.(type) {
+ case string:
+ str = part
+ case int:
+ argIdx := part - 1
+
+ if argIdx < 0 {
+ return "", fmt.Errorf("first sql argument must be > 0")
+ }
+
+ if argIdx >= len(args) {
+ return "", fmt.Errorf("insufficient arguments")
+ }
+ arg := args[argIdx]
+ switch arg := arg.(type) {
+ case nil:
+ str = "null"
+ case int64:
+ str = strconv.FormatInt(arg, 10)
+ case float64:
+ str = strconv.FormatFloat(arg, 'f', -1, 64)
+ case bool:
+ str = strconv.FormatBool(arg)
+ case []byte:
+ str = QuoteBytes(arg)
+ case string:
+ str = QuoteString(arg)
+ case time.Time:
+ str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
+ default:
+ return "", fmt.Errorf("invalid arg type: %T", arg)
+ }
+ argUse[argIdx] = true
+
+ // Prevent SQL injection via Line Comment Creation
+ // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
+ str = " " + str + " "
+ default:
+ return "", fmt.Errorf("invalid Part type: %T", part)
+ }
+ buf.WriteString(str)
+ }
+
+ for i, used := range argUse {
+ if !used {
+ return "", fmt.Errorf("unused argument: %d", i)
+ }
+ }
+ return buf.String(), nil
+}
+
+func NewQuery(sql string) (*Query, error) {
+ l := &sqlLexer{
+ src: sql,
+ stateFn: rawState,
+ }
+
+ for l.stateFn != nil {
+ l.stateFn = l.stateFn(l)
+ }
+
+ query := &Query{Parts: l.parts}
+
+ return query, nil
+}
+
+func QuoteString(str string) string {
+ return "'" + strings.ReplaceAll(str, "'", "''") + "'"
+}
+
+func QuoteBytes(buf []byte) string {
+ return `'\x` + hex.EncodeToString(buf) + "'"
+}
+
+type sqlLexer struct {
+ src string
+ start int
+ pos int
+ nested int // multiline comment nesting level.
+ stateFn stateFn
+ parts []Part
+}
+
+type stateFn func(*sqlLexer) stateFn
+
+func rawState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case 'e', 'E':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '\'' {
+ l.pos += width
+ return escapeStringState
+ }
+ case '\'':
+ return singleQuoteState
+ case '"':
+ return doubleQuoteState
+ case '$':
+ nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
+ if '0' <= nextRune && nextRune <= '9' {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos-width])
+ }
+ l.start = l.pos
+ return placeholderState
+ }
+ case '-':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '-' {
+ l.pos += width
+ return oneLineCommentState
+ }
+ case '/':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '*' {
+ l.pos += width
+ return multilineCommentState
+ }
+ case utf8.RuneError:
+ if width != replacementcharacterwidth {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+ }
+}
+
+func singleQuoteState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '\'':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '\'' {
+ return rawState
+ }
+ l.pos += width
+ case utf8.RuneError:
+ if width != replacementcharacterwidth {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+ }
+}
+
+func doubleQuoteState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '"':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '"' {
+ return rawState
+ }
+ l.pos += width
+ case utf8.RuneError:
+ if width != replacementcharacterwidth {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+ }
+}
+
+// placeholderState consumes a placeholder value. The $ must have already has
+// already been consumed. The first rune must be a digit.
+func placeholderState(l *sqlLexer) stateFn {
+ num := 0
+
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ if '0' <= r && r <= '9' {
+ num *= 10
+ num += int(r - '0')
+ } else {
+ l.parts = append(l.parts, num)
+ l.pos -= width
+ l.start = l.pos
+ return rawState
+ }
+ }
+}
+
+func escapeStringState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '\\':
+ _, width = utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+ case '\'':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '\'' {
+ return rawState
+ }
+ l.pos += width
+ case utf8.RuneError:
+ if width != replacementcharacterwidth {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+ }
+}
+
+func oneLineCommentState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '\\':
+ _, width = utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+ case '\n', '\r':
+ return rawState
+ case utf8.RuneError:
+ if width != replacementcharacterwidth {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+ }
+}
+
+func multilineCommentState(l *sqlLexer) stateFn {
+ for {
+ r, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ l.pos += width
+
+ switch r {
+ case '/':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune == '*' {
+ l.pos += width
+ l.nested++
+ }
+ case '*':
+ nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
+ if nextRune != '/' {
+ continue
+ }
+
+ l.pos += width
+ if l.nested == 0 {
+ return rawState
+ }
+ l.nested--
+
+ case utf8.RuneError:
+ if width != replacementcharacterwidth {
+ if l.pos-l.start > 0 {
+ l.parts = append(l.parts, l.src[l.start:l.pos])
+ l.start = l.pos
+ }
+ return nil
+ }
+ }
+ }
+}
+
+// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
+// as necessary. This function is only safe when standard_conforming_strings is
+// on.
+func SanitizeSQL(sql string, args ...any) (string, error) {
+ query, err := NewQuery(sql)
+ if err != nil {
+ return "", err
+ }
+ return query.Sanitize(args...)
+}
diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go
new file mode 100644
index 0000000..dec83f4
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go
@@ -0,0 +1,112 @@
+package stmtcache
+
+import (
+ "container/list"
+
+ "github.com/jackc/pgx/v5/pgconn"
+)
+
+// LRUCache implements Cache with a Least Recently Used (LRU) cache.
+type LRUCache struct {
+ cap int
+ m map[string]*list.Element
+ l *list.List
+ invalidStmts []*pgconn.StatementDescription
+}
+
+// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
+func NewLRUCache(cap int) *LRUCache {
+ return &LRUCache{
+ cap: cap,
+ m: make(map[string]*list.Element),
+ l: list.New(),
+ }
+}
+
+// Get returns the statement description for sql. Returns nil if not found.
+func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
+ if el, ok := c.m[key]; ok {
+ c.l.MoveToFront(el)
+ return el.Value.(*pgconn.StatementDescription)
+ }
+
+ return nil
+
+}
+
+// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
+// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
+func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
+ if sd.SQL == "" {
+ panic("cannot store statement description with empty SQL")
+ }
+
+ if _, present := c.m[sd.SQL]; present {
+ return
+ }
+
+ // The statement may have been invalidated but not yet handled. Do not readd it to the cache.
+ for _, invalidSD := range c.invalidStmts {
+ if invalidSD.SQL == sd.SQL {
+ return
+ }
+ }
+
+ if c.l.Len() == c.cap {
+ c.invalidateOldest()
+ }
+
+ el := c.l.PushFront(sd)
+ c.m[sd.SQL] = el
+}
+
+// Invalidate invalidates statement description identified by sql. Does nothing if not found.
+func (c *LRUCache) Invalidate(sql string) {
+ if el, ok := c.m[sql]; ok {
+ delete(c.m, sql)
+ c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
+ c.l.Remove(el)
+ }
+}
+
+// InvalidateAll invalidates all statement descriptions.
+func (c *LRUCache) InvalidateAll() {
+ el := c.l.Front()
+ for el != nil {
+ c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
+ el = el.Next()
+ }
+
+ c.m = make(map[string]*list.Element)
+ c.l = list.New()
+}
+
+// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
+func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
+ return c.invalidStmts
+}
+
+// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
+// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
+// never seen by the call to GetInvalidated.
+func (c *LRUCache) RemoveInvalidated() {
+ c.invalidStmts = nil
+}
+
+// Len returns the number of cached prepared statement descriptions.
+func (c *LRUCache) Len() int {
+ return c.l.Len()
+}
+
+// Cap returns the maximum number of cached prepared statement descriptions.
+func (c *LRUCache) Cap() int {
+ return c.cap
+}
+
+func (c *LRUCache) invalidateOldest() {
+ oldest := c.l.Back()
+ sd := oldest.Value.(*pgconn.StatementDescription)
+ c.invalidStmts = append(c.invalidStmts, sd)
+ delete(c.m, sd.SQL)
+ c.l.Remove(oldest)
+}
diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go
new file mode 100644
index 0000000..d57bdd2
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go
@@ -0,0 +1,45 @@
+// Package stmtcache is a cache for statement descriptions.
+package stmtcache
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+
+ "github.com/jackc/pgx/v5/pgconn"
+)
+
+// StatementName returns a statement name that will be stable for sql across multiple connections and program
+// executions.
+func StatementName(sql string) string {
+ digest := sha256.Sum256([]byte(sql))
+ return "stmtcache_" + hex.EncodeToString(digest[0:24])
+}
+
+// Cache caches statement descriptions.
+type Cache interface {
+ // Get returns the statement description for sql. Returns nil if not found.
+ Get(sql string) *pgconn.StatementDescription
+
+ // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
+ Put(sd *pgconn.StatementDescription)
+
+ // Invalidate invalidates statement description identified by sql. Does nothing if not found.
+ Invalidate(sql string)
+
+ // InvalidateAll invalidates all statement descriptions.
+ InvalidateAll()
+
+ // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
+ GetInvalidated() []*pgconn.StatementDescription
+
+ // RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
+ // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
+ // never seen by the call to GetInvalidated.
+ RemoveInvalidated()
+
+ // Len returns the number of cached prepared statement descriptions.
+ Len() int
+
+ // Cap returns the maximum number of cached prepared statement descriptions.
+ Cap() int
+}
diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go
new file mode 100644
index 0000000..6964132
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go
@@ -0,0 +1,77 @@
+package stmtcache
+
+import (
+ "math"
+
+ "github.com/jackc/pgx/v5/pgconn"
+)
+
+// UnlimitedCache implements Cache with no capacity limit.
+type UnlimitedCache struct {
+ m map[string]*pgconn.StatementDescription
+ invalidStmts []*pgconn.StatementDescription
+}
+
+// NewUnlimitedCache creates a new UnlimitedCache.
+func NewUnlimitedCache() *UnlimitedCache {
+ return &UnlimitedCache{
+ m: make(map[string]*pgconn.StatementDescription),
+ }
+}
+
+// Get returns the statement description for sql. Returns nil if not found.
+func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
+ return c.m[sql]
+}
+
+// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
+func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
+ if sd.SQL == "" {
+ panic("cannot store statement description with empty SQL")
+ }
+
+ if _, present := c.m[sd.SQL]; present {
+ return
+ }
+
+ c.m[sd.SQL] = sd
+}
+
+// Invalidate invalidates statement description identified by sql. Does nothing if not found.
+func (c *UnlimitedCache) Invalidate(sql string) {
+ if sd, ok := c.m[sql]; ok {
+ delete(c.m, sql)
+ c.invalidStmts = append(c.invalidStmts, sd)
+ }
+}
+
+// InvalidateAll invalidates all statement descriptions.
+func (c *UnlimitedCache) InvalidateAll() {
+ for _, sd := range c.m {
+ c.invalidStmts = append(c.invalidStmts, sd)
+ }
+
+ c.m = make(map[string]*pgconn.StatementDescription)
+}
+
+// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
+func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
+ return c.invalidStmts
+}
+
+// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
+// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
+// never seen by the call to GetInvalidated.
+func (c *UnlimitedCache) RemoveInvalidated() {
+ c.invalidStmts = nil
+}
+
+// Len returns the number of cached prepared statement descriptions.
+func (c *UnlimitedCache) Len() int {
+ return len(c.m)
+}
+
+// Cap returns the maximum number of cached prepared statement descriptions.
+func (c *UnlimitedCache) Cap() int {
+ return math.MaxInt
+}