package pgx
import (
"fmt"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
// formats for an extended query.
type ExtendedQueryBuilder struct {
ParamValues [][]byte
paramValueBytes []byte
ParamFormats []int16
ResultFormats []int16
}
// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
// sd is nil then QueryExecModeExec behavior will be used.
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset()
if sd == nil {
for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
}
if len(sd.ParamOIDs) != len(args) {
return fmt.Errorf("mismatched param and argument count")
}
for i := range args {
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
for i := range sd.Fields {
eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
return nil
}
// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
// must be an untyped nil.
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
if format == -1 {
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
if preferredErr == nil {
return nil
}
var otherFormat int16
if preferredFormat == TextFormatCode {
otherFormat = BinaryFormatCode
} else {
otherFormat = TextFormatCode
}
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
if otherErr == nil {
return nil
}
return preferredErr // return the error from the preferred format
}
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
if err != nil {
return err
}
eqb.ParamFormats = append(eqb.ParamFormats, format)
eqb.ParamValues = append(eqb.ParamValues, v)
return nil
}
// appendResultFormat appends a result format to the query.
func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
eqb.ResultFormats = append(eqb.ResultFormats, format)
}
// reset readies eqb to build another query.
func (eqb *ExtendedQueryBuilder) reset() {
eqb.ParamValues = eqb.ParamValues[0:0]
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
eqb.ParamFormats = eqb.ParamFormats[0:0]
eqb.ResultFormats = eqb.ResultFormats[0:0]
if cap(eqb.ParamValues) > 64 {
eqb.ParamValues = make([][]byte, 0, 64)
}
if cap(eqb.paramValueBytes) > 256 {
eqb.paramValueBytes = make([]byte, 0, 256)
}
if cap(eqb.ParamFormats) > 64 {
eqb.ParamFormats = make([]int16, 0, 64)
}
if cap(eqb.ResultFormats) > 64 {
eqb.ResultFormats = make([]int16, 0, 64)
}
}
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
pos := len(eqb.paramValueBytes)
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
// chooseParameterFormatCode determines the correct format code for an
// argument to a prepared statement. It defaults to TextFormatCode if no
// determination can be made.
func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
switch arg.(type) {
case string, *string:
return TextFormatCode
}
return m.FormatCodeForOID(oid)
}