aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/rows.go
diff options
context:
space:
mode:
authorGibheer <gibheer+git@zero-knowledge.org>2024-09-05 19:38:25 +0200
committerGibheer <gibheer+git@zero-knowledge.org>2024-09-05 19:38:25 +0200
commit6ea4d2c82de80efc87708e5e182034b7c6c2019e (patch)
tree35c0856a929040216c82153ca62d43b27530a887 /vendor/github.com/jackc/pgx/v5/rows.go
parent6f64eeace1b66639b9380b44e88a8d54850a4306 (diff)
switch from github.com/lib/pq to github.com/jackc/pgx/v5HEAD20240905master
lib/pq is out of maintenance for some time now, so switch to the newer more active library. Looks like it finally stabilized after a long time.
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/rows.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/rows.go851
1 files changed, 851 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go
new file mode 100644
index 0000000..d4f7a90
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/rows.go
@@ -0,0 +1,851 @@
+package pgx
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jackc/pgx/v5/pgconn"
+ "github.com/jackc/pgx/v5/pgtype"
+)
+
+// Rows is the result set returned from *Conn.Query. Rows must be closed before
+// the *Conn can be used again. Rows are closed by explicitly calling Close(),
+// calling Next() until it returns false, or when a fatal error occurs.
+//
+// Once a Rows is closed the only methods that may be called are Close(), Err(),
+// and CommandTag().
+//
+// Rows is an interface instead of a struct to allow tests to mock Query. However,
+// adding a method to an interface is technically a breaking change. Because of this
+// the Rows interface is partially excluded from semantic version requirements.
+// Methods will not be removed or changed, but new methods may be added.
+type Rows interface {
+ // Close closes the rows, making the connection ready for use again. It is safe
+ // to call Close after rows is already closed.
+ Close()
+
+ // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by
+ // calling Close or by Next returning false). If it is called early it may return nil even if there was an error
+ // executing the query.
+ Err() error
+
+ // CommandTag returns the command tag from this query. It is only available after Rows is closed.
+ CommandTag() pgconn.CommandTag
+
+ // FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur
+ // when there was an error executing the query.
+ FieldDescriptions() []pgconn.FieldDescription
+
+ // Next prepares the next row for reading. It returns true if there is another
+ // row and false if no more rows are available or a fatal error has occurred.
+ // It automatically closes rows when all rows are read.
+ //
+ // Callers should check rows.Err() after rows.Next() returns false to detect
+ // whether result-set reading ended prematurely due to an error. See
+ // Conn.Query for details.
+ //
+ // For simpler error handling, consider using the higher-level pgx v5
+ // CollectRows() and ForEachRow() helpers instead.
+ Next() bool
+
+ // Scan reads the values from the current row into dest values positionally.
+ // dest can include pointers to core types, values implementing the Scanner
+ // interface, and nil. nil will skip the value entirely. It is an error to
+ // call Scan without first calling Next() and checking that it returned true.
+ Scan(dest ...any) error
+
+ // Values returns the decoded row values. As with Scan(), it is an error to
+ // call Values without first calling Next() and checking that it returned
+ // true.
+ Values() ([]any, error)
+
+ // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next
+ // call or the Rows is closed.
+ RawValues() [][]byte
+
+ // Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a
+ // *Conn (e.g. if it was created by RowsFromResultReader)
+ Conn() *Conn
+}
+
+// Row is a convenience wrapper over Rows that is returned by QueryRow.
+//
+// Row is an interface instead of a struct to allow tests to mock QueryRow. However,
+// adding a method to an interface is technically a breaking change. Because of this
+// the Row interface is partially excluded from semantic version requirements.
+// Methods will not be removed or changed, but new methods may be added.
+type Row interface {
+ // Scan works the same as Rows. with the following exceptions. If no
+ // rows were found it returns ErrNoRows. If multiple rows are returned it
+ // ignores all but the first.
+ Scan(dest ...any) error
+}
+
+// RowScanner scans an entire row at a time into the RowScanner.
+type RowScanner interface {
+ // ScanRows scans the row.
+ ScanRow(rows Rows) error
+}
+
+// connRow implements the Row interface for Conn.QueryRow.
+type connRow baseRows
+
+func (r *connRow) Scan(dest ...any) (err error) {
+ rows := (*baseRows)(r)
+
+ if rows.Err() != nil {
+ return rows.Err()
+ }
+
+ for _, d := range dest {
+ if _, ok := d.(*pgtype.DriverBytes); ok {
+ rows.Close()
+ return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow")
+ }
+ }
+
+ if !rows.Next() {
+ if rows.Err() == nil {
+ return ErrNoRows
+ }
+ return rows.Err()
+ }
+
+ rows.Scan(dest...)
+ rows.Close()
+ return rows.Err()
+}
+
+// baseRows implements the Rows interface for Conn.Query.
+type baseRows struct {
+ typeMap *pgtype.Map
+ resultReader *pgconn.ResultReader
+
+ values [][]byte
+
+ commandTag pgconn.CommandTag
+ err error
+ closed bool
+
+ scanPlans []pgtype.ScanPlan
+ scanTypes []reflect.Type
+
+ conn *Conn
+ multiResultReader *pgconn.MultiResultReader
+
+ queryTracer QueryTracer
+ batchTracer BatchTracer
+ ctx context.Context
+ startTime time.Time
+ sql string
+ args []any
+ rowCount int
+}
+
+func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription {
+ return rows.resultReader.FieldDescriptions()
+}
+
+func (rows *baseRows) Close() {
+ if rows.closed {
+ return
+ }
+
+ rows.closed = true
+
+ if rows.resultReader != nil {
+ var closeErr error
+ rows.commandTag, closeErr = rows.resultReader.Close()
+ if rows.err == nil {
+ rows.err = closeErr
+ }
+ }
+
+ if rows.multiResultReader != nil {
+ closeErr := rows.multiResultReader.Close()
+ if rows.err == nil {
+ rows.err = closeErr
+ }
+ }
+
+ if rows.err != nil && rows.conn != nil && rows.sql != "" {
+ if sc := rows.conn.statementCache; sc != nil {
+ sc.Invalidate(rows.sql)
+ }
+
+ if sc := rows.conn.descriptionCache; sc != nil {
+ sc.Invalidate(rows.sql)
+ }
+ }
+
+ if rows.batchTracer != nil {
+ rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err})
+ } else if rows.queryTracer != nil {
+ rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err})
+ }
+}
+
+func (rows *baseRows) CommandTag() pgconn.CommandTag {
+ return rows.commandTag
+}
+
+func (rows *baseRows) Err() error {
+ return rows.err
+}
+
+// fatal signals an error occurred after the query was sent to the server. It
+// closes the rows automatically.
+func (rows *baseRows) fatal(err error) {
+ if rows.err != nil {
+ return
+ }
+
+ rows.err = err
+ rows.Close()
+}
+
+func (rows *baseRows) Next() bool {
+ if rows.closed {
+ return false
+ }
+
+ if rows.resultReader.NextRow() {
+ rows.rowCount++
+ rows.values = rows.resultReader.Values()
+ return true
+ } else {
+ rows.Close()
+ return false
+ }
+}
+
+func (rows *baseRows) Scan(dest ...any) error {
+ m := rows.typeMap
+ fieldDescriptions := rows.FieldDescriptions()
+ values := rows.values
+
+ if len(fieldDescriptions) != len(values) {
+ err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
+ rows.fatal(err)
+ return err
+ }
+
+ if len(dest) == 1 {
+ if rc, ok := dest[0].(RowScanner); ok {
+ err := rc.ScanRow(rows)
+ if err != nil {
+ rows.fatal(err)
+ }
+ return err
+ }
+ }
+
+ if len(fieldDescriptions) != len(dest) {
+ err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
+ rows.fatal(err)
+ return err
+ }
+
+ if rows.scanPlans == nil {
+ rows.scanPlans = make([]pgtype.ScanPlan, len(values))
+ rows.scanTypes = make([]reflect.Type, len(values))
+ for i := range dest {
+ rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
+ rows.scanTypes[i] = reflect.TypeOf(dest[i])
+ }
+ }
+
+ for i, dst := range dest {
+ if dst == nil {
+ continue
+ }
+
+ if rows.scanTypes[i] != reflect.TypeOf(dst) {
+ rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
+ rows.scanTypes[i] = reflect.TypeOf(dest[i])
+ }
+
+ err := rows.scanPlans[i].Scan(values[i], dst)
+ if err != nil {
+ err = ScanArgError{ColumnIndex: i, Err: err}
+ rows.fatal(err)
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (rows *baseRows) Values() ([]any, error) {
+ if rows.closed {
+ return nil, errors.New("rows is closed")
+ }
+
+ values := make([]any, 0, len(rows.FieldDescriptions()))
+
+ for i := range rows.FieldDescriptions() {
+ buf := rows.values[i]
+ fd := &rows.FieldDescriptions()[i]
+
+ if buf == nil {
+ values = append(values, nil)
+ continue
+ }
+
+ if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok {
+ value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf)
+ if err != nil {
+ rows.fatal(err)
+ }
+ values = append(values, value)
+ } else {
+ switch fd.Format {
+ case TextFormatCode:
+ values = append(values, string(buf))
+ case BinaryFormatCode:
+ newBuf := make([]byte, len(buf))
+ copy(newBuf, buf)
+ values = append(values, newBuf)
+ default:
+ rows.fatal(errors.New("unknown format code"))
+ }
+ }
+
+ if rows.Err() != nil {
+ return nil, rows.Err()
+ }
+ }
+
+ return values, rows.Err()
+}
+
+func (rows *baseRows) RawValues() [][]byte {
+ return rows.values
+}
+
+func (rows *baseRows) Conn() *Conn {
+ return rows.conn
+}
+
+type ScanArgError struct {
+ ColumnIndex int
+ Err error
+}
+
+func (e ScanArgError) Error() string {
+ return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
+}
+
+func (e ScanArgError) Unwrap() error {
+ return e.Err
+}
+
+// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface.
+//
+// typeMap - OID to Go type mapping.
+// fieldDescriptions - OID and format of values
+// values - the raw data as returned from the PostgreSQL server
+// dest - the destination that values will be decoded into
+func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error {
+ if len(fieldDescriptions) != len(values) {
+ return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
+ }
+ if len(fieldDescriptions) != len(dest) {
+ return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
+ }
+
+ for i, d := range dest {
+ if d == nil {
+ continue
+ }
+
+ err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
+ if err != nil {
+ return ScanArgError{ColumnIndex: i, Err: err}
+ }
+ }
+
+ return nil
+}
+
+// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used
+// to read from the lower level pgconn interface.
+func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows {
+ return &baseRows{
+ typeMap: typeMap,
+ resultReader: resultReader,
+ }
+}
+
+// ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row
+// fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed
+// when ForEachRow returns.
+func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) {
+ defer rows.Close()
+
+ for rows.Next() {
+ err := rows.Scan(scans...)
+ if err != nil {
+ return pgconn.CommandTag{}, err
+ }
+
+ err = fn()
+ if err != nil {
+ return pgconn.CommandTag{}, err
+ }
+ }
+
+ if err := rows.Err(); err != nil {
+ return pgconn.CommandTag{}, err
+ }
+
+ return rows.CommandTag(), nil
+}
+
+// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call.
+type CollectableRow interface {
+ FieldDescriptions() []pgconn.FieldDescription
+ Scan(dest ...any) error
+ Values() ([]any, error)
+ RawValues() [][]byte
+}
+
+// RowToFunc is a function that scans or otherwise converts row to a T.
+type RowToFunc[T any] func(row CollectableRow) (T, error)
+
+// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
+//
+// This function closes the rows automatically on return.
+func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
+ defer rows.Close()
+
+ for rows.Next() {
+ value, err := fn(rows)
+ if err != nil {
+ return nil, err
+ }
+ slice = append(slice, value)
+ }
+
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return slice, nil
+}
+
+// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
+//
+// This function closes the rows automatically on return.
+func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
+ return AppendRows([]T{}, rows, fn)
+}
+
+// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
+// CollectOneRow is to CollectRows as QueryRow is to Query.
+//
+// This function closes the rows automatically on return.
+func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
+ defer rows.Close()
+
+ var value T
+ var err error
+
+ if !rows.Next() {
+ if err = rows.Err(); err != nil {
+ return value, err
+ }
+ return value, ErrNoRows
+ }
+
+ value, err = fn(rows)
+ if err != nil {
+ return value, err
+ }
+
+ rows.Close()
+ return value, rows.Err()
+}
+
+// CollectExactlyOneRow calls fn for the first row in rows and returns the result.
+// - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
+// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
+//
+// This function closes the rows automatically on return.
+func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
+ defer rows.Close()
+
+ var (
+ err error
+ value T
+ )
+
+ if !rows.Next() {
+ if err = rows.Err(); err != nil {
+ return value, err
+ }
+
+ return value, ErrNoRows
+ }
+
+ value, err = fn(rows)
+ if err != nil {
+ return value, err
+ }
+
+ if rows.Next() {
+ var zero T
+
+ return zero, ErrTooManyRows
+ }
+
+ return value, rows.Err()
+}
+
+// RowTo returns a T scanned from row.
+func RowTo[T any](row CollectableRow) (T, error) {
+ var value T
+ err := row.Scan(&value)
+ return value, err
+}
+
+// RowTo returns a the address of a T scanned from row.
+func RowToAddrOf[T any](row CollectableRow) (*T, error) {
+ var value T
+ err := row.Scan(&value)
+ return &value, err
+}
+
+// RowToMap returns a map scanned from row.
+func RowToMap(row CollectableRow) (map[string]any, error) {
+ var value map[string]any
+ err := row.Scan((*mapRowScanner)(&value))
+ return value, err
+}
+
+type mapRowScanner map[string]any
+
+func (rs *mapRowScanner) ScanRow(rows Rows) error {
+ values, err := rows.Values()
+ if err != nil {
+ return err
+ }
+
+ *rs = make(mapRowScanner, len(values))
+
+ for i := range values {
+ (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i]
+ }
+
+ return nil
+}
+
+// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
+// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be
+// ignored.
+func RowToStructByPos[T any](row CollectableRow) (T, error) {
+ var value T
+ err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
+ return value, err
+}
+
+// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
+// public fields as row has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then
+// the field will be ignored.
+func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
+ var value T
+ err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
+ return &value, err
+}
+
+type positionalStructRowScanner struct {
+ ptrToStruct any
+}
+
+func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
+ typ := reflect.TypeOf(rs.ptrToStruct).Elem()
+ fields := lookupStructFields(typ)
+ if len(rows.RawValues()) > len(fields) {
+ return fmt.Errorf(
+ "got %d values, but dst struct has only %d fields",
+ len(rows.RawValues()),
+ len(fields),
+ )
+ }
+ scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
+ return rows.Scan(scanTargets...)
+}
+
+// Map from reflect.Type -> []structRowField
+var positionalStructFieldMap sync.Map
+
+func lookupStructFields(t reflect.Type) []structRowField {
+ if cached, ok := positionalStructFieldMap.Load(t); ok {
+ return cached.([]structRowField)
+ }
+
+ fieldStack := make([]int, 0, 1)
+ fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
+ fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
+ return fieldsIface.([]structRowField)
+}
+
+func computeStructFields(
+ t reflect.Type,
+ fields []structRowField,
+ fieldStack *[]int,
+) []structRowField {
+ tail := len(*fieldStack)
+ *fieldStack = append(*fieldStack, 0)
+ for i := 0; i < t.NumField(); i++ {
+ sf := t.Field(i)
+ (*fieldStack)[tail] = i
+ // Handle anonymous struct embedding, but do not try to handle embedded pointers.
+ if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
+ fields = computeStructFields(sf.Type, fields, fieldStack)
+ } else if sf.PkgPath == "" {
+ dbTag, _ := sf.Tag.Lookup(structTagKey)
+ if dbTag == "-" {
+ // Field is ignored, skip it.
+ continue
+ }
+ fields = append(fields, structRowField{
+ path: append([]int(nil), *fieldStack...),
+ })
+ }
+ }
+ *fieldStack = (*fieldStack)[:tail]
+ return fields
+}
+
+// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
+// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
+// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
+func RowToStructByName[T any](row CollectableRow) (T, error) {
+ var value T
+ err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
+ return value, err
+}
+
+// RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number
+// of named public fields as row has fields. The row and T fields will be matched by name. The match is
+// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
+// then the field will be ignored.
+func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
+ var value T
+ err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
+ return &value, err
+}
+
+// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public
+// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
+// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
+func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
+ var value T
+ err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
+ return value, err
+}
+
+// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or
+// equal number of named public fields as row has fields. The row and T fields will be matched by name. The match is
+// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
+// then the field will be ignored.
+func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
+ var value T
+ err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
+ return &value, err
+}
+
+type namedStructRowScanner struct {
+ ptrToStruct any
+ lax bool
+}
+
+func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
+ typ := reflect.TypeOf(rs.ptrToStruct).Elem()
+ fldDescs := rows.FieldDescriptions()
+ namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
+ if err != nil {
+ return err
+ }
+ if !rs.lax && namedStructFields.missingField != "" {
+ return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
+ }
+ fields := namedStructFields.fields
+ scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
+ return rows.Scan(scanTargets...)
+}
+
+// Map from namedStructFieldMap -> *namedStructFields
+var namedStructFieldMap sync.Map
+
+type namedStructFieldsKey struct {
+ t reflect.Type
+ colNames string
+}
+
+type namedStructFields struct {
+ fields []structRowField
+ // missingField is the first field from the struct without a corresponding row field.
+ // This is used to construct the correct error message for non-lax queries.
+ missingField string
+}
+
+func lookupNamedStructFields(
+ t reflect.Type,
+ fldDescs []pgconn.FieldDescription,
+) (*namedStructFields, error) {
+ key := namedStructFieldsKey{
+ t: t,
+ colNames: joinFieldNames(fldDescs),
+ }
+ if cached, ok := namedStructFieldMap.Load(key); ok {
+ return cached.(*namedStructFields), nil
+ }
+
+ // We could probably do two-levels of caching, where we compute the key -> fields mapping
+ // for a type only once, cache it by type, then use that to compute the column -> fields
+ // mapping for a given set of columns.
+ fieldStack := make([]int, 0, 1)
+ fields, missingField := computeNamedStructFields(
+ fldDescs,
+ t,
+ make([]structRowField, len(fldDescs)),
+ &fieldStack,
+ )
+ for i, f := range fields {
+ if f.path == nil {
+ return nil, fmt.Errorf(
+ "struct doesn't have corresponding row field %s",
+ fldDescs[i].Name,
+ )
+ }
+ }
+
+ fieldsIface, _ := namedStructFieldMap.LoadOrStore(
+ key,
+ &namedStructFields{fields: fields, missingField: missingField},
+ )
+ return fieldsIface.(*namedStructFields), nil
+}
+
+func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
+ switch len(fldDescs) {
+ case 0:
+ return ""
+ case 1:
+ return fldDescs[0].Name
+ }
+
+ totalSize := len(fldDescs) - 1 // Space for separator bytes.
+ for _, d := range fldDescs {
+ totalSize += len(d.Name)
+ }
+ var b strings.Builder
+ b.Grow(totalSize)
+ b.WriteString(fldDescs[0].Name)
+ for _, d := range fldDescs[1:] {
+ b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
+ b.WriteString(d.Name)
+ }
+ return b.String()
+}
+
+func computeNamedStructFields(
+ fldDescs []pgconn.FieldDescription,
+ t reflect.Type,
+ fields []structRowField,
+ fieldStack *[]int,
+) ([]structRowField, string) {
+ var missingField string
+ tail := len(*fieldStack)
+ *fieldStack = append(*fieldStack, 0)
+ for i := 0; i < t.NumField(); i++ {
+ sf := t.Field(i)
+ (*fieldStack)[tail] = i
+ if sf.PkgPath != "" && !sf.Anonymous {
+ // Field is unexported, skip it.
+ continue
+ }
+ // Handle anonymous struct embedding, but do not try to handle embedded pointers.
+ if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
+ var missingSubField string
+ fields, missingSubField = computeNamedStructFields(
+ fldDescs,
+ sf.Type,
+ fields,
+ fieldStack,
+ )
+ if missingField == "" {
+ missingField = missingSubField
+ }
+ } else {
+ dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
+ if dbTagPresent {
+ dbTag, _, _ = strings.Cut(dbTag, ",")
+ }
+ if dbTag == "-" {
+ // Field is ignored, skip it.
+ continue
+ }
+ colName := dbTag
+ if !dbTagPresent {
+ colName = sf.Name
+ }
+ fpos := fieldPosByName(fldDescs, colName)
+ if fpos == -1 {
+ if missingField == "" {
+ missingField = colName
+ }
+ continue
+ }
+ fields[fpos] = structRowField{
+ path: append([]int(nil), *fieldStack...),
+ }
+ }
+ }
+ *fieldStack = (*fieldStack)[:tail]
+
+ return fields, missingField
+}
+
+const structTagKey = "db"
+
+func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
+ i = -1
+ for i, desc := range fldDescs {
+
+ // Snake case support.
+ field = strings.ReplaceAll(field, "_", "")
+ descName := strings.ReplaceAll(desc.Name, "_", "")
+
+ if strings.EqualFold(descName, field) {
+ return i
+ }
+ }
+ return
+}
+
+// structRowField describes a field of a struct.
+//
+// TODO: It would be a bit more efficient to track the path using the pointer
+// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
+// construct references when scanning rows. However, it's not clear it's worth
+// using unsafe for this.
+type structRowField struct {
+ path []int
+}
+
+func setupStructScanTargets(receiver any, fields []structRowField) []any {
+ scanTargets := make([]any, len(fields))
+ v := reflect.ValueOf(receiver).Elem()
+ for i, f := range fields {
+ scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
+ }
+ return scanTargets
+}