aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/extended_query_builder.go
blob: 526b0e953be4fdf1c6c5c65095254091ab981b52 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package pgx

import (
	"fmt"

	"github.com/jackc/pgx/v5/pgconn"
	"github.com/jackc/pgx/v5/pgtype"
)

// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
// formats for an extended query.
type ExtendedQueryBuilder struct {
	ParamValues     [][]byte
	paramValueBytes []byte
	ParamFormats    []int16
	ResultFormats   []int16
}

// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
// sd is nil then QueryExecModeExec behavior will be used.
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
	eqb.reset()

	if sd == nil {
		for i := range args {
			err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
			if err != nil {
				err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
				return err
			}
		}
		return nil
	}

	if len(sd.ParamOIDs) != len(args) {
		return fmt.Errorf("mismatched param and argument count")
	}

	for i := range args {
		err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
		if err != nil {
			err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
			return err
		}
	}

	for i := range sd.Fields {
		eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
	}

	return nil
}

// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
// must be an untyped nil.
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
	if format == -1 {
		preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
		preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
		if preferredErr == nil {
			return nil
		}

		var otherFormat int16
		if preferredFormat == TextFormatCode {
			otherFormat = BinaryFormatCode
		} else {
			otherFormat = TextFormatCode
		}

		otherErr := eqb.appendParam(m, oid, otherFormat, arg)
		if otherErr == nil {
			return nil
		}

		return preferredErr // return the error from the preferred format
	}

	v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
	if err != nil {
		return err
	}

	eqb.ParamFormats = append(eqb.ParamFormats, format)
	eqb.ParamValues = append(eqb.ParamValues, v)

	return nil
}

// appendResultFormat appends a result format to the query.
func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
	eqb.ResultFormats = append(eqb.ResultFormats, format)
}

// reset readies eqb to build another query.
func (eqb *ExtendedQueryBuilder) reset() {
	eqb.ParamValues = eqb.ParamValues[0:0]
	eqb.paramValueBytes = eqb.paramValueBytes[0:0]
	eqb.ParamFormats = eqb.ParamFormats[0:0]
	eqb.ResultFormats = eqb.ResultFormats[0:0]

	if cap(eqb.ParamValues) > 64 {
		eqb.ParamValues = make([][]byte, 0, 64)
	}

	if cap(eqb.paramValueBytes) > 256 {
		eqb.paramValueBytes = make([]byte, 0, 256)
	}

	if cap(eqb.ParamFormats) > 64 {
		eqb.ParamFormats = make([]int16, 0, 64)
	}
	if cap(eqb.ResultFormats) > 64 {
		eqb.ResultFormats = make([]int16, 0, 64)
	}
}

func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
	if eqb.paramValueBytes == nil {
		eqb.paramValueBytes = make([]byte, 0, 128)
	}

	pos := len(eqb.paramValueBytes)

	buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
	if err != nil {
		return nil, err
	}
	if buf == nil {
		return nil, nil
	}
	eqb.paramValueBytes = buf
	return eqb.paramValueBytes[pos:], nil
}

// chooseParameterFormatCode determines the correct format code for an
// argument to a prepared statement. It defaults to TextFormatCode if no
// determination can be made.
func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
	switch arg.(type) {
	case string, *string:
		return TextFormatCode
	}

	return m.FormatCodeForOID(oid)
}