From 6ea4d2c82de80efc87708e5e182034b7c6c2019e Mon Sep 17 00:00:00 2001 From: Gibheer Date: Thu, 5 Sep 2024 19:38:25 +0200 Subject: switch from github.com/lib/pq to github.com/jackc/pgx/v5 lib/pq is out of maintenance for some time now, so switch to the newer more active library. Looks like it finally stabilized after a long time. --- vendor/github.com/jackc/pgx/v5/conn.go | 1395 ++++++++++++++++++++++++++++++++ 1 file changed, 1395 insertions(+) create mode 100644 vendor/github.com/jackc/pgx/v5/conn.go (limited to 'vendor/github.com/jackc/pgx/v5/conn.go') diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go new file mode 100644 index 0000000..3117214 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -0,0 +1,1395 @@ +package pgx + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/v5/internal/sanitize" + "github.com/jackc/pgx/v5/internal/stmtcache" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" +) + +// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and +// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. +type ConnConfig struct { + pgconn.Config + + Tracer QueryTracer + + // Original connection string that was parsed into config. + connString string + + // StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement" + // query exec mode. + StatementCacheCapacity int + + // DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with + // "cache_describe" query exec mode. + DescriptionCacheCapacity int + + // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol + // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as + // PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same + // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. + DefaultQueryExecMode QueryExecMode + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// ParseConfigOptions contains options that control how a config is built such as getsslpassword. +type ParseConfigOptions struct { + pgconn.ParseConfigOptions +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (cc *ConnConfig) Copy() *ConnConfig { + newConfig := new(ConnConfig) + *newConfig = *cc + newConfig.Config = *newConfig.Config.Copy() + return newConfig +} + +// ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig. +func (cc *ConnConfig) ConnString() string { return cc.connString } + +// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access +// to multiple database connections from multiple goroutines. +type Conn struct { + pgConn *pgconn.PgConn + config *ConnConfig // config used when establishing this connection + preparedStatements map[string]*pgconn.StatementDescription + statementCache stmtcache.Cache + descriptionCache stmtcache.Cache + + queryTracer QueryTracer + batchTracer BatchTracer + copyFromTracer CopyFromTracer + prepareTracer PrepareTracer + + notifications []*pgconn.Notification + + doneChan chan struct{} + closedChan chan error + + typeMap *pgtype.Map + + wbuf []byte + eqb ExtendedQueryBuilder +} + +// Identifier a PostgreSQL identifier or name. Identifiers can be composed of +// multiple parts such as ["schema", "table"] or ["table", "column"]. +type Identifier []string + +// Sanitize returns a sanitized string safe for SQL interpolation. +func (ident Identifier) Sanitize() string { + parts := make([]string, len(ident)) + for i := range ident { + s := strings.ReplaceAll(ident[i], string([]byte{0}), "") + parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"` + } + return strings.Join(parts, ".") +} + +var ( + // ErrNoRows occurs when rows are expected but none are returned. + ErrNoRows = errors.New("no rows in result set") + // ErrTooManyRows occurs when more rows than expected are returned. + ErrTooManyRows = errors.New("too many rows in result set") +) + +var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") +var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") + +// Connect establishes a connection with a PostgreSQL server with a connection string. See +// pgconn.Connect for details. +func Connect(ctx context.Context, connString string) (*Conn, error) { + connConfig, err := ParseConfig(connString) + if err != nil { + return nil, err + } + return connect(ctx, connConfig) +} + +// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to +// provide a GetSSLPassword function. +func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) { + connConfig, err := ParseConfigWithOptions(connString, options) + if err != nil { + return nil, err + } + return connect(ctx, connConfig) +} + +// ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. +// connConfig must have been created by ParseConfig. +func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { + // In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other + // connections with the same config. See https://github.com/jackc/pgx/issues/618. + connConfig = connConfig.Copy() + + return connect(ctx, connConfig) +} + +// ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is +// only used to provide a GetSSLPassword function. +func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) { + config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions) + if err != nil { + return nil, err + } + + statementCacheCapacity := 512 + if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { + delete(config.RuntimeParams, "statement_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) + } + statementCacheCapacity = int(n) + } + + descriptionCacheCapacity := 512 + if s, ok := config.RuntimeParams["description_cache_capacity"]; ok { + delete(config.RuntimeParams, "description_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err) + } + descriptionCacheCapacity = int(n) + } + + defaultQueryExecMode := QueryExecModeCacheStatement + if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { + delete(config.RuntimeParams, "default_query_exec_mode") + switch s { + case "cache_statement": + defaultQueryExecMode = QueryExecModeCacheStatement + case "cache_describe": + defaultQueryExecMode = QueryExecModeCacheDescribe + case "describe_exec": + defaultQueryExecMode = QueryExecModeDescribeExec + case "exec": + defaultQueryExecMode = QueryExecModeExec + case "simple_protocol": + defaultQueryExecMode = QueryExecModeSimpleProtocol + default: + return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s) + } + } + + connConfig := &ConnConfig{ + Config: *config, + createdByParseConfig: true, + StatementCacheCapacity: statementCacheCapacity, + DescriptionCacheCapacity: descriptionCacheCapacity, + DefaultQueryExecMode: defaultQueryExecMode, + connString: connString, + } + + return connConfig, nil +} + +// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig] +// does. In addition, it accepts the following options: +// +// - default_query_exec_mode. +// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See +// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". +// +// - statement_cache_capacity. +// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode. +// Default: 512. +// +// - description_cache_capacity. +// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode. +// Default: 512. +func ParseConfig(connString string) (*ConnConfig, error) { + return ParseConfigWithOptions(connString, ParseConfigOptions{}) +} + +// connect connects to a database. connect takes ownership of config. The caller must not use or access it again. +func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { + if connectTracer, ok := config.Tracer.(ConnectTracer); ok { + ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config}) + defer func() { + connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err}) + }() + } + + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + + c = &Conn{ + config: config, + typeMap: pgtype.NewMap(), + queryTracer: config.Tracer, + } + + if t, ok := c.queryTracer.(BatchTracer); ok { + c.batchTracer = t + } + if t, ok := c.queryTracer.(CopyFromTracer); ok { + c.copyFromTracer = t + } + if t, ok := c.queryTracer.(PrepareTracer); ok { + c.prepareTracer = t + } + + // Only install pgx notification system if no other callback handler is present. + if config.Config.OnNotification == nil { + config.Config.OnNotification = c.bufferNotifications + } + + c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) + if err != nil { + return nil, err + } + + c.preparedStatements = make(map[string]*pgconn.StatementDescription) + c.doneChan = make(chan struct{}) + c.closedChan = make(chan error) + c.wbuf = make([]byte, 0, 1024) + + if c.config.StatementCacheCapacity > 0 { + c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity) + } + + if c.config.DescriptionCacheCapacity > 0 { + c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) + } + + return c, nil +} + +// Close closes a connection. It is safe to call Close on an already closed +// connection. +func (c *Conn) Close(ctx context.Context) error { + if c.IsClosed() { + return nil + } + + err := c.pgConn.Close(ctx) + return err +} + +// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These +// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and +// Exec to execute the statement. It can also be used with Batch.Queue. +// +// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if +// name == sql. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This +// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared. +func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.prepareTracer != nil { + ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) + } + + if name != "" { + var ok bool + if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { + if c.prepareTracer != nil { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true}) + } + return sd, nil + } + } + + if c.prepareTracer != nil { + defer func() { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err}) + }() + } + + var psName, psKey string + if name == sql { + digest := sha256.Sum256([]byte(sql)) + psName = "stmt_" + hex.EncodeToString(digest[0:24]) + psKey = sql + } else { + psName = name + psKey = name + } + + sd, err = c.pgConn.Prepare(ctx, psName, sql, nil) + if err != nil { + return nil, err + } + + if psKey != "" { + c.preparedStatements[psKey] = sd + } + + return sd, nil +} + +// Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed. +func (c *Conn) Deallocate(ctx context.Context, name string) error { + var psName string + sd := c.preparedStatements[name] + if sd != nil { + psName = sd.Name + } else { + psName = name + } + + err := c.pgConn.Deallocate(ctx, psName) + if err != nil { + return err + } + + if sd != nil { + delete(c.preparedStatements, name) + } + + return nil +} + +// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache. +func (c *Conn) DeallocateAll(ctx context.Context) error { + c.preparedStatements = map[string]*pgconn.StatementDescription{} + if c.config.StatementCacheCapacity > 0 { + c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity) + } + if c.config.DescriptionCacheCapacity > 0 { + c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) + } + _, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll() + return err +} + +func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) { + c.notifications = append(c.notifications, n) +} + +// WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a +// slightly more convenient form. +func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) { + var n *pgconn.Notification + + // Return already received notification immediately + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + return n, nil + } + + err := c.pgConn.WaitForNotification(ctx) + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + } + return n, err +} + +// IsClosed reports if the connection has been closed. +func (c *Conn) IsClosed() bool { + return c.pgConn.IsClosed() +} + +func (c *Conn) die(err error) { + if c.IsClosed() { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // force immediate hard cancel + c.pgConn.Close(ctx) +} + +func quoteIdentifier(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` +} + +// Ping delegates to the underlying *pgconn.PgConn.Ping. +func (c *Conn) Ping(ctx context.Context) error { + return c.pgConn.Ping(ctx) +} + +// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the +// PostgreSQL connection than pgx exposes. +// +// It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn +// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. +func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } + +// TypeMap returns the connection info used for this connection. +func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap } + +// Config returns a copy of config that was used to establish this connection. +func (c *Conn) Config() *ConnConfig { return c.config.Copy() } + +// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced +// positionally from the sql string as $1, $2, etc. +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments}) + } + + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return pgconn.CommandTag{}, err + } + + commandTag, err := c.exec(ctx, sql, arguments...) + + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err}) + } + + return commandTag, err +} + +func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { + mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter + +optionLoop: + for len(arguments) > 0 { + switch arg := arguments[0].(type) { + case QueryExecMode: + mode = arg + arguments = arguments[1:] + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] + default: + break optionLoop + } + } + + if queryRewriter != nil { + sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + if err != nil { + return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err) + } + } + + // Always use simple protocol when there are no arguments. + if len(arguments) == 0 { + mode = QueryExecModeSimpleProtocol + } + + if sd, ok := c.preparedStatements[sql]; ok { + return c.execPrepared(ctx, sd, arguments) + } + + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return pgconn.CommandTag{}, errDisabledStatementCache + } + sd := c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql) + if err != nil { + return pgconn.CommandTag{}, err + } + c.statementCache.Put(sd) + } + + return c.execPrepared(ctx, sd, arguments) + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return pgconn.CommandTag{}, errDisabledDescriptionCache + } + sd := c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err + } + c.descriptionCache.Put(sd) + } + + return c.execParams(ctx, sd, arguments) + case QueryExecModeDescribeExec: + sd, err := c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err + } + return c.execPrepared(ctx, sd, arguments) + case QueryExecModeExec: + return c.execSQLParams(ctx, sql, arguments) + case QueryExecModeSimpleProtocol: + return c.execSimpleProtocol(ctx, sql, arguments) + default: + return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode) + } +} + +func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []any) (commandTag pgconn.CommandTag, err error) { + if len(arguments) > 0 { + sql, err = c.sanitizeForSimpleQuery(sql, arguments...) + if err != nil { + return pgconn.CommandTag{}, err + } + } + + mrr := c.pgConn.Exec(ctx, sql) + for mrr.NextResult() { + commandTag, _ = mrr.ResultReader().Close() + } + err = mrr.Close() + return commandTag, err +} + +func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, sd, arguments) + if err != nil { + return pgconn.CommandTag{}, err + } + + result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} + +func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, sd, arguments) + if err != nil { + return pgconn.CommandTag{}, err + } + + result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} + +type unknownArgumentTypeQueryExecModeExecError struct { + arg any +} + +func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { + return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) +} + +func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, nil, args) + if err != nil { + return pgconn.CommandTag{}, err + } + + result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} + +func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { + r := &baseRows{} + + r.ctx = ctx + r.queryTracer = c.queryTracer + r.typeMap = c.typeMap + r.startTime = time.Now() + r.sql = sql + r.args = args + r.conn = c + + return r +} + +type QueryExecMode int32 + +const ( + _ QueryExecMode = iota + + // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round + // trip after the statement is cached. This is the default. If the database schema is modified or the search_path is + // changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the + // number of columns returned by a "SELECT *" changes or the type of a column is changed. + QueryExecModeCacheStatement + + // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended + // protocol. Queries are executed in a single round trip after the description is cached. If the database schema is + // modified or the search_path is changed after a statement is cached then the first execution of a previously cached + // query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed. + QueryExecModeCacheDescribe + + // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips + // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the + // statement description on the first round trip and then uses it to execute the query on the second round trip. This + // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe + // even when the database schema is modified concurrently. + QueryExecModeDescribeExec + + // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol + // with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be + // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are + // unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know + // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. + QueryExecModeExec + + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. + // Queries are executed in a single round trip. Type mappings can be registered with + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use + // a map[string]string directly as an argument. This mode cannot. + // + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor + // exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. + // + // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol + // should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does + // not support the extended protocol. + QueryExecModeSimpleProtocol +) + +func (m QueryExecMode) String() string { + switch m { + case QueryExecModeCacheStatement: + return "cache statement" + case QueryExecModeCacheDescribe: + return "cache describe" + case QueryExecModeDescribeExec: + return "describe exec" + case QueryExecModeExec: + return "exec" + case QueryExecModeSimpleProtocol: + return "simple protocol" + default: + return "invalid" + } +} + +// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. +type QueryResultFormats []int16 + +// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. +type QueryResultFormatsByOID map[uint32]int16 + +// QueryRewriter rewrites a query when used as the first arguments to a query method. +type QueryRewriter interface { + RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) +} + +// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query +// and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to +// determine if the query executed successfully. +// +// The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the +// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It +// is allowed to ignore the error returned from Query and handle it in Rows. +// +// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not +// return an error. +// +// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be +// collected before processing rather than processed while receiving each row. This avoids the possibility of the +// application processing rows from a query that the server rejected. The CollectRows function is useful here. +// +// An implementor of QueryRewriter may be passed as the first element of args. It can rewrite the sql and change or +// replace args. For example, NamedArgs is QueryRewriter that implements named arguments. +// +// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args}) + } + + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err}) + } + return &baseRows{err: err, closed: true}, err + } + + var resultFormats QueryResultFormats + var resultFormatsByOID QueryResultFormatsByOID + mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter + +optionLoop: + for len(args) > 0 { + switch arg := args[0].(type) { + case QueryResultFormats: + resultFormats = arg + args = args[1:] + case QueryResultFormatsByOID: + resultFormatsByOID = arg + args = args[1:] + case QueryExecMode: + mode = arg + args = args[1:] + case QueryRewriter: + queryRewriter = arg + args = args[1:] + default: + break optionLoop + } + } + + if queryRewriter != nil { + var err error + originalSQL := sql + originalArgs := args + sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args) + if err != nil { + rows := c.getRows(ctx, originalSQL, originalArgs) + err = fmt.Errorf("rewrite query failed: %w", err) + rows.fatal(err) + return rows, err + } + } + + // Bypass any statement caching. + if sql == "" { + mode = QueryExecModeSimpleProtocol + } + + c.eqb.reset() + rows := c.getRows(ctx, sql, args) + + var err error + sd, explicitPreparedStatement := c.preparedStatements[sql] + if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { + if sd == nil { + sd, err = c.getStatementDescription(ctx, mode, sql) + if err != nil { + rows.fatal(err) + return rows, err + } + } + + if len(sd.ParamOIDs) != len(args) { + rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) + return rows, rows.err + } + + rows.sql = sd.SQL + + err = c.eqb.Build(c.typeMap, sd, args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(sd.Fields)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] + } + } + + if resultFormats == nil { + resultFormats = c.eqb.ResultFormats + } + + if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) + } + } else if mode == QueryExecModeExec { + err := c.eqb.Build(c.typeMap, nil, args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else if mode == QueryExecModeSimpleProtocol { + sql, err = c.sanitizeForSimpleQuery(sql, args...) + if err != nil { + rows.fatal(err) + return rows, err + } + + mrr := c.pgConn.Exec(ctx, sql) + if mrr.NextResult() { + rows.resultReader = mrr.ResultReader() + rows.multiResultReader = mrr + } else { + err = mrr.Close() + rows.fatal(err) + return rows, err + } + + return rows, nil + } else { + err = fmt.Errorf("unknown QueryExecMode: %v", mode) + rows.fatal(err) + return rows, rows.err + } + + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + + return rows, rows.err +} + +// getStatementDescription returns the statement description of the sql query +// according to the given mode. +// +// If the mode is one that doesn't require to know the param and result OIDs +// then nil is returned without error. +func (c *Conn) getStatementDescription( + ctx context.Context, + mode QueryExecMode, + sql string, +) (sd *pgconn.StatementDescription, err error) { + + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return nil, errDisabledStatementCache + } + sd = c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql) + if err != nil { + return nil, err + } + c.statementCache.Put(sd) + } + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return nil, errDisabledDescriptionCache + } + sd = c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return nil, err + } + c.descriptionCache.Put(sd) + } + case QueryExecModeDescribeExec: + return c.Prepare(ctx, "", sql) + } + return sd, err +} + +// QueryRow is a convenience wrapper over Query. Any error that occurs while +// querying is deferred until calling Scan on the returned Row. That Row will +// error with ErrNoRows if no rows are returned. +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { + rows, _ := c.Query(ctx, sql, args...) + return (*connRow)(rows.(*baseRows)) +} + +// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless +// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection +// is used again. +func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if c.batchTracer != nil { + ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) + defer func() { + err := br.(interface{ earlyError() error }).earlyError() + if err != nil { + c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err}) + } + }() + } + + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + + for _, bi := range b.QueuedQueries { + var queryRewriter QueryRewriter + sql := bi.SQL + arguments := bi.Arguments + + optionLoop: + for len(arguments) > 0 { + // Update Batch.Queue function comment when additional options are implemented + switch arg := arguments[0].(type) { + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] + default: + break optionLoop + } + } + + if queryRewriter != nil { + var err error + sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)} + } + } + + bi.SQL = sql + bi.Arguments = arguments + } + + // TODO: changing mode per batch? Update Batch.Queue function comment when implemented + mode := c.config.DefaultQueryExecMode + if mode == QueryExecModeSimpleProtocol { + return c.sendBatchQueryExecModeSimpleProtocol(ctx, b) + } + + // All other modes use extended protocol and thus can use prepared statements. + for _, bi := range b.QueuedQueries { + if sd, ok := c.preparedStatements[bi.SQL]; ok { + bi.sd = sd + } + } + + switch mode { + case QueryExecModeExec: + return c.sendBatchQueryExecModeExec(ctx, b) + case QueryExecModeCacheStatement: + return c.sendBatchQueryExecModeCacheStatement(ctx, b) + case QueryExecModeCacheDescribe: + return c.sendBatchQueryExecModeCacheDescribe(ctx, b) + case QueryExecModeDescribeExec: + return c.sendBatchQueryExecModeDescribeExec(ctx, b) + default: + panic("unknown QueryExecMode") + } +} + +func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { + var sb strings.Builder + for i, bi := range b.QueuedQueries { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, + } +} + +func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { + batch := &pgconn.Batch{} + + for _, bi := range b.QueuedQueries { + sd := bi.sd + if sd != nil { + err := c.eqb.Build(c.typeMap, sd, bi.Arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + err := c.eqb.Build(c.typeMap, nil, bi.Arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) + } + } + + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + + mrr := c.pgConn.ExecBatch(ctx, batch) + + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, + } +} + +func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.statementCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.QueuedQueries { + if bi.sd == nil { + sd := c.statementCache.Get(bi.SQL) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + Name: stmtcache.StatementName(bi.SQL), + SQL: bi.SQL, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache) +} + +func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.descriptionCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.QueuedQueries { + if bi.sd == nil { + sd := c.descriptionCache.Get(bi.SQL) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + SQL: bi.SQL, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache) +} + +func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.QueuedQueries { + if bi.sd == nil { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd := &pgconn.StatementDescription{ + SQL: bi.SQL, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil) +} + +func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { + pipeline := c.pgConn.StartPipeline(ctx) + defer func() { + if pbr != nil && pbr.err != nil { + pipeline.Close() + } + }() + + // Prepare any needed queries + if len(distinctNewQueries) > 0 { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } + + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + } + + _, ok := results.(*pgconn.PipelineSync) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} + } + } + + // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } + + // Queue the queries. + for _, bi := range b.QueuedQueries { + err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments) + if err != nil { + // we wrap the error so we the user can understand which query failed inside the batch + err = fmt.Errorf("error building query %s: %w", bi.SQL, err) + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + } + + if bi.sd.Name == "" { + pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + } + + return &pipelineBatchResults{ + ctx: ctx, + conn: c, + pipeline: pipeline, + b: b, + } +} + +func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { + if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") + } + + if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") + } + + var err error + valueArgs := make([]any, len(args)) + for i, a := range args { + valueArgs[i], err = convertSimpleArgument(c.typeMap, a) + if err != nil { + return "", err + } + } + + return sanitize.SanitizeSQL(sql, valueArgs...) +} + +// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be +// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular, +// typeName must be one of the following: +// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered. +// - A composite type name where all field types are already registered. +// - A domain type name where the base type is already registered. +// - An enum type name. +// - A range type name where the element type is already registered. +// - A multirange type name where the element type is already registered. +func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) { + var oid uint32 + + err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return nil, err + } + + var typtype string + var typbasetype uint32 + + err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype) + if err != nil { + return nil, err + } + + switch typtype { + case "b": // array + elementOID, err := c.getArrayElementOID(ctx, oid) + if err != nil { + return nil, err + } + + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("array element OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}, nil + case "c": // composite + fields, err := c.getCompositeFields(ctx, oid) + if err != nil { + return nil, err + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil + case "d": // domain + dt, ok := c.TypeMap().TypeForOID(typbasetype) + if !ok { + return nil, errors.New("domain base type OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil + case "e": // enum + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil + case "r": // range + elementOID, err := c.getRangeElementOID(ctx, oid) + if err != nil { + return nil, err + } + + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("range element OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil + case "m": // multirange + elementOID, err := c.getMultiRangeElementOID(ctx, oid) + if err != nil { + return nil, err + } + + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("multirange element OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil + default: + return &pgtype.Type{}, errors.New("unknown typtype") + } +} + +func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) { + var typrelid uint32 + + err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) + if err != nil { + return nil, err + } + + var fields []pgtype.CompositeCodecField + var fieldName string + var fieldOID uint32 + rows, _ := c.Query(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 + and not attisdropped + and attnum > 0 +order by attnum`, + typrelid, + ) + _, err = ForEachRow(rows, []any{&fieldName, &fieldOID}, func() error { + dt, ok := c.TypeMap().TypeForOID(fieldOID) + if !ok { + return fmt.Errorf("unknown composite type field OID: %v", fieldOID) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + return nil + }) + if err != nil { + return nil, err + } + + return fields, nil +} + +func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { + if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' { + return nil + } + + if c.descriptionCache != nil { + c.descriptionCache.RemoveInvalidated() + } + + var invalidatedStatements []*pgconn.StatementDescription + if c.statementCache != nil { + invalidatedStatements = c.statementCache.GetInvalidated() + } + + if len(invalidatedStatements) == 0 { + return nil + } + + pipeline := c.pgConn.StartPipeline(ctx) + defer pipeline.Close() + + for _, sd := range invalidatedStatements { + pipeline.SendDeallocate(sd.Name) + } + + err := pipeline.Sync() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + err = pipeline.Close() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + c.statementCache.RemoveInvalidated() + for _, sd := range invalidatedStatements { + delete(c.preparedStatements, sd.Name) + } + + return nil +} -- cgit v1.2.3-70-g09d2