switch from github.com/lib/pq to github.com/jackc/pgx/v5

lib/pq is out of maintenance for some time now, so switch to the newer
more active library. Looks like it finally stabilized after a long time.
This commit is contained in:
Gibheer 2024-09-05 19:38:25 +02:00
parent 6f64eeace1
commit 6ea4d2c82d
656 changed files with 176723 additions and 16580 deletions

View File

@ -18,6 +18,7 @@ import (
"time"
"git.zero-knowledge.org/gibheer/monzero"
_ "github.com/jackc/pgx/v5/stdlib"
)
var (
@ -73,7 +74,7 @@ func main() {
os.Exit(1)
}
db, err := sql.Open("postgres", config.DB)
db, err := sql.Open("pgx", config.DB)
if err != nil {
logger.Error("could not open database connection", "error", err)
os.Exit(1)

View File

@ -4,8 +4,6 @@ import (
"database/sql"
"net/http"
"time"
"github.com/lib/pq"
)
type (
@ -62,7 +60,7 @@ type (
State int
Output string
Inserted time.Time
Sent pq.NullTime
Sent sql.NullTime
NotifierName string
MappingId int
}
@ -91,7 +89,7 @@ func showCheck(con *Context) {
err := DB.QueryRow(query, id[0]).Scan(&cd.Id, &cd.Name, &cd.Message, &cd.Enabled,
&cd.Updated, &cd.LastRefresh, &cd.MappingId, &cd.MappingName, &cd.NodeId,
&cd.NodeName, &cd.NodeMessage, &cd.CommandId, &cd.CommandName, &cd.CommandMessage,
pq.Array(&cd.CommandLine), pq.Array(&cd.States), &cd.Notice, &cd.NextTime,
&cd.CommandLine, &cd.States, &cd.Notice, &cd.NextTime,
&cd.CheckerID, &cd.CheckerName, &cd.CheckerMsg)
if err != nil && err == sql.ErrNoRows {
con.w.Header()["Location"] = []string{"/"}

View File

@ -18,7 +18,7 @@ import (
"time"
"github.com/BurntSushi/toml"
"github.com/lib/pq"
_ "github.com/jackc/pgx/v5/stdlib"
"golang.org/x/term"
)
@ -107,7 +107,7 @@ func main() {
logger := parseLogger(config)
db, err := sql.Open("postgres", config.DB)
db, err := sql.Open("pgx", config.DB)
if err != nil {
log.Fatalf("could not open database connection: %s", err)
}
@ -278,7 +278,7 @@ func checkAction(con *Context) {
case "disable":
setClause = "enabled = false, updated = now()"
case "delete_check":
if _, err := DB.Exec(`delete from checks where id = any ($1::bigint[])`, pq.Array(checks)); err != nil {
if _, err := DB.Exec(`delete from checks where id = any ($1::bigint[])`, checks); err != nil {
con.log.Info("could not delete checks", "checks", checks, "error", err)
con.Error = "could not delete checks"
returnError(http.StatusInternalServerError, con, con.w)
@ -321,7 +321,7 @@ func checkAction(con *Context) {
cn.notifier_id, $2
from checks_notify cn
join active_checks ac on cn.check_id = ac.check_id
where cn.check_id = any ($1::bigint[])`, pq.Array(&checks), &hostname); err != nil {
where cn.check_id = any ($1::bigint[])`, &checks, &hostname); err != nil {
con.log.Info("could not acknowledge check", "error", err)
con.Error = "could not acknowledge check"
returnError(http.StatusInternalServerError, con, con.w)
@ -334,7 +334,7 @@ func checkAction(con *Context) {
}
_, err := DB.Exec(
"update active_checks set notice = $2 where check_id = any ($1::bigint[]);",
pq.Array(&checks),
&checks,
comment)
if err != nil {
con.w.WriteHeader(http.StatusInternalServerError)
@ -346,7 +346,7 @@ func checkAction(con *Context) {
return
case "uncomment":
_, err := DB.Exec(`update active_checks set notice = null where check_id = any($1::bigint[]);`,
pq.Array(&checks))
&checks)
if err != nil {
con.Error = "could not uncomment checks"
returnError(http.StatusInternalServerError, con, con.w)
@ -366,7 +366,7 @@ func checkAction(con *Context) {
}
sql := "update " + setTable + " set " + setClause + " where " + whereColumn + " = any($1::bigint[])"
whereVals = append([]any{pq.Array(&checks)}, whereVals...)
whereVals = append([]any{&checks}, whereVals...)
if len(whereFields) > 0 {
for i, column := range whereFields {
sql = sql + " and " + column + fmt.Sprintf(" = $%d", i+1)

View File

@ -11,7 +11,7 @@ import (
"text/template"
"time"
"github.com/lib/pq"
_ "github.com/jackc/pgx/v5/stdlib"
)
var (
@ -42,7 +42,7 @@ func main() {
log.Fatalf("could not parse check interval: %s", err)
}
db, err := sql.Open("postgres", config.DB)
db, err := sql.Open("pgx", config.DB)
if err != nil {
log.Fatalf("could not open database connection: %s", err)
}
@ -167,7 +167,7 @@ func startConfigGen(db *sql.DB, checkInterval time.Duration) {
time.Sleep(checkInterval)
continue
}
if _, err := tx.Exec(SQLRefreshActiveCheck, check_id, pq.Array(stringToShellFields(cmd.Bytes()))); err != nil {
if _, err := tx.Exec(SQLRefreshActiveCheck, check_id, stringToShellFields(cmd.Bytes())); err != nil {
tx.Rollback()
log.Printf("could not refresh check '%d': %s", check_id, err)
continue

15
go.mod
View File

@ -3,12 +3,17 @@ module git.zero-knowledge.org/gibheer/monzero
go 1.18
require (
github.com/BurntSushi/toml v1.2.1
github.com/lib/pq v1.10.7
golang.org/x/crypto v0.2.0
github.com/BurntSushi/toml v1.4.0
github.com/jackc/pgx/v5 v5.6.0
golang.org/x/crypto v0.27.0
golang.org/x/term v0.24.0
)
require (
golang.org/x/sys v0.2.0 // indirect
golang.org/x/term v0.2.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
)

41
go.sum
View File

@ -1,10 +1,31 @@
github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak=
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
golang.org/x/crypto v0.2.0 h1:BRXPfhNivWL5Yq0BGQ39a2sW6t44aODpfxkWjYdzewE=
golang.org/x/crypto v0.2.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.2.0 h1:z85xZCsEl7bi/KwbNADeBYoOP0++7W1ipu+aGnpwzRM=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@ -9,7 +9,7 @@ See the [releases page](https://github.com/BurntSushi/toml/releases) for a
changelog; this information is also in the git tag annotations (e.g. `git show
v0.4.0`).
This library requires Go 1.13 or newer; add it to your go.mod with:
This library requires Go 1.18 or newer; add it to your go.mod with:
% go get github.com/BurntSushi/toml@latest

View File

@ -6,7 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"io/fs"
"math"
"os"
"reflect"
@ -18,13 +18,13 @@ import (
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
UnmarshalTOML(any) error
}
// Unmarshal decodes the contents of data in TOML format into a pointer v.
//
// See [Decoder] for a description of the decoding process.
func Unmarshal(data []byte, v interface{}) error {
func Unmarshal(data []byte, v any) error {
_, err := NewDecoder(bytes.NewReader(data)).Decode(v)
return err
}
@ -32,12 +32,12 @@ func Unmarshal(data []byte, v interface{}) error {
// Decode the TOML data in to the pointer v.
//
// See [Decoder] for a description of the decoding process.
func Decode(data string, v interface{}) (MetaData, error) {
func Decode(data string, v any) (MetaData, error) {
return NewDecoder(strings.NewReader(data)).Decode(v)
}
// DecodeFile reads the contents of a file and decodes it with [Decode].
func DecodeFile(path string, v interface{}) (MetaData, error) {
func DecodeFile(path string, v any) (MetaData, error) {
fp, err := os.Open(path)
if err != nil {
return MetaData{}, err
@ -46,6 +46,17 @@ func DecodeFile(path string, v interface{}) (MetaData, error) {
return NewDecoder(fp).Decode(v)
}
// DecodeFS reads the contents of a file from [fs.FS] and decodes it with
// [Decode].
func DecodeFS(fsys fs.FS, path string, v any) (MetaData, error) {
fp, err := fsys.Open(path)
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
//
// This type can be used for any value, which will cause decoding to be delayed.
@ -58,7 +69,7 @@ func DecodeFile(path string, v interface{}) (MetaData, error) {
// overhead of reflection. They can be useful when you don't know the exact type
// of TOML data until runtime.
type Primitive struct {
undecoded interface{}
undecoded any
context Key
}
@ -91,7 +102,7 @@ const (
// UnmarshalText method. See the Unmarshaler example for a demonstration with
// email addresses.
//
// ### Key mapping
// # Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go struct.
// The special `toml` struct tag can be used to map TOML keys to struct fields
@ -122,7 +133,7 @@ var (
)
// Decode TOML data in to the pointer `v`.
func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
func (dec *Decoder) Decode(v any) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
s := "%q"
@ -136,8 +147,8 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v))
}
// Check if this is a supported type: struct, map, interface{}, or something
// that implements UnmarshalTOML or UnmarshalText.
// Check if this is a supported type: struct, map, any, or something that
// implements UnmarshalTOML or UnmarshalText.
rv = indirect(rv)
rt := rv.Type()
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
@ -148,7 +159,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
// TODO: parser should read from io.Reader? Or at the very least, make it
// read from []byte rather than string
data, err := ioutil.ReadAll(dec.r)
data, err := io.ReadAll(dec.r)
if err != nil {
return MetaData{}, err
}
@ -179,7 +190,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
// will only reflect keys that were decoded. Namely, any keys hidden behind a
// Primitive will be considered undecoded. Executing this method will update the
// undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
func (md *MetaData) PrimitiveDecode(primValue Primitive, v any) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
@ -190,7 +201,7 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
func (md *MetaData) unify(data any, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
// TODO: #76 would make this superfluous after implemented.
if rv.Type() == primitiveType {
@ -207,7 +218,11 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
rvi := rv.Interface()
if v, ok := rvi.(Unmarshaler); ok {
return v.UnmarshalTOML(data)
err := v.UnmarshalTOML(data)
if err != nil {
return md.parseErr(err)
}
return nil
}
if v, ok := rvi.(encoding.TextUnmarshaler); ok {
return md.unifyText(data, v)
@ -227,14 +242,6 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
@ -248,7 +255,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
if rv.NumMethod() > 0 { // Only support empty interfaces are supported.
if rv.NumMethod() > 0 { /// Only empty interfaces are supported.
return md.e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
@ -258,14 +265,13 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
return md.e("unsupported type %s", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
func (md *MetaData) unifyStruct(mapping any, rv reflect.Value) error {
tmap, ok := mapping.(map[string]any)
if !ok {
if mapping == nil {
return nil
}
return md.e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
return md.e("type mismatch for %s: expected table but found %s", rv.Type().String(), fmtType(mapping))
}
for key, datum := range tmap {
@ -304,14 +310,14 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
func (md *MetaData) unifyMap(mapping any, rv reflect.Value) error {
keyType := rv.Type().Key().Kind()
if keyType != reflect.String && keyType != reflect.Interface {
return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)",
keyType, rv.Type())
}
tmap, ok := mapping.(map[string]interface{})
tmap, ok := mapping.(map[string]any)
if !ok {
if tmap == nil {
return nil
@ -347,7 +353,7 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
func (md *MetaData) unifyArray(data any, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
@ -361,7 +367,7 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
func (md *MetaData) unifySlice(data any, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
@ -388,7 +394,7 @@ func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
return nil
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
func (md *MetaData) unifyString(data any, rv reflect.Value) error {
_, ok := rv.Interface().(json.Number)
if ok {
if i, ok := data.(int64); ok {
@ -408,7 +414,7 @@ func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
return md.badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
func (md *MetaData) unifyFloat64(data any, rv reflect.Value) error {
rvk := rv.Kind()
if num, ok := data.(float64); ok {
@ -429,7 +435,7 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) ||
(rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
return md.parseErr(errUnsafeFloat{i: num, size: rvk.String()})
}
rv.SetFloat(float64(num))
return nil
@ -438,7 +444,7 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
return md.badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
func (md *MetaData) unifyInt(data any, rv reflect.Value) error {
_, ok := rv.Interface().(time.Duration)
if ok {
// Parse as string duration, and fall back to regular integer parsing
@ -481,7 +487,7 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
return nil
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
func (md *MetaData) unifyBool(data any, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
@ -489,12 +495,12 @@ func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
return md.badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
func (md *MetaData) unifyAnything(data any, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) error {
func (md *MetaData) unifyText(data any, v encoding.TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case Marshaler:
@ -523,13 +529,13 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
return md.badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
return md.parseErr(err)
}
return nil
}
func (md *MetaData) badtype(dst string, data interface{}) error {
return md.e("incompatible types: TOML value has type %T; destination has type %s", data, dst)
func (md *MetaData) badtype(dst string, data any) error {
return md.e("incompatible types: TOML value has type %s; destination has type %s", fmtType(data), dst)
}
func (md *MetaData) parseErr(err error) error {
@ -543,7 +549,7 @@ func (md *MetaData) parseErr(err error) error {
}
}
func (md *MetaData) e(format string, args ...interface{}) error {
func (md *MetaData) e(format string, args ...any) error {
f := "toml: "
if len(md.context) > 0 {
f = fmt.Sprintf("toml: (last key %q): ", md.context)
@ -556,7 +562,7 @@ func (md *MetaData) e(format string, args ...interface{}) error {
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
func rvalue(v any) reflect.Value {
return indirect(reflect.ValueOf(v))
}
@ -600,3 +606,8 @@ func isUnifiable(rv reflect.Value) bool {
}
return false
}
// fmt %T with "interface {}" replaced with "any", which is far more readable.
func fmtType(t any) string {
return strings.ReplaceAll(fmt.Sprintf("%T", t), "interface {}", "any")
}

View File

@ -1,19 +0,0 @@
//go:build go1.16
// +build go1.16
package toml
import (
"io/fs"
)
// DecodeFS reads the contents of a file from [fs.FS] and decodes it with
// [Decode].
func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) {
fp, err := fsys.Open(path)
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
}

View File

@ -5,17 +5,25 @@ import (
"io"
)
// TextMarshaler is an alias for encoding.TextMarshaler.
//
// Deprecated: use encoding.TextMarshaler
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is an alias for encoding.TextUnmarshaler.
//
// Deprecated: use encoding.TextUnmarshaler
type TextUnmarshaler encoding.TextUnmarshaler
// DecodeReader is an alias for NewDecoder(r).Decode(v).
//
// Deprecated: use NewDecoder(reader).Decode(&value).
func DecodeReader(r io.Reader, v any) (MetaData, error) { return NewDecoder(r).Decode(v) }
// PrimitiveDecode is an alias for MetaData.PrimitiveDecode().
//
// Deprecated: use MetaData.PrimitiveDecode.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
func PrimitiveDecode(primValue Primitive, v any) error {
md := MetaData{decoded: make(map[string]struct{})}
return md.unify(primValue.undecoded, rvalue(v))
}
// Deprecated: use NewDecoder(reader).Decode(&value).
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { return NewDecoder(r).Decode(v) }

View File

@ -2,9 +2,6 @@
//
// This package supports TOML v1.0.0, as specified at https://toml.io
//
// There is also support for delaying decoding with the Primitive type, and
// querying the set of keys in a TOML document with the MetaData type.
//
// The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator,
// and can be used to verify if TOML document is valid. It can also be used to
// print the type of each key.

View File

@ -2,6 +2,7 @@ package toml
import (
"bufio"
"bytes"
"encoding"
"encoding/json"
"errors"
@ -76,6 +77,17 @@ type Marshaler interface {
MarshalTOML() ([]byte, error)
}
// Marshal returns a TOML representation of the Go value.
//
// See [Encoder] for a description of the encoding process.
func Marshal(v any) ([]byte, error) {
buff := new(bytes.Buffer)
if err := NewEncoder(buff).Encode(v); err != nil {
return nil, err
}
return buff.Bytes(), nil
}
// Encoder encodes a Go to a TOML document.
//
// The mapping between Go values and TOML values should be precisely the same as
@ -115,28 +127,24 @@ type Marshaler interface {
// NOTE: only exported keys are encoded due to the use of reflection. Unexported
// keys are silently discarded.
type Encoder struct {
// String to use for a single indentation level; default is two spaces.
Indent string
w *bufio.Writer
Indent string // string for a single indentation level; default is two spaces.
hasWritten bool // written any output to w yet?
w *bufio.Writer
}
// NewEncoder create a new Encoder.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
return &Encoder{w: bufio.NewWriter(w), Indent: " "}
}
// Encode writes a TOML representation of the Go value to the [Encoder]'s writer.
//
// An error is returned if the value given cannot be encoded to a valid TOML
// document.
func (enc *Encoder) Encode(v interface{}) error {
func (enc *Encoder) Encode(v any) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
err := enc.safeEncode(Key([]string{}), rv)
if err != nil {
return err
}
return enc.w.Flush()
@ -279,18 +287,30 @@ func (enc *Encoder) eElement(rv reflect.Value) {
case reflect.Float32:
f := rv.Float()
if math.IsNaN(f) {
if math.Signbit(f) {
enc.wf("-")
}
enc.wf("nan")
} else if math.IsInf(f, 0) {
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)])
if math.Signbit(f) {
enc.wf("-")
}
enc.wf("inf")
} else {
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 32)))
}
case reflect.Float64:
f := rv.Float()
if math.IsNaN(f) {
if math.Signbit(f) {
enc.wf("-")
}
enc.wf("nan")
} else if math.IsInf(f, 0) {
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)])
if math.Signbit(f) {
enc.wf("-")
}
enc.wf("inf")
} else {
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 64)))
}
@ -303,7 +323,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
case reflect.Interface:
enc.eElement(rv.Elem())
default:
encPanic(fmt.Errorf("unexpected type: %T", rv.Interface()))
encPanic(fmt.Errorf("unexpected type: %s", fmtType(rv.Interface())))
}
}
@ -457,6 +477,16 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
frv := eindirect(rv.Field(i))
if is32Bit {
// Copy so it works correct on 32bit archs; not clear why this
// is needed. See #314, and https://www.reddit.com/r/golang/comments/pnx8v4
// This also works fine on 64bit, but 32bit archs are somewhat
// rare and this is a wee bit faster.
copyStart := make([]int, len(start))
copy(copyStart, start)
start = copyStart
}
// Treat anonymous struct fields with tag names as though they are
// not anonymous, like encoding/json does.
//
@ -470,44 +500,37 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
if typeIsTable(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
// Copy so it works correct on 32bit archs; not clear why this
// is needed. See #314, and https://www.reddit.com/r/golang/comments/pnx8v4
// This also works fine on 64bit, but 32bit archs are somewhat
// rare and this is a wee bit faster.
if is32Bit {
copyStart := make([]int, len(start))
copy(copyStart, start)
fieldsDirect = append(fieldsDirect, append(copyStart, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
}
addFields(rt, rv, nil)
writeFields := func(fields [][]int) {
for _, fieldIndex := range fields {
fieldType := rt.FieldByIndex(fieldIndex)
fieldVal := eindirect(rv.FieldByIndex(fieldIndex))
if isNil(fieldVal) { /// Don't write anything for nil fields.
continue
}
fieldVal := rv.FieldByIndex(fieldIndex)
opts := getOptions(fieldType.Tag)
if opts.skip {
continue
}
if opts.omitempty && isEmpty(fieldVal) {
continue
}
fieldVal = eindirect(fieldVal)
if isNil(fieldVal) { /// Don't write anything for nil fields.
continue
}
keyName := fieldType.Name
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && enc.isEmpty(fieldVal) {
continue
}
if opts.omitzero && isZero(fieldVal) {
continue
}
@ -649,7 +672,7 @@ func isZero(rv reflect.Value) bool {
return false
}
func (enc *Encoder) isEmpty(rv reflect.Value) bool {
func isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return rv.Len() == 0
@ -664,13 +687,15 @@ func (enc *Encoder) isEmpty(rv reflect.Value) bool {
// type b struct{ s []string }
// s := a{field: b{s: []string{"AAA"}}}
for i := 0; i < rv.NumField(); i++ {
if !enc.isEmpty(rv.Field(i)) {
if !isEmpty(rv.Field(i)) {
return false
}
}
return true
case reflect.Bool:
return !rv.Bool()
case reflect.Ptr:
return rv.IsNil()
}
return false
}
@ -693,8 +718,11 @@ func (enc *Encoder) newline() {
// v v v v vv
// key = {k = 1, k2 = 2}
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
/// Marshaler used on top-level document; call eElement() to just call
/// Marshal{TOML,Text}.
if len(key) == 0 {
encPanic(errNoKey)
enc.eElement(val)
return
}
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
@ -703,7 +731,7 @@ func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
}
}
func (enc *Encoder) wf(format string, v ...interface{}) {
func (enc *Encoder) wf(format string, v ...any) {
_, err := fmt.Fprintf(enc.w, format, v...)
if err != nil {
encPanic(err)

View File

@ -84,7 +84,7 @@ func (pe ParseError) Error() string {
pe.Position.Line, pe.LastKey, msg)
}
// ErrorWithUsage() returns the error with detailed location context.
// ErrorWithPosition returns the error with detailed location context.
//
// See the documentation on [ParseError].
func (pe ParseError) ErrorWithPosition() string {
@ -114,17 +114,26 @@ func (pe ParseError) ErrorWithPosition() string {
msg, pe.Position.Line, col, col+pe.Position.Len)
}
if pe.Position.Line > 2 {
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-2, lines[pe.Position.Line-3])
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-2, expandTab(lines[pe.Position.Line-3]))
}
if pe.Position.Line > 1 {
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-1, lines[pe.Position.Line-2])
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-1, expandTab(lines[pe.Position.Line-2]))
}
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line, lines[pe.Position.Line-1])
fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col), strings.Repeat("^", pe.Position.Len))
/// Expand tabs, so that the ^^^s are at the correct position, but leave
/// "column 10-13" intact. Adjusting this to the visual column would be
/// better, but we don't know the tabsize of the user in their editor, which
/// can be 8, 4, 2, or something else. We can't know. So leaving it as the
/// character index is probably the "most correct".
expanded := expandTab(lines[pe.Position.Line-1])
diff := len(expanded) - len(lines[pe.Position.Line-1])
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line, expanded)
fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col+diff), strings.Repeat("^", pe.Position.Len))
return b.String()
}
// ErrorWithUsage() returns the error with detailed location context and usage
// ErrorWithUsage returns the error with detailed location context and usage
// guidance.
//
// See the documentation on [ParseError].
@ -159,18 +168,48 @@ func (pe ParseError) column(lines []string) int {
return col
}
func expandTab(s string) string {
var (
b strings.Builder
l int
fill = func(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = ' '
}
return string(b)
}
)
b.Grow(len(s))
for _, r := range s {
switch r {
case '\t':
tw := 8 - l%8
b.WriteString(fill(tw))
l += tw
default:
b.WriteRune(r)
l += 1
}
}
return b.String()
}
type (
errLexControl struct{ r rune }
errLexEscape struct{ r rune }
errLexUTF8 struct{ b byte }
errLexInvalidNum struct{ v string }
errLexInvalidDate struct{ v string }
errParseDate struct{ v string }
errLexInlineTableNL struct{}
errLexStringNL struct{}
errParseRange struct {
i interface{} // int or float
i any // int or float
size string // "int64", "uint16", etc.
}
errUnsafeFloat struct {
i interface{} // float32 or float64
size string // "float32" or "float64"
}
errParseDuration struct{ d string }
)
@ -183,16 +222,18 @@ func (e errLexEscape) Error() string { return fmt.Sprintf(`invalid escape
func (e errLexEscape) Usage() string { return usageEscape }
func (e errLexUTF8) Error() string { return fmt.Sprintf("invalid UTF-8 byte: 0x%02x", e.b) }
func (e errLexUTF8) Usage() string { return "" }
func (e errLexInvalidNum) Error() string { return fmt.Sprintf("invalid number: %q", e.v) }
func (e errLexInvalidNum) Usage() string { return "" }
func (e errLexInvalidDate) Error() string { return fmt.Sprintf("invalid date: %q", e.v) }
func (e errLexInvalidDate) Usage() string { return "" }
func (e errParseDate) Error() string { return fmt.Sprintf("invalid datetime: %q", e.v) }
func (e errParseDate) Usage() string { return usageDate }
func (e errLexInlineTableNL) Error() string { return "newlines not allowed within inline tables" }
func (e errLexInlineTableNL) Usage() string { return usageInlineNewline }
func (e errLexStringNL) Error() string { return "strings cannot contain newlines" }
func (e errLexStringNL) Usage() string { return usageStringNewline }
func (e errParseRange) Error() string { return fmt.Sprintf("%v is out of range for %s", e.i, e.size) }
func (e errParseRange) Usage() string { return usageIntOverflow }
func (e errUnsafeFloat) Error() string {
return fmt.Sprintf("%v is out of the safe %s range", e.i, e.size)
}
func (e errUnsafeFloat) Usage() string { return usageUnsafeFloat }
func (e errParseDuration) Error() string { return fmt.Sprintf("invalid duration: %q", e.d) }
func (e errParseDuration) Usage() string { return usageDuration }
@ -251,19 +292,35 @@ bug in the program that uses too small of an integer.
The maximum and minimum values are:
size lowest highest
int8 -128 127
int16 -32,768 32,767
int32 -2,147,483,648 2,147,483,647
int64 -9.2 × 10¹ 9.2 × 10¹
uint8 0 255
uint16 0 65535
uint32 0 4294967295
uint16 0 65,535
uint32 0 4,294,967,295
uint64 0 1.8 × 10¹
int refers to int32 on 32-bit systems and int64 on 64-bit systems.
`
const usageUnsafeFloat = `
This number is outside of the "safe" range for floating point numbers; whole
(non-fractional) numbers outside the below range can not always be represented
accurately in a float, leading to some loss of accuracy.
Explicitly mark a number as a fractional unit by adding ".0", which will incur
some loss of accuracy; for example:
f = 2_000_000_000.0
Accuracy ranges:
float32 = 16,777,215
float64 = 9,007,199,254,740,991
`
const usageDuration = `
A duration must be as "number<unit>", without any spaces. Valid units are:
@ -277,3 +334,23 @@ A duration must be as "number<unit>", without any spaces. Valid units are:
You can combine multiple units; for example "5m10s" for 5 minutes and 10
seconds.
`
const usageDate = `
A TOML datetime must be in one of the following formats:
2006-01-02T15:04:05Z07:00 Date and time, with timezone.
2006-01-02T15:04:05 Date and time, but without timezone.
2006-01-02 Date without a time or timezone.
15:04:05 Just a time, without any timezone.
Seconds may optionally have a fraction, up to nanosecond precision:
15:04:05.123
15:04:05.856018510
`
// TOML 1.1:
// The seconds part in times is optional, and may be omitted:
// 2006-01-02T15:04Z07:00
// 2006-01-02T15:04
// 15:04

View File

@ -17,6 +17,7 @@ const (
itemEOF
itemText
itemString
itemStringEsc
itemRawString
itemMultilineString
itemRawMultilineString
@ -52,6 +53,8 @@ type lexer struct {
line int
state stateFn
items chan item
tomlNext bool
esc bool
// Allow for backing up up to 4 runes. This is necessary because TOML
// contains 3-rune tokens (""" and ''').
@ -87,13 +90,14 @@ func (lx *lexer) nextItem() item {
}
}
func lex(input string) *lexer {
func lex(input string, tomlNext bool) *lexer {
lx := &lexer{
input: input,
state: lexTop,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
line: 1,
tomlNext: tomlNext,
}
return lx
}
@ -162,7 +166,7 @@ func (lx *lexer) next() (r rune) {
}
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
if r == utf8.RuneError {
if r == utf8.RuneError && w == 1 {
lx.error(errLexUTF8{lx.input[lx.pos]})
return utf8.RuneError
}
@ -268,7 +272,7 @@ func (lx *lexer) errorPos(start, length int, err error) stateFn {
}
// errorf is like error, and creates a new error.
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
func (lx *lexer) errorf(format string, values ...any) stateFn {
if lx.atEOF {
pos := lx.getPos()
pos.Line--
@ -331,9 +335,7 @@ func lexTopEnd(lx *lexer) stateFn {
lx.emit(itemEOF)
return nil
}
return lx.errorf(
"expected a top-level item to end with a newline, comment, or EOF, but got %q instead",
r)
return lx.errorf("expected a top-level item to end with a newline, comment, or EOF, but got %q instead", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
@ -408,7 +410,7 @@ func lexTableNameEnd(lx *lexer) stateFn {
// Lexes only one part, e.g. only 'a' inside 'a.b'.
func lexBareName(lx *lexer) stateFn {
r := lx.next()
if isBareKeyChar(r) {
if isBareKeyChar(r, lx.tomlNext) {
return lexBareName
}
lx.backup()
@ -618,6 +620,9 @@ func lexInlineTableValue(lx *lexer) stateFn {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValue)
case isNL(r):
if lx.tomlNext {
return lexSkip(lx, lexInlineTableValue)
}
return lx.errorPrevLine(errLexInlineTableNL{})
case r == '#':
lx.push(lexInlineTableValue)
@ -640,6 +645,9 @@ func lexInlineTableValueEnd(lx *lexer) stateFn {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValueEnd)
case isNL(r):
if lx.tomlNext {
return lexSkip(lx, lexInlineTableValueEnd)
}
return lx.errorPrevLine(errLexInlineTableNL{})
case r == '#':
lx.push(lexInlineTableValueEnd)
@ -648,6 +656,9 @@ func lexInlineTableValueEnd(lx *lexer) stateFn {
lx.ignore()
lx.skip(isWhitespace)
if lx.peek() == '}' {
if lx.tomlNext {
return lexInlineTableValueEnd
}
return lx.errorf("trailing comma not allowed in inline tables")
}
return lexInlineTableValue
@ -687,7 +698,12 @@ func lexString(lx *lexer) stateFn {
return lexStringEscape
case r == '"':
lx.backup()
if lx.esc {
lx.esc = false
lx.emit(itemStringEsc)
} else {
lx.emit(itemString)
}
lx.next()
lx.ignore()
return lx.pop()
@ -737,6 +753,7 @@ func lexMultilineString(lx *lexer) stateFn {
lx.backup() /// backup: don't include the """ in the item.
lx.backup()
lx.backup()
lx.esc = false
lx.emit(itemMultilineString)
lx.next() /// Read over ''' again and discard it.
lx.next()
@ -770,8 +787,8 @@ func lexRawString(lx *lexer) stateFn {
}
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning ''' has already been consumed and
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such a
// string. It assumes that the beginning triple-' has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next()
@ -826,8 +843,14 @@ func lexMultilineStringEscape(lx *lexer) stateFn {
}
func lexStringEscape(lx *lexer) stateFn {
lx.esc = true
r := lx.next()
switch r {
case 'e':
if !lx.tomlNext {
return lx.error(errLexEscape{r})
}
fallthrough
case 'b':
fallthrough
case 't':
@ -846,6 +869,11 @@ func lexStringEscape(lx *lexer) stateFn {
fallthrough
case '\\':
return lx.pop()
case 'x':
if !lx.tomlNext {
return lx.error(errLexEscape{r})
}
return lexHexEscape
case 'u':
return lexShortUnicodeEscape
case 'U':
@ -854,14 +882,23 @@ func lexStringEscape(lx *lexer) stateFn {
return lx.error(errLexEscape{r})
}
func lexHexEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 2; i++ {
r = lx.next()
if !isHex(r) {
return lx.errorf(`expected two hexadecimal digits after '\x', but got %q instead`, lx.current())
}
}
return lx.pop()
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(
`expected four hexadecimal digits after '\u', but got %q instead`,
lx.current())
if !isHex(r) {
return lx.errorf(`expected four hexadecimal digits after '\u', but got %q instead`, lx.current())
}
}
return lx.pop()
@ -871,10 +908,8 @@ func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(
`expected eight hexadecimal digits after '\U', but got %q instead`,
lx.current())
if !isHex(r) {
return lx.errorf(`expected eight hexadecimal digits after '\U', but got %q instead`, lx.current())
}
}
return lx.pop()
@ -941,7 +976,7 @@ func lexDatetime(lx *lexer) stateFn {
// lexHexInteger consumes a hexadecimal integer after seeing the '0x' prefix.
func lexHexInteger(lx *lexer) stateFn {
r := lx.next()
if isHexadecimal(r) {
if isHex(r) {
return lexHexInteger
}
switch r {
@ -1075,7 +1110,7 @@ func lexBaseNumberOrDate(lx *lexer) stateFn {
return lexOctalInteger
case 'x':
r = lx.peek()
if !isHexadecimal(r) {
if !isHex(r) {
lx.errorf("not a hexidecimal number: '%s%c'", lx.current(), r)
}
return lexHexInteger
@ -1173,7 +1208,7 @@ func (itype itemType) String() string {
return "EOF"
case itemText:
return "Text"
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
case itemString, itemStringEsc, itemRawString, itemMultilineString, itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
@ -1206,7 +1241,7 @@ func (itype itemType) String() string {
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
return fmt.Sprintf("(%s, %s)", item.typ, item.val)
}
func isWhitespace(r rune) bool { return r == '\t' || r == ' ' }
@ -1222,10 +1257,23 @@ func isControl(r rune) bool { // Control characters except \t, \r, \n
func isDigit(r rune) bool { return r >= '0' && r <= '9' }
func isBinary(r rune) bool { return r == '0' || r == '1' }
func isOctal(r rune) bool { return r >= '0' && r <= '7' }
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')
func isHex(r rune) bool { return (r >= '0' && r <= '9') || (r|0x20 >= 'a' && r|0x20 <= 'f') }
func isBareKeyChar(r rune, tomlNext bool) bool {
if tomlNext {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' || r == '-' ||
r == 0xb2 || r == 0xb3 || r == 0xb9 || (r >= 0xbc && r <= 0xbe) ||
(r >= 0xc0 && r <= 0xd6) || (r >= 0xd8 && r <= 0xf6) || (r >= 0xf8 && r <= 0x037d) ||
(r >= 0x037f && r <= 0x1fff) ||
(r >= 0x200c && r <= 0x200d) || (r >= 0x203f && r <= 0x2040) ||
(r >= 0x2070 && r <= 0x218f) || (r >= 0x2460 && r <= 0x24ff) ||
(r >= 0x2c00 && r <= 0x2fef) || (r >= 0x3001 && r <= 0xd7ff) ||
(r >= 0xf900 && r <= 0xfdcf) || (r >= 0xfdf0 && r <= 0xfffd) ||
(r >= 0x10000 && r <= 0xeffff)
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||

View File

@ -13,7 +13,7 @@ type MetaData struct {
context Key // Used only during decoding.
keyInfo map[string]keyInfo
mapping map[string]interface{}
mapping map[string]any
keys []Key
decoded map[string]struct{}
data []byte // Input file; for errors.
@ -31,12 +31,12 @@ func (md *MetaData) IsDefined(key ...string) bool {
}
var (
hash map[string]interface{}
hash map[string]any
ok bool
hashOrVal interface{} = md.mapping
hashOrVal any = md.mapping
)
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
if hash, ok = hashOrVal.(map[string]any); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
@ -94,28 +94,55 @@ func (md *MetaData) Undecoded() []Key {
type Key []string
func (k Key) String() string {
ss := make([]string, len(k))
for i := range k {
ss[i] = k.maybeQuoted(i)
// This is called quite often, so it's a bit funky to make it faster.
var b strings.Builder
b.Grow(len(k) * 25)
outer:
for i, kk := range k {
if i > 0 {
b.WriteByte('.')
}
return strings.Join(ss, ".")
if kk == "" {
b.WriteString(`""`)
} else {
for _, r := range kk {
// "Inline" isBareKeyChar
if !((r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-') {
b.WriteByte('"')
b.WriteString(dblQuotedReplacer.Replace(kk))
b.WriteByte('"')
continue outer
}
}
b.WriteString(kk)
}
}
return b.String()
}
func (k Key) maybeQuoted(i int) string {
if k[i] == "" {
return `""`
}
for _, c := range k[i] {
if !isBareKeyChar(c) {
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
for _, r := range k[i] {
if (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' {
continue
}
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
}
return k[i]
}
// Like append(), but only increase the cap by 1.
func (k Key) add(piece string) Key {
if cap(k) > len(k) {
return append(k, piece)
}
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
func (k Key) parent() Key { return k[:len(k)-1] } // all except the last piece.
func (k Key) last() string { return k[len(k)-1] } // last piece of this key.

View File

@ -2,6 +2,8 @@ package toml
import (
"fmt"
"math"
"os"
"strconv"
"strings"
"time"
@ -15,11 +17,12 @@ type parser struct {
context Key // Full key for the current hash in scope.
currentKey string // Base key name for everything except hashes.
pos Position // Current position in the TOML file.
tomlNext bool
ordered []Key // List of keys in the order that they appear in the TOML data.
keyInfo map[string]keyInfo // Map keyname → info about the TOML key.
mapping map[string]interface{} // Map keyname → key value.
mapping map[string]any // Map keyname → key value.
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names").
}
@ -29,6 +32,8 @@ type keyInfo struct {
}
func parse(data string) (p *parser, err error) {
_, tomlNext := os.LookupEnv("BURNTSUSHI_TOML_110")
defer func() {
if r := recover(); r != nil {
if pErr, ok := r.(ParseError); ok {
@ -41,9 +46,13 @@ func parse(data string) (p *parser, err error) {
}()
// Read over BOM; do this here as the lexer calls utf8.DecodeRuneInString()
// which mangles stuff.
if strings.HasPrefix(data, "\xff\xfe") || strings.HasPrefix(data, "\xfe\xff") {
// which mangles stuff. UTF-16 BOM isn't strictly valid, but some tools add
// it anyway.
if strings.HasPrefix(data, "\xff\xfe") || strings.HasPrefix(data, "\xfe\xff") { // UTF-16
data = data[2:]
//lint:ignore S1017 https://github.com/dominikh/go-tools/issues/1447
} else if strings.HasPrefix(data, "\xef\xbb\xbf") { // UTF-8
data = data[3:]
}
// Examine first few bytes for NULL bytes; this probably means it's a UTF-16
@ -64,10 +73,11 @@ func parse(data string) (p *parser, err error) {
p = &parser{
keyInfo: make(map[string]keyInfo),
mapping: make(map[string]interface{}),
lx: lex(data),
mapping: make(map[string]any),
lx: lex(data, tomlNext),
ordered: make([]Key, 0),
implicits: make(map[string]struct{}),
tomlNext: tomlNext,
}
for {
item := p.next()
@ -89,7 +99,7 @@ func (p *parser) panicErr(it item, err error) {
})
}
func (p *parser) panicItemf(it item, format string, v ...interface{}) {
func (p *parser) panicItemf(it item, format string, v ...any) {
panic(ParseError{
Message: fmt.Sprintf(format, v...),
Position: it.pos,
@ -98,7 +108,7 @@ func (p *parser) panicItemf(it item, format string, v ...interface{}) {
})
}
func (p *parser) panicf(format string, v ...interface{}) {
func (p *parser) panicf(format string, v ...any) {
panic(ParseError{
Message: fmt.Sprintf(format, v...),
Position: p.pos,
@ -131,7 +141,7 @@ func (p *parser) nextPos() item {
return it
}
func (p *parser) bug(format string, v ...interface{}) {
func (p *parser) bug(format string, v ...any) {
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
}
@ -186,20 +196,21 @@ func (p *parser) topLevel(item item) {
p.assertEqual(itemKeyEnd, k.typ)
/// The current key is the last part.
p.currentKey = key[len(key)-1]
p.currentKey = key.last()
/// All the other parts (if any) are the context; need to set each part
/// as implicit.
context := key[:len(key)-1]
context := key.parent()
for i := range context {
p.addImplicitContext(append(p.context, context[i:i+1]...))
}
p.ordered = append(p.ordered, p.context.add(p.currentKey))
/// Set value.
vItem := p.next()
val, typ := p.value(vItem, false)
p.set(p.currentKey, val, typ, vItem.pos)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ, vItem.pos)
/// Remove the context we added (preserving any context from [tbl] lines).
p.context = outerContext
@ -214,7 +225,7 @@ func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
case itemString, itemStringEsc, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it, false)
return s.(string)
@ -231,12 +242,14 @@ var datetimeRepl = strings.NewReplacer(
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) {
func (p *parser) value(it item, parentIsArray bool) (any, tomlType) {
switch it.typ {
case itemString:
return it.val, p.typeOfPrimitive(it)
case itemStringEsc:
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it)
case itemMultilineString:
return p.replaceEscapes(it, stripFirstNewline(p.stripEscapedNewlines(it.val))), p.typeOfPrimitive(it)
return p.replaceEscapes(it, p.stripEscapedNewlines(stripFirstNewline(it.val))), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
@ -266,7 +279,7 @@ func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) {
panic("unreachable")
}
func (p *parser) valueInteger(it item) (interface{}, tomlType) {
func (p *parser) valueInteger(it item) (any, tomlType) {
if !numUnderscoresOK(it.val) {
p.panicItemf(it, "Invalid integer %q: underscores must be surrounded by digits", it.val)
}
@ -290,7 +303,7 @@ func (p *parser) valueInteger(it item) (interface{}, tomlType) {
return num, p.typeOfPrimitive(it)
}
func (p *parser) valueFloat(it item) (interface{}, tomlType) {
func (p *parser) valueFloat(it item) (any, tomlType) {
parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
@ -314,7 +327,9 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) {
p.panicItemf(it, "Invalid float %q: '.' must be followed by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
if val == "+nan" || val == "-nan" { // Go doesn't support this, but TOML spec does.
signbit := false
if val == "+nan" || val == "-nan" {
signbit = val == "-nan"
val = "nan"
}
num, err := strconv.ParseFloat(val, 64)
@ -325,20 +340,29 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) {
p.panicItemf(it, "Invalid float value: %q", it.val)
}
}
if signbit {
num = math.Copysign(num, -1)
}
return num, p.typeOfPrimitive(it)
}
var dtTypes = []struct {
fmt string
zone *time.Location
next bool
}{
{time.RFC3339Nano, time.Local},
{"2006-01-02T15:04:05.999999999", internal.LocalDatetime},
{"2006-01-02", internal.LocalDate},
{"15:04:05.999999999", internal.LocalTime},
{time.RFC3339Nano, time.Local, false},
{"2006-01-02T15:04:05.999999999", internal.LocalDatetime, false},
{"2006-01-02", internal.LocalDate, false},
{"15:04:05.999999999", internal.LocalTime, false},
// tomlNext
{"2006-01-02T15:04Z07:00", time.Local, true},
{"2006-01-02T15:04", internal.LocalDatetime, true},
{"15:04", internal.LocalTime, true},
}
func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
func (p *parser) valueDatetime(it item) (any, tomlType) {
it.val = datetimeRepl.Replace(it.val)
var (
t time.Time
@ -346,28 +370,49 @@ func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
err error
)
for _, dt := range dtTypes {
if dt.next && !p.tomlNext {
continue
}
t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone)
if err == nil {
if missingLeadingZero(it.val, dt.fmt) {
p.panicErr(it, errParseDate{it.val})
}
ok = true
break
}
}
if !ok {
p.panicItemf(it, "Invalid TOML Datetime: %q.", it.val)
p.panicErr(it, errParseDate{it.val})
}
return t, p.typeOfPrimitive(it)
}
func (p *parser) valueArray(it item) (interface{}, tomlType) {
// Go's time.Parse() will accept numbers without a leading zero; there isn't any
// way to require it. https://github.com/golang/go/issues/29911
//
// Depend on the fact that the separators (- and :) should always be at the same
// location.
func missingLeadingZero(d, l string) bool {
for i, c := range []byte(l) {
if c == '.' || c == 'Z' {
return false
}
if (c < '0' || c > '9') && d[i] != c {
return true
}
}
return false
}
func (p *parser) valueArray(it item) (any, tomlType) {
p.setType(p.currentKey, tomlArray, it.pos)
var (
types []tomlType
// Initialize to a non-nil empty slice. This makes it consistent with
// how S = [] decodes into a non-nil slice inside something like struct
// { S []string }. See #338
array = []interface{}{}
// Initialize to a non-nil slice to make it consistent with how S = []
// decodes into a non-nil slice inside something like struct { S
// []string }. See #338
array = make([]any, 0, 2)
)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
@ -377,20 +422,20 @@ func (p *parser) valueArray(it item) (interface{}, tomlType) {
val, typ := p.value(it, true)
array = append(array, val)
types = append(types, typ)
// XXX: types isn't used here, we need it to record the accurate type
// XXX: type isn't used here, we need it to record the accurate type
// information.
//
// Not entirely sure how to best store this; could use "key[0]",
// "key[1]" notation, or maybe store it on the Array type?
_ = typ
}
return array, tomlArray
}
func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tomlType) {
func (p *parser) valueInlineTable(it item, parentIsArray bool) (any, tomlType) {
var (
hash = make(map[string]interface{})
topHash = make(map[string]any)
outerContext = p.context
outerKey = p.currentKey
)
@ -418,19 +463,33 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom
p.assertEqual(itemKeyEnd, k.typ)
/// The current key is the last part.
p.currentKey = key[len(key)-1]
p.currentKey = key.last()
/// All the other parts (if any) are the context; need to set each part
/// as implicit.
context := key[:len(key)-1]
context := key.parent()
for i := range context {
p.addImplicitContext(append(p.context, context[i:i+1]...))
}
p.ordered = append(p.ordered, p.context.add(p.currentKey))
/// Set the value.
val, typ := p.value(p.next(), false)
p.set(p.currentKey, val, typ, it.pos)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ, it.pos)
hash := topHash
for _, c := range context {
h, ok := hash[c]
if !ok {
h = make(map[string]any)
hash[c] = h
}
hash, ok = h.(map[string]any)
if !ok {
p.panicf("%q is not a table", p.context)
}
}
hash[p.currentKey] = val
/// Restore context.
@ -438,7 +497,7 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
return topHash, tomlHash
}
// numHasLeadingZero checks if this number has leading zeroes, allowing for '0',
@ -468,9 +527,9 @@ func numUnderscoresOK(s string) bool {
}
}
// isHexadecimal is a superset of all the permissable characters
// surrounding an underscore.
accept = isHexadecimal(r)
// isHexis a superset of all the permissable characters surrounding an
// underscore.
accept = isHex(r)
}
return accept
}
@ -493,21 +552,19 @@ func numPeriodsOK(s string) bool {
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) addContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
/// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
keyContext := make(Key, 0, len(key)-1)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
/// We only need implicit hashes for the parents.
for _, k := range key.parent() {
_, ok := hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
hashContext[k] = make(map[string]any)
}
// If the hash context is actually an array of tables, then set
@ -516,9 +573,9 @@ func (p *parser) addContext(key Key, array bool) {
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
case []map[string]any:
hashContext = t[len(t)-1]
case map[string]interface{}:
case map[string]any:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
@ -529,40 +586,33 @@ func (p *parser) addContext(key Key, array bool) {
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
k := key.last()
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 4)
hashContext[k] = make([]map[string]any, 0, 4)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
if hash, ok := hashContext[k].([]map[string]any); ok {
hashContext[k] = append(hash, make(map[string]any))
} else {
p.panicf("Key '%s' was already created and cannot be used as an array.", key)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
p.setValue(key.last(), make(map[string]any))
}
p.context = append(p.context, key[len(key)-1])
}
// set calls setValue and setType.
func (p *parser) set(key string, val interface{}, typ tomlType, pos Position) {
p.setValue(key, val)
p.setType(key, typ, pos)
p.context = append(p.context, key.last())
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
func (p *parser) setValue(key string, value any) {
var (
tmpHash interface{}
tmpHash any
ok bool
hash = p.mapping
keyContext Key
keyContext = make(Key, 0, len(p.context)+1)
)
for _, k := range p.context {
keyContext = append(keyContext, k)
@ -570,11 +620,11 @@ func (p *parser) setValue(key string, value interface{}) {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
case []map[string]any:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
case map[string]any:
hash = t
default:
p.panicf("Key '%s' has already been defined.", keyContext)
@ -601,9 +651,8 @@ func (p *parser) setValue(key string, value interface{}) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
// Otherwise, we have a concrete key trying to override a previous key,
// which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
@ -636,10 +685,7 @@ func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = struct{}
func (p *parser) removeImplicit(key Key) { delete(p.implicits, key.String()) }
func (p *parser) isImplicit(key Key) bool { _, ok := p.implicits[key.String()]; return ok }
func (p *parser) isArray(key Key) bool { return p.keyInfo[key.String()].tomlType == tomlArray }
func (p *parser) addImplicitContext(key Key) {
p.addImplicit(key)
p.addContext(key, false)
}
func (p *parser) addImplicitContext(key Key) { p.addImplicit(key); p.addContext(key, false) }
// current returns the full key name of the current context.
func (p *parser) current() string {
@ -662,114 +708,131 @@ func stripFirstNewline(s string) string {
return s
}
// Remove newlines inside triple-quoted strings if a line ends with "\".
// stripEscapedNewlines removes whitespace after line-ending backslashes in
// multiline strings.
//
// A line-ending backslash is an unescaped \ followed only by whitespace until
// the next newline. After a line-ending backslash, all whitespace is removed
// until the next non-whitespace character.
func (p *parser) stripEscapedNewlines(s string) string {
split := strings.Split(s, "\n")
if len(split) < 1 {
return s
var (
b strings.Builder
i int
)
b.Grow(len(s))
for {
ix := strings.Index(s[i:], `\`)
if ix < 0 {
b.WriteString(s)
return b.String()
}
i += ix
escNL := false // Keep track of the last non-blank line was escaped.
for i, line := range split {
line = strings.TrimRight(line, " \t\r")
if len(line) == 0 || line[len(line)-1] != '\\' {
split[i] = strings.TrimRight(split[i], "\r")
if !escNL && i != len(split)-1 {
split[i] += "\n"
}
if len(s) > i+1 && s[i+1] == '\\' {
// Escaped backslash.
i += 2
continue
}
escBS := true
for j := len(line) - 1; j >= 0 && line[j] == '\\'; j-- {
escBS = !escBS
// Scan until the next non-whitespace.
j := i + 1
whitespaceLoop:
for ; j < len(s); j++ {
switch s[j] {
case ' ', '\t', '\r', '\n':
default:
break whitespaceLoop
}
if escNL {
line = strings.TrimLeft(line, " \t\r")
}
escNL = !escBS
if escBS {
split[i] += "\n"
if j == i+1 {
// Not a whitespace escape.
i++
continue
}
if i == len(split)-1 {
p.panicf("invalid escape: '\\ '")
if !strings.Contains(s[i:j], "\n") {
// This is not a line-ending backslash. (It's a bad escape sequence,
// but we can let replaceEscapes catch it.)
i++
continue
}
split[i] = line[:len(line)-1] // Remove \
if len(split)-1 > i {
split[i+1] = strings.TrimLeft(split[i+1], " \t\r")
b.WriteString(s[:i])
s = s[j:]
i = 0
}
}
return strings.Join(split, "")
}
func (p *parser) replaceEscapes(it item, str string) string {
replaced := make([]rune, 0, len(str))
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
var (
b strings.Builder
skip = 0
)
b.Grow(len(str))
for i, c := range str {
if skip > 0 {
skip--
continue
}
r += 1
if r >= len(s) {
if c != '\\' {
b.WriteRune(c)
continue
}
if i >= len(str) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
switch str[i+1] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
p.bug("Expected valid escape code after \\, but got %q.", str[i+1])
case ' ', '\t':
p.panicItemf(it, "invalid escape: '\\%c'", s[r])
p.panicItemf(it, "invalid escape: '\\%c'", str[i+1])
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
b.WriteByte(0x08)
skip = 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
b.WriteByte(0x09)
skip = 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
b.WriteByte(0x0a)
skip = 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
b.WriteByte(0x0c)
skip = 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
b.WriteByte(0x0d)
skip = 1
case 'e':
if p.tomlNext {
b.WriteByte(0x1b)
skip = 1
}
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
b.WriteByte(0x22)
skip = 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
b.WriteByte(0x5c)
skip = 1
// The lexer guarantees the correct number of characters are present;
// don't need to check here.
case 'x':
if p.tomlNext {
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+4])
b.WriteRune(escaped)
skip = 3
}
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+5])
replaced = append(replaced, escaped)
r += 5
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+6])
b.WriteRune(escaped)
skip = 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+9])
replaced = append(replaced, escaped)
r += 9
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+10])
b.WriteRune(escaped)
skip = 9
}
}
return string(replaced)
return b.String()
}
func (p *parser) asciiEscapeToUnicode(it item, bs []byte) rune {
s := string(bs)
func (p *parser) asciiEscapeToUnicode(it item, s string) rune {
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the lexer claims it's OK: %s", s, err)

View File

@ -26,9 +26,7 @@ type field struct {
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
@ -46,9 +44,7 @@ func (x byName) Less(i, j int) bool {
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {

View File

@ -22,13 +22,8 @@ func typeIsTable(t tomlType) bool {
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
func (btype tomlBaseType) typeString() string { return string(btype) }
func (btype tomlBaseType) String() string { return btype.typeString() }
var (
tomlInteger tomlBaseType = "Integer"
@ -54,7 +49,7 @@ func (p *parser) typeOfPrimitive(lexItem item) tomlType {
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
case itemString, itemStringEsc:
return tomlString
case itemMultilineString:
return tomlString

9
vendor/github.com/jackc/pgpassfile/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,9 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

22
vendor/github.com/jackc/pgpassfile/LICENSE generated vendored Normal file
View File

@ -0,0 +1,22 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

8
vendor/github.com/jackc/pgpassfile/README.md generated vendored Normal file
View File

@ -0,0 +1,8 @@
[![](https://godoc.org/github.com/jackc/pgpassfile?status.svg)](https://godoc.org/github.com/jackc/pgpassfile)
[![Build Status](https://travis-ci.org/jackc/pgpassfile.svg)](https://travis-ci.org/jackc/pgpassfile)
# pgpassfile
Package pgpassfile is a parser PostgreSQL .pgpass files.
Extracted and rewritten from original implementation in https://github.com/jackc/pgx.

110
vendor/github.com/jackc/pgpassfile/pgpass.go generated vendored Normal file
View File

@ -0,0 +1,110 @@
// Package pgpassfile is a parser PostgreSQL .pgpass files.
package pgpassfile
import (
"bufio"
"io"
"os"
"regexp"
"strings"
)
// Entry represents a line in a PG passfile.
type Entry struct {
Hostname string
Port string
Database string
Username string
Password string
}
// Passfile is the in memory data structure representing a PG passfile.
type Passfile struct {
Entries []*Entry
}
// ReadPassfile reads the file at path and parses it into a Passfile.
func ReadPassfile(path string) (*Passfile, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return ParsePassfile(f)
}
// ParsePassfile reads r and parses it into a Passfile.
func ParsePassfile(r io.Reader) (*Passfile, error) {
passfile := &Passfile{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
entry := parseLine(scanner.Text())
if entry != nil {
passfile.Entries = append(passfile.Entries, entry)
}
}
return passfile, scanner.Err()
}
// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped
// colon.
var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+")
// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)")
// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable
// line.
func parseLine(line string) *Entry {
const (
tmpBackslash = "\r"
tmpColon = "\n"
)
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "#") {
return nil
}
line = strings.Replace(line, `\\`, tmpBackslash, -1)
line = strings.Replace(line, `\:`, tmpColon, -1)
parts := strings.Split(line, ":")
if len(parts) != 5 {
return nil
}
// Unescape escaped colons and backslashes
for i := range parts {
parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1)
parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1)
}
return &Entry{
Hostname: parts[0],
Port: parts[1],
Database: parts[2],
Username: parts[3],
Password: parts[4],
}
}
// FindPassword finds the password for the provided hostname, port, database, and username. For a
// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no
// match is found.
//
// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information.
func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) {
for _, e := range pf.Entries {
if (e.Hostname == "*" || e.Hostname == hostname) &&
(e.Port == "*" || e.Port == port) &&
(e.Database == "*" || e.Database == database) &&
(e.Username == "*" || e.Username == username) {
return e.Password
}
}
return ""
}

22
vendor/github.com/jackc/pgservicefile/LICENSE generated vendored Normal file
View File

@ -0,0 +1,22 @@
Copyright (c) 2020 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

7
vendor/github.com/jackc/pgservicefile/README.md generated vendored Normal file
View File

@ -0,0 +1,7 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgservicefile.svg)](https://pkg.go.dev/github.com/jackc/pgservicefile)
[![Build Status](https://github.com/jackc/pgservicefile/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgservicefile/actions/workflows/ci.yml)
# pgservicefile
Package pgservicefile is a parser for PostgreSQL service files (e.g. `.pg_service.conf`).

81
vendor/github.com/jackc/pgservicefile/pgservicefile.go generated vendored Normal file
View File

@ -0,0 +1,81 @@
// Package pgservicefile is a parser for PostgreSQL service files (e.g. .pg_service.conf).
package pgservicefile
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"strings"
)
type Service struct {
Name string
Settings map[string]string
}
type Servicefile struct {
Services []*Service
servicesByName map[string]*Service
}
// GetService returns the named service.
func (sf *Servicefile) GetService(name string) (*Service, error) {
service, present := sf.servicesByName[name]
if !present {
return nil, errors.New("not found")
}
return service, nil
}
// ReadServicefile reads the file at path and parses it into a Servicefile.
func ReadServicefile(path string) (*Servicefile, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return ParseServicefile(f)
}
// ParseServicefile reads r and parses it into a Servicefile.
func ParseServicefile(r io.Reader) (*Servicefile, error) {
servicefile := &Servicefile{}
var service *Service
scanner := bufio.NewScanner(r)
lineNum := 0
for scanner.Scan() {
lineNum += 1
line := scanner.Text()
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
// ignore comments and empty lines
} else if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
service = &Service{Name: line[1 : len(line)-1], Settings: make(map[string]string)}
servicefile.Services = append(servicefile.Services, service)
} else if service != nil {
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("unable to parse line %d", lineNum)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
service.Settings[key] = value
} else {
return nil, fmt.Errorf("line %d is not in a section", lineNum)
}
}
servicefile.servicesByName = make(map[string]*Service, len(servicefile.Services))
for _, service := range servicefile.Services {
servicefile.servicesByName[service.Name] = service
}
return servicefile, scanner.Err()
}

27
vendor/github.com/jackc/pgx/v5/.gitignore generated vendored Normal file
View File

@ -0,0 +1,27 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
.envrc
/.testdb
.DS_Store

400
vendor/github.com/jackc/pgx/v5/CHANGELOG.md generated vendored Normal file
View File

@ -0,0 +1,400 @@
# 5.6.0 (May 25, 2024)
* Add StrictNamedArgs (Tomas Zahradnicek)
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
* Add SeverityUnlocalized field to PgError / Notice
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
* Allow customizing context canceled behavior for pgconn
* Add ScanLocation to pgtype.Timestamp[tz]Codec
* Add custom data to pgconn.PgConn
* Fix ResultReader.Read() to handle nil values
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
* Failed connection attempts include all errors
* Optimize LargeObject.Read (Mitar)
* Add tracing for connection acquire and release from pool (ngavinsir)
* Fix encode driver.Valuer not called when nil
* Add support for custom JSON marshal and unmarshal (Mitar)
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
# 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization.
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
# 5.5.4 (March 4, 2024)
Fix CVE-2024-27304
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
attacker's control.
Thanks to Paul Gerste for reporting this issue.
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
* Fix simple protocol encoding of json.RawMessage
* Fix *Pipeline.getResults should close pipeline on error
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
* Fix deallocation of invalidated cached statements in a transaction
* Handle invalid sslkey file
* Fix scan float4 into sql.Scanner
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
# 5.5.3 (February 3, 2024)
* Fix: prepared statement already exists
* Improve CopyFrom auto-conversion of text-ish values
* Add ltree type support (Florent Viel)
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
* Add AppendRows function (Edoardo Spadolini)
* Optimize convert UUID [16]byte to string (Kirill Malikov)
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
# 5.5.2 (January 13, 2024)
* Allow NamedArgs to start with underscore
* pgproto3: Maximum message body length support (jeremy.spriet)
* Upgrade golang.org/x/crypto to v0.17.0
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
* Fix: update description cache after exec prepare (James Hartig)
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
* Add OnPgError for easier centralized error handling (James Hartig)
# 5.5.1 (December 9, 2023)
* Add CopyFromFunc helper function. (robford)
* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message.
* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid.
* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo)
* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev)
* Add pgtype.Numeric.ScanScientific (Eshton Robateau)
# 5.5.0 (November 4, 2023)
* Add CollectExactlyOneRow. (Julien GOTTELAND)
* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov)
* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements.
* Statement cache now uses deterministic, stable statement names.
* database/sql prepared statement names are deterministically generated.
* Fix: SendBatch wasn't respecting context cancellation.
* Fix: Timeout error from pipeline is now normalized.
* Fix: database/sql encoding json.RawMessage to []byte.
* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin)
* stdlib: Use Ping instead of CheckConn in ResetSession
* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov)
# 5.4.3 (August 5, 2023)
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)
* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb)
* Fix: pgxpool: background health check cannot overflow pool
* Fix: Check for nil in defer when sending batch (recover properly from panic)
* Fix: json scan of non-string pointer to pointer
* Fix: zeronull.Timestamptz should use pgtype.Timestamptz
* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig)
* RowTo(AddrOf)StructByPos ignores fields with "-" db tag
* Optimization: improve text format numeric parsing (horpto)
# 5.4.2 (July 11, 2023)
* Fix: RowScanner errors are fatal to Rows
* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman)
* Hstore text codec internal improvements (Evan Jones)
* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares)
* Fix: Stop background reader as soon as possible.
* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn.
# 5.4.1 (June 18, 2023)
* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov)
* Add TxOptions.BeginQuery to allow overriding the default BEGIN query
# 5.4.0 (June 14, 2023)
* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues.
* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov)
* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
* CancelRequest: don't try to read the reply (Nicola Murino)
* Fix: correctly handle bool type aliases (Wichert Akkerman)
* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones)
* Add BeforeClose to pgxpool.Pool (Evan Cordell)
* Fix: various hstore fixes and optimizations (Evan Jones)
* Fix: RowToStructByPos with embedded unexported struct
* Support different bool string representations (Lev Zakharov)
* Fix: error when using BatchResults.Exec on a select that returns an error after some rows.
* Fix: pipelineBatchResults.Exec() not returning error from ResultReader
* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using
a callback.
* Fix: scanning a table type into a struct
* Fix: scan array of record to pointer to slice of struct
* Fix: handle null for json (Cemre Mengu)
* Batch Query callback is called even when there is an error
* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P)
# 5.3.1 (February 27, 2023)
* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka)
* Fix: sql.Scanner not being used in certain cases
* Add text format jsonpath support
* Fix: fake non-blocking read adaptive wait time
# 5.3.0 (February 11, 2023)
* Fix: json values work with sql.Scanner
* Fixed / improved error messages (Mark Chambers and Yevgeny Pats)
* Fix: support scan into single dimensional arrays
* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub)
* Fix: driver.Value representation of bytea should be []byte not string
* Fix: better handling of unregistered OIDs
* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora)
* Fix: encode to json ignoring driver.Valuer
* Support sql.Scanner on renamed base type
* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers)
* Fix: connect with multiple hostnames when one can't be resolved
* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform
* Fix: scanning json column into **string
* Multiple reductions in memory allocations
* Fake non-blocking read adapts its max wait time
* Improve CopyFrom performance and reduce memory usage
* Fix: encode []any to array
* Fix: LoadType for composite with dropped attributes (Felix Röhrich)
* Support v4 and v5 stdlib in same program
* Fix: text format array decoding with string of "NULL"
* Prefer binary format for arrays
# 5.2.0 (December 5, 2022)
* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov)
* Optimize creating begin transaction SQL string (Petr Evdokimov and ksco)
* `Conn.LoadType` supports range and multirange types (Vitalii Solodilov)
* Fix scan `uint` and `uint64` `ScanNumeric`. This resolves a PostgreSQL `numeric` being incorrectly scanned into `uint` and `uint64`.
# 5.1.1 (November 17, 2022)
* Fix simple query sanitizer where query text contains a Unicode replacement character.
* Remove erroneous `name` argument from `DeallocateAll()`. Technically, this is a breaking change, but given that method was only added 5 days ago this change was accepted. (Bodo Kaiser)
# 5.1.0 (November 12, 2022)
* Update puddle to v2.1.2. This resolves a race condition and a deadlock in pgxpool.
* `QueryRewriter.RewriteQuery` now returns an error. Technically, this is a breaking change for any external implementers, but given the minimal likelihood that there are actually any external implementers this change was accepted.
* Expose `GetSSLPassword` support to pgx.
* Fix encode `ErrorResponse` unknown field handling. This would only affect pgproto3 being used directly as a proxy with a non-PostgreSQL server that included additional error fields.
* Fix date text format encoding with 5 digit years.
* Fix date values passed to a `sql.Scanner` as `string` instead of `time.Time`.
* DateCodec.DecodeValue can return `pgtype.InfinityModifier` instead of `string` for infinite values. This now matches the behavior of the timestamp types.
* Add domain type support to `Conn.LoadType()`.
* Add `RowToStructByName` and `RowToAddrOfStructByName`. (Pavlo Golub)
* Add `Conn.DeallocateAll()` to clear all prepared statements including the statement cache. (Bodo Kaiser)
# 5.0.4 (October 24, 2022)
* Fix: CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows
* Fix: some reflect Kind checks to first check for nil
* Bump golang.org/x/text dependency to placate snyk
* Fix: RowToStructByPos on structs with multiple anonymous sub-structs (Baptiste Fontaine)
* Fix: Exec checks if tx is closed
# 5.0.3 (October 14, 2022)
* Fix `driver.Valuer` handling edge cases that could cause infinite loop or crash
# v5.0.2 (October 8, 2022)
* Fix date encoding in text format to always use 2 digits for month and day
* Prefer driver.Valuer over wrap plans when encoding
* Fix scan to pointer to pointer to renamed type
* Allow scanning NULL even if PG and Go types are incompatible
# v5.0.1 (September 24, 2022)
* Fix 32-bit atomic usage
* Add MarshalJSON for Float8 (yogipristiawan)
* Add `[` and `]` to text encoding of `Lseg`
* Fix sqlScannerWrapper NULL handling
# v5.0.0 (September 17, 2022)
## Merged Packages
`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main
`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional
release work due to releasing multiple packages, and less clear changelogs.
## pgconn
`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`.
The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`.
`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`.
pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features.
`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping.
pgconn now supports pipeline mode.
`*PgConn.ReceiveResults` removed. Use pipeline mode instead.
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
## pgxpool
`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
## pgtype
The `pgtype` package has been significantly changed.
### NULL Representation
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
Previously, a type that implemented `driver.Valuer` would have the `Value` method called even on a nil pointer. All nils
whether typed or untyped now represent `NULL`.
### Codec and Value Split
Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled
encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when
there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a
PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This
concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are
generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and
`PointValuer` for the PostgreSQL `point` type).
### Array Types
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
arrays.
### Composite Types
Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite
values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite.
### Range Types
Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to
easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`.
### pgxtype
`LoadDataType` moved to `*Conn` as `LoadType`.
### Bytea
The `Bytea` and `GenericBinary` types have been replaced. Use the following instead:
* `[]byte` - For normal usage directly use `[]byte`.
* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation.
* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation.
* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes.
### Dropped lib/pq Support
`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work
in most cases this is no longer supported.
### database/sql Scan
Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now
only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by
considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with
`pgx`. The previous behavior was only necessary for `lib/pq` compatibility.
Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement
`sql.Scanner` directly.
### Number Type Fields Include Bit size
`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`.
This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and
`sql.NullInt64` the structures are identical. This means they can be directly converted one to another.
### 3rd Party Type Integrations
* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
the pgx dependency tree.
### Other Changes
* `Bit` and `Varbit` are both replaced by the `Bits` type.
* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type.
* `Hstore` is now defined as `map[string]*string`.
* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly.
* `QChar` type removed. Use `rune` or `byte` directly.
* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`.
* `Macaddr` type removed. Use `net.HardwareAddr` directly.
* Renamed `pgtype.ConnInfo` to `pgtype.Map`.
* Renamed `pgtype.DataType` to `pgtype.Type`.
* Renamed `pgtype.None` to `pgtype.Finite`.
* `RegisterType` now accepts a `*Type` instead of `Type`.
* Assorted array helper methods and types made private.
## stdlib
* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13.
## Reduced Memory Usage by Reusing Read Buffers
Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed
transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy.
However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large
chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now
ownership remains with the read buffer and anything needing to retain a value must make a copy.
## Query Execution Modes
Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode.
See documentation for `QueryExecMode`.
## QueryRewriter Interface and NamedArgs
pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which
allows arbitrary rewriting of query SQL and arguments.
## RowScanner Interface
The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row.
## Rows Result Helpers
* `CollectRows` and `RowTo*` functions simplify collecting results into a slice.
* `CollectOneRow` collects one row using `RowTo*` functions.
* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`.
## Tx Helpers
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
## Improved Batch Query Ergonomics
Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the
results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code
to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can
be used to register a callback function that will handle the result. Callback functions are called automatically when
`BatchResults.Close` is called.
## SendBatch Uses Pipeline Mode When Appropriate
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
in a single network round trip. So it would only take 2 round trips.
## Tracing and Logging
Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer.
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
tree.

121
vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md generated vendored Normal file
View File

@ -0,0 +1,121 @@
# Contributing
## Discuss Significant Changes
Before you invest a significant amount of time on a change, please create a discussion or issue describing your
proposal. This will help to ensure your proposed change has a reasonable chance of being merged.
## Avoid Dependencies
Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change
that adds a dependency is no.
## Development Environment Setup
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE`
environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or key-value pairs. In addition,
the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to
simplify environment variable handling.
### Using an Existing PostgreSQL Cluster
If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx
test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different
authentication methods).
Create and setup a test database:
```
export PGDATABASE=pgx_test
createdb
psql -c 'create extension hstore;'
psql -c 'create extension ltree;'
psql -c 'create domain uint64 as numeric(20,0);'
```
Ensure a `postgres` user exists. This happens by default in normal PostgreSQL installs, but some installation methods
such as Homebrew do not.
```
createuser -s postgres
```
Ensure your `PGX_TEST_DATABASE` environment variable points to the database you just created and run the tests.
```
export PGX_TEST_DATABASE="host=/private/tmp database=pgx_test"
go test ./...
```
This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods).
### Creating a New PostgreSQL Cluster Exclusively for Testing
The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is
highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`.
```
export PGPORT=5015
export PGUSER=postgres
export PGDATABASE=pgx_test
export POSTGRESQL_DATA_DIR=postgresql
export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test"
export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test"
export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret"
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem"
export PGX_SSL_PASSWORD=certpw
export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key"
```
Create a new database cluster.
```
initdb --locale=en_US -E UTF-8 --username=postgres .testdb/$POSTGRESQL_DATA_DIR
echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
cd .testdb
# Generate CA, server, and encrypted client certificates.
go run ../testsetup/generate_certs.go
# Copy certificates to server directory and set permissions.
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
cp localhost.key $POSTGRESQL_DATA_DIR/server.key
chmod 600 $POSTGRESQL_DATA_DIR/server.key
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
cd ..
```
Start the new cluster. This will be necessary whenever you are running pgx tests.
```
postgres -D .testdb/$POSTGRESQL_DATA_DIR
```
Setup the test database in the new cluster.
```
createdb
psql --no-psqlrc -f testsetup/postgresql_setup.sql
```
### PgBouncer
There are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
### Optional Tests
pgx supports multiple connection types and means of authentication. These tests are optional. They will only run if the
appropriate environment variables are set. In addition, there may be tests specific to particular PostgreSQL versions,
non-PostgreSQL servers (e.g. CockroachDB), or connection poolers (e.g. PgBouncer). `go test ./... -v | grep SKIP` to see
if any tests are being skipped.

22
vendor/github.com/jackc/pgx/v5/LICENSE generated vendored Normal file
View File

@ -0,0 +1,22 @@
Copyright (c) 2013-2021 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

174
vendor/github.com/jackc/pgx/v5/README.md generated vendored Normal file
View File

@ -0,0 +1,174 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5)
[![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml)
# pgx - PostgreSQL Driver and Toolkit
pgx is a pure Go driver and toolkit for PostgreSQL.
The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
proxies, load balancers, logical replication clients, etc.
## Example Usage
```go
package main
import (
"context"
"fmt"
"os"
"github.com/jackc/pgx/v5"
)
func main() {
// urlExample := "postgres://username:password@localhost:5432/database_name"
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err)
os.Exit(1)
}
defer conn.Close(context.Background())
var name string
var weight int64
err = conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
if err != nil {
fmt.Fprintf(os.Stderr, "QueryRow failed: %v\n", err)
os.Exit(1)
}
fmt.Println(name, weight)
}
```
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
## Features
* Support for approximately 70 different PostgreSQL types
* Automatic statement preparation and caching
* Batch queries
* Single-round trip query mode
* Full TLS connection control
* Binary format support for custom types (allows for much quicker encoding/decoding)
* `COPY` protocol support for faster bulk data loads
* Tracing and logging support
* Connection pool with after-connect hook for arbitrary connection setup
* `LISTEN` / `NOTIFY`
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
* `hstore` support
* `json` and `jsonb` support
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
* Large object support
* NULL mapping to pointer to pointer
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
* Notice response handling
* Simulated nested transactions with savepoints
## Choosing Between the pgx and database/sql Interfaces
The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
through the `database/sql` interface.
The pgx interface is recommended when:
1. The application only targets PostgreSQL.
2. No other libraries that require `database/sql` are in use.
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
## Testing
See CONTRIBUTING.md for setup instructions.
## Architecture
See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture.
## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy
pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
## PGX Family Libraries
### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl)
pglogrepl provides functionality to act as a client for PostgreSQL logical replication.
### [github.com/jackc/pgmock](https://github.com/jackc/pgmock)
pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler).
### [github.com/jackc/tern](https://github.com/jackc/tern)
tern is a stand-alone SQL migration system.
### [github.com/jackc/pgerrcode](https://github.com/jackc/pgerrcode)
pgerrcode contains constants for the PostgreSQL error codes.
## Adapters for 3rd Party Types
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
## Adapters for 3rd Party Tracers
* [https://github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
## Adapters for 3rd Party Loggers
These adapters can be used with the tracelog package.
* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log)
* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15)
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)
* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog)
## 3rd Party Libraries with PGX Support
### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock)
pgxmock is a mock library implementing pgx interfaces.
pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
Library for scanning data from a database into Go structs and more.
### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql)
A carefully designed SQL client for making using SQL easier,
more productive, and less error-prone on Golang.
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
Adds GSSAPI / Kerberos authentication support.
### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx)
Explicit data mapping and scanning library for Go structs and slices.
### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan)
Type safe and flexible package for scanning database data into Go types.
Supports, structs, maps, slices and custom mapping functions.
### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
Code first migration library for native pgx (no database/sql abstraction).

18
vendor/github.com/jackc/pgx/v5/Rakefile generated vendored Normal file
View File

@ -0,0 +1,18 @@
require "erb"
rule '.go' => '.go.erb' do |task|
erb = ERB.new(File.read(task.source))
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
sh "goimports", "-w", task.name
end
generated_code_files = [
"pgtype/int.go",
"pgtype/int_test.go",
"pgtype/integration_benchmark_test.go",
"pgtype/zeronull/int.go",
"pgtype/zeronull/int_test.go"
]
desc "Generate code"
task generate: generated_code_files

433
vendor/github.com/jackc/pgx/v5/batch.go generated vendored Normal file
View File

@ -0,0 +1,433 @@
package pgx
import (
"context"
"errors"
"fmt"
"github.com/jackc/pgx/v5/pgconn"
)
// QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct {
SQL string
Arguments []any
Fn batchItemFunc
sd *pgconn.StatementDescription
}
type batchItemFunc func(br BatchResults) error
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.Fn = func(br BatchResults) error {
rows, _ := br.Query()
defer rows.Close()
err := fn(rows)
if err != nil {
return err
}
rows.Close()
return rows.Err()
}
}
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
qq.Fn = func(br BatchResults) error {
row := br.QueryRow()
return fn(row)
}
}
// Exec sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.Fn = func(br BatchResults) error {
ct, err := br.Exec()
if err != nil {
return err
}
return fn(ct)
}
}
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once.
type Batch struct {
QueuedQueries []*QueuedQuery
}
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
// The only pgx option argument that is supported is QueryRewriter. Queries are executed using the
// connection's DefaultQueryExecMode.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{
SQL: query,
Arguments: arguments,
}
b.QueuedQueries = append(b.QueuedQueries, qq)
return qq
}
// Len returns number of queries that have been queued so far.
func (b *Batch) Len() int {
return len(b.QueuedQueries)
}
type BatchResults interface {
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
// calling Exec on the QueuedQuery.
Exec() (pgconn.CommandTag, error)
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
// calling Query on the QueuedQuery.
Query() (Rows, error)
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
// Prefer calling QueryRow on the QueuedQuery.
QueryRow() Row
// Close closes the batch operation. All unread results are read and any callback functions registered with
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
// error or the batch encounters an error subsequent callback functions will not be called.
//
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
// connection will have been closed.
//
// Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback
// functions will not be rerun.
Close() error
}
type batchResults struct {
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
b *Batch
qqIdx int
closed bool
endTraced bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return pgconn.CommandTag{}, br.err
}
if br.closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
query, arguments, _ := br.nextQueryAndArgs()
if !br.mrr.NextResult() {
err := br.mrr.Close()
if err == nil {
err = errors.New("no result")
}
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: err,
})
}
return pgconn.CommandTag{}, err
}
commandTag, err := br.mrr.ResultReader().Close()
if err != nil {
br.err = err
br.mrr.Close()
}
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
CommandTag: commandTag,
Err: br.err,
})
}
return commandTag, br.err
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
func (br *batchResults) Query() (Rows, error) {
query, arguments, ok := br.nextQueryAndArgs()
if !ok {
query = "batch query"
}
if br.err != nil {
return &baseRows{err: br.err, closed: true}, br.err
}
if br.closed {
alreadyClosedErr := fmt.Errorf("batch already closed")
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
}
rows := br.conn.getRows(br.ctx, query, arguments)
rows.batchTracer = br.conn.batchTracer
if !br.mrr.NextResult() {
rows.err = br.mrr.Close()
if rows.err == nil {
rows.err = errors.New("no result")
}
rows.closed = true
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: rows.err,
})
}
return rows, rows.err
}
rows.resultReader = br.mrr.ResultReader()
return rows, nil
}
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
func (br *batchResults) QueryRow() Row {
rows, _ := br.Query()
return (*connRow)(rows.(*baseRows))
}
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *batchResults) Close() error {
defer func() {
if !br.endTraced {
if br.conn != nil && br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
}
br.endTraced = true
}
}()
if br.err != nil {
return br.err
}
if br.closed {
return nil
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil {
br.err = err
}
} else {
br.Exec()
}
}
br.closed = true
err := br.mrr.Close()
if br.err == nil {
br.err = err
}
return br.err
}
func (br *batchResults) earlyError() error {
return br.err
}
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
ok = true
br.qqIdx++
}
return
}
type pipelineBatchResults struct {
ctx context.Context
conn *Conn
pipeline *pgconn.Pipeline
lastRows *baseRows
err error
b *Batch
qqIdx int
closed bool
endTraced bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return pgconn.CommandTag{}, br.err
}
if br.closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
if br.lastRows != nil && br.lastRows.err != nil {
return pgconn.CommandTag{}, br.err
}
query, arguments, _ := br.nextQueryAndArgs()
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
return pgconn.CommandTag{}, br.err
}
var commandTag pgconn.CommandTag
switch results := results.(type) {
case *pgconn.ResultReader:
commandTag, br.err = results.Close()
default:
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
}
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
CommandTag: commandTag,
Err: br.err,
})
}
return commandTag, br.err
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
func (br *pipelineBatchResults) Query() (Rows, error) {
if br.err != nil {
return &baseRows{err: br.err, closed: true}, br.err
}
if br.closed {
alreadyClosedErr := fmt.Errorf("batch already closed")
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return &baseRows{err: br.err, closed: true}, br.err
}
query, arguments, ok := br.nextQueryAndArgs()
if !ok {
query = "batch query"
}
rows := br.conn.getRows(br.ctx, query, arguments)
rows.batchTracer = br.conn.batchTracer
br.lastRows = rows
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
rows.err = err
rows.closed = true
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: err,
})
}
} else {
switch results := results.(type) {
case *pgconn.ResultReader:
rows.resultReader = results
default:
err = fmt.Errorf("unexpected pipeline result: %T", results)
br.err = err
rows.err = err
rows.closed = true
}
}
return rows, rows.err
}
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
func (br *pipelineBatchResults) QueryRow() Row {
rows, _ := br.Query()
return (*connRow)(rows.(*baseRows))
}
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *pipelineBatchResults) Close() error {
defer func() {
if !br.endTraced {
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
}
br.endTraced = true
}
}()
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return br.err
}
if br.closed {
return br.err
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil {
br.err = err
}
} else {
br.Exec()
}
}
br.closed = true
err := br.pipeline.Close()
if br.err == nil {
br.err = err
}
return br.err
}
func (br *pipelineBatchResults) earlyError() error {
return br.err
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
ok = true
br.qqIdx++
}
return
}

1395
vendor/github.com/jackc/pgx/v5/conn.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

276
vendor/github.com/jackc/pgx/v5/copy_from.go generated vendored Normal file
View File

@ -0,0 +1,276 @@
package pgx
import (
"bytes"
"context"
"fmt"
"io"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn"
)
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
// making it usable by *Conn.CopyFrom.
func CopyFromRows(rows [][]any) CopyFromSource {
return &copyFromRows{rows: rows, idx: -1}
}
type copyFromRows struct {
rows [][]any
idx int
}
func (ctr *copyFromRows) Next() bool {
ctr.idx++
return ctr.idx < len(ctr.rows)
}
func (ctr *copyFromRows) Values() ([]any, error) {
return ctr.rows[ctr.idx], nil
}
func (ctr *copyFromRows) Err() error {
return nil
}
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
// making it usable by *Conn.CopyFrom.
func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
return &copyFromSlice{next: next, idx: -1, len: length}
}
type copyFromSlice struct {
next func(int) ([]any, error)
idx int
len int
err error
}
func (cts *copyFromSlice) Next() bool {
cts.idx++
return cts.idx < cts.len
}
func (cts *copyFromSlice) Values() ([]any, error) {
values, err := cts.next(cts.idx)
if err != nil {
cts.err = err
}
return values, err
}
func (cts *copyFromSlice) Err() error {
return cts.err
}
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
return &copyFromFunc{next: nxtf}
}
type copyFromFunc struct {
next func() ([]any, error)
valueRow []any
err error
}
func (g *copyFromFunc) Next() bool {
g.valueRow, g.err = g.next()
// only return true if valueRow exists and no error
return g.valueRow != nil && g.err == nil
}
func (g *copyFromFunc) Values() ([]any, error) {
return g.valueRow, g.err
}
func (g *copyFromFunc) Err() error {
return g.err
}
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
type CopyFromSource interface {
// Next returns true if there is another row and makes the next row data
// available to Values(). When there are no more rows available or an error
// has occurred it returns false.
Next() bool
// Values returns the values for the current row.
Values() ([]any, error)
// Err returns any error that has been encountered by the CopyFromSource. If
// this is not nil *Conn.CopyFrom will abort the copy.
Err() error
}
type copyFrom struct {
conn *Conn
tableName Identifier
columnNames []string
rowSrc CopyFromSource
readerErrChan chan error
mode QueryExecMode
}
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
if ct.conn.copyFromTracer != nil {
ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
TableName: ct.tableName,
ColumnNames: ct.columnNames,
})
}
quotedTableName := ct.tableName.Sanitize()
cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
if i != 0 {
cbuf.WriteString(", ")
}
cbuf.WriteString(quoteIdentifier(cn))
}
quotedColumnNames := cbuf.String()
var sd *pgconn.StatementDescription
switch ct.mode {
case QueryExecModeExec, QueryExecModeSimpleProtocol:
// These modes don't support the binary format. Before the inclusion of the
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
// the OIDs. These prepared statements were not cached.
//
// Since that's the same behavior provided by QueryExecModeDescribeExec,
// we'll default to that mode.
ct.mode = QueryExecModeDescribeExec
fallthrough
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
var err error
sd, err = ct.conn.getStatementDescription(
ctx,
ct.mode,
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
)
if err != nil {
return 0, fmt.Errorf("statement description failed: %w", err)
}
default:
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
}
r, w := io.Pipe()
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
buf := ct.conn.wbuf
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
moreRows := true
for moreRows {
var err error
moreRows, buf, err = ct.buildCopyBuf(buf, sd)
if err != nil {
w.CloseWithError(err)
return
}
if ct.rowSrc.Err() != nil {
w.CloseWithError(ct.rowSrc.Err())
return
}
if len(buf) > 0 {
_, err = w.Write(buf)
if err != nil {
w.Close()
return
}
}
buf = buf[:0]
}
w.Close()
}()
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
r.Close()
<-doneChan
if ct.conn.copyFromTracer != nil {
ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
CommandTag: commandTag,
Err: err,
})
}
return commandTag.RowsAffected(), err
}
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
lastBufLen := 0
largestRowLen := 0
for ct.rowSrc.Next() {
lastBufLen = len(buf)
values, err := ct.rowSrc.Values()
if err != nil {
return false, nil, err
}
if len(values) != len(ct.columnNames) {
return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
if err != nil {
return false, nil, err
}
}
rowLen := len(buf) - lastBufLen
if rowLen > largestRowLen {
largestRowLen = rowLen
}
// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
// 13, 65531, 13, 65531, 13.
if len(buf) > sendBufSize-largestRowLen {
return true, buf, nil
}
}
return false, buf, nil
}
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
// an error.
//
// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
// for the type of each column. Almost all types implemented by pgx support the binary format.
//
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
// Conn.LoadType and pgtype.Map.RegisterType.
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
ct := &copyFrom{
conn: c,
tableName: tableName,
columnNames: columnNames,
rowSrc: rowSrc,
readerErrChan: make(chan error),
mode: c.config.DefaultQueryExecMode,
}
return ct.run(ctx)
}

194
vendor/github.com/jackc/pgx/v5/doc.go generated vendored Normal file
View File

@ -0,0 +1,194 @@
// Package pgx is a PostgreSQL database driver.
/*
pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar
to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use
github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for
details.
Establishing a Connection
The primary way of establishing a connection is with [pgx.Connect]:
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
string.
Connection Pool
[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
Query Interface
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
rows.Scan, and rows.Err().
CollectRows can be used collect all returned rows into a slice.
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5)
numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32])
if err != nil {
return err
}
// numbers => [1 2 3 4 5]
ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows
directly.
var sum, n int32
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
_, err := pgx.ForEachRow(rows, []any{&n}, func() error {
sum += n
return nil
})
if err != nil {
return err
}
pgx also implements QueryRow in the same style as database/sql.
var name string
var weight int64
err := conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
if err != nil {
return err
}
Use Exec to execute a query that does not return a result set.
commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42)
if err != nil {
return err
}
if commandTag.RowsAffected() != 1 {
return errors.New("No row found to delete")
}
PostgreSQL Data Types
pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
require type registration. See that package's documentation for details.
Transactions
Transactions are started by calling Begin.
tx, err := conn.Begin(context.Background())
if err != nil {
return err
}
// Rollback is safe to call even if the tx is already closed, so if
// the tx commits successfully, this is a no-op
defer tx.Rollback(context.Background())
_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
if err != nil {
return err
}
err = tx.Commit(context.Background())
if err != nil {
return err
}
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
These are internally implemented with savepoints.
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
a pseudo nested transaction.
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use.
err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
return err
})
if err != nil {
return err
}
Prepared Statements
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx
includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are
automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig
for information on how to customize or disable the statement cache.
Copy Protocol
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a
CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface.
Or implement CopyFromSource to avoid buffering the entire data set in memory.
rows := [][]any{
{"John", "Smith", int32(36)},
{"Jane", "Doe", int32(29)},
}
copyCount, err := conn.CopyFrom(
context.Background(),
pgx.Identifier{"people"},
[]string{"first_name", "last_name", "age"},
pgx.CopyFromRows(rows),
)
When you already have a typed array using CopyFromSlice can be more convenient.
rows := []User{
{"John", "Smith", 36},
{"Jane", "Doe", 29},
}
copyCount, err := conn.CopyFrom(
context.Background(),
pgx.Identifier{"people"},
[]string{"first_name", "last_name", "age"},
pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) {
return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil
}),
)
CopyFrom can be faster than an insert with as few as 5 rows.
Listen and Notify
pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a
notification is received or the context is canceled.
_, err := conn.Exec(context.Background(), "listen channelname")
if err != nil {
return err
}
notification, err := conn.WaitForNotification(context.Background())
if err != nil {
return err
}
// do something with notification
Tracing and Logging
pgx supports tracing by setting ConnConfig.Tracer.
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3.
Lower Level PostgreSQL Functionality
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
*/
package pgx

View File

@ -0,0 +1,146 @@
package pgx
import (
"fmt"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
// formats for an extended query.
type ExtendedQueryBuilder struct {
ParamValues [][]byte
paramValueBytes []byte
ParamFormats []int16
ResultFormats []int16
}
// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
// sd is nil then QueryExecModeExec behavior will be used.
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset()
if sd == nil {
for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
}
if len(sd.ParamOIDs) != len(args) {
return fmt.Errorf("mismatched param and argument count")
}
for i := range args {
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
for i := range sd.Fields {
eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
return nil
}
// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
// must be an untyped nil.
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
if format == -1 {
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
if preferredErr == nil {
return nil
}
var otherFormat int16
if preferredFormat == TextFormatCode {
otherFormat = BinaryFormatCode
} else {
otherFormat = TextFormatCode
}
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
if otherErr == nil {
return nil
}
return preferredErr // return the error from the preferred format
}
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
if err != nil {
return err
}
eqb.ParamFormats = append(eqb.ParamFormats, format)
eqb.ParamValues = append(eqb.ParamValues, v)
return nil
}
// appendResultFormat appends a result format to the query.
func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
eqb.ResultFormats = append(eqb.ResultFormats, format)
}
// reset readies eqb to build another query.
func (eqb *ExtendedQueryBuilder) reset() {
eqb.ParamValues = eqb.ParamValues[0:0]
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
eqb.ParamFormats = eqb.ParamFormats[0:0]
eqb.ResultFormats = eqb.ResultFormats[0:0]
if cap(eqb.ParamValues) > 64 {
eqb.ParamValues = make([][]byte, 0, 64)
}
if cap(eqb.paramValueBytes) > 256 {
eqb.paramValueBytes = make([]byte, 0, 256)
}
if cap(eqb.ParamFormats) > 64 {
eqb.ParamFormats = make([]int16, 0, 64)
}
if cap(eqb.ResultFormats) > 64 {
eqb.ResultFormats = make([]int16, 0, 64)
}
}
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
pos := len(eqb.paramValueBytes)
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
// chooseParameterFormatCode determines the correct format code for an
// argument to a prepared statement. It defaults to TextFormatCode if no
// determination can be made.
func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
switch arg.(type) {
case string, *string:
return TextFormatCode
}
return m.FormatCodeForOID(oid)
}

View File

@ -0,0 +1,70 @@
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
//
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
package iobufpool
import "sync"
const minPoolExpOf2 = 8
var pools [18]*sync.Pool
func init() {
for i := range pools {
bufLen := 1 << (minPoolExpOf2 + i)
pools[i] = &sync.Pool{
New: func() any {
buf := make([]byte, bufLen)
return &buf
},
}
}
}
// Get gets a []byte of len size with cap <= size*2.
func Get(size int) *[]byte {
i := getPoolIdx(size)
if i >= len(pools) {
buf := make([]byte, size)
return &buf
}
ptrBuf := (pools[i].Get().(*[]byte))
*ptrBuf = (*ptrBuf)[:size]
return ptrBuf
}
func getPoolIdx(size int) int {
size--
size >>= minPoolExpOf2
i := 0
for size > 0 {
size >>= 1
i++
}
return i
}
// Put returns buf to the pool.
func Put(buf *[]byte) {
i := putPoolIdx(cap(*buf))
if i < 0 {
return
}
pools[i].Put(buf)
}
func putPoolIdx(size int) int {
minPoolSize := 1 << minPoolExpOf2
for i := range pools {
if size == minPoolSize<<i {
return i
}
}
return -1
}

View File

@ -0,0 +1,6 @@
# pgio
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
pgio provides functions for appending integers to a []byte while doing byte
order conversion.

6
vendor/github.com/jackc/pgx/v5/internal/pgio/doc.go generated vendored Normal file
View File

@ -0,0 +1,6 @@
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
/*
pgio provides functions for appending integers to a []byte while doing byte
order conversion.
*/
package pgio

40
vendor/github.com/jackc/pgx/v5/internal/pgio/write.go generated vendored Normal file
View File

@ -0,0 +1,40 @@
package pgio
import "encoding/binary"
func AppendUint16(buf []byte, n uint16) []byte {
wp := len(buf)
buf = append(buf, 0, 0)
binary.BigEndian.PutUint16(buf[wp:], n)
return buf
}
func AppendUint32(buf []byte, n uint32) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0)
binary.BigEndian.PutUint32(buf[wp:], n)
return buf
}
func AppendUint64(buf []byte, n uint64) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint64(buf[wp:], n)
return buf
}
func AppendInt16(buf []byte, n int16) []byte {
return AppendUint16(buf, uint16(n))
}
func AppendInt32(buf []byte, n int32) []byte {
return AppendUint32(buf, uint32(n))
}
func AppendInt64(buf []byte, n int64) []byte {
return AppendUint64(buf, uint64(n))
}
func SetInt32(buf []byte, n int32) {
binary.BigEndian.PutUint32(buf, uint32(n))
}

View File

@ -0,0 +1,331 @@
package sanitize
import (
"bytes"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
"unicode/utf8"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part any
type Query struct {
Parts []Part
}
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
case int:
argIdx := part - 1
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
case int64:
str = strconv.FormatInt(arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
case []byte:
str = QuoteBytes(arg)
case string:
str = QuoteString(arg)
case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
if !used {
return "", fmt.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
l.pos += width
return oneLineCommentState
}
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
return multilineCommentState
}
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func oneLineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n', '\r':
return rawState
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func multilineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
l.nested++
}
case '*':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '/' {
continue
}
l.pos += width
if l.nested == 0 {
return rawState
}
l.nested--
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}

View File

@ -0,0 +1,112 @@
package stmtcache
import (
"container/list"
"github.com/jackc/pgx/v5/pgconn"
)
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
type LRUCache struct {
cap int
m map[string]*list.Element
l *list.List
invalidStmts []*pgconn.StatementDescription
}
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
func NewLRUCache(cap int) *LRUCache {
return &LRUCache{
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
if el, ok := c.m[key]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription)
}
return nil
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
for _, invalidSD := range c.invalidStmts {
if invalidSD.SQL == sd.SQL {
return
}
}
if c.l.Len() == c.cap {
c.invalidateOldest()
}
el := c.l.PushFront(sd)
c.m[sd.SQL] = el
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *LRUCache) Invalidate(sql string) {
if el, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
c.l.Remove(el)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *LRUCache) InvalidateAll() {
el := c.l.Front()
for el != nil {
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
el = el.Next()
}
c.m = make(map[string]*list.Element)
c.l = list.New()
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRUCache) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRUCache) Cap() int {
return c.cap
}
func (c *LRUCache) invalidateOldest() {
oldest := c.l.Back()
sd := oldest.Value.(*pgconn.StatementDescription)
c.invalidStmts = append(c.invalidStmts, sd)
delete(c.m, sd.SQL)
c.l.Remove(oldest)
}

View File

@ -0,0 +1,45 @@
// Package stmtcache is a cache for statement descriptions.
package stmtcache
import (
"crypto/sha256"
"encoding/hex"
"github.com/jackc/pgx/v5/pgconn"
)
// StatementName returns a statement name that will be stable for sql across multiple connections and program
// executions.
func StatementName(sql string) string {
digest := sha256.Sum256([]byte(sql))
return "stmtcache_" + hex.EncodeToString(digest[0:24])
}
// Cache caches statement descriptions.
type Cache interface {
// Get returns the statement description for sql. Returns nil if not found.
Get(sql string) *pgconn.StatementDescription
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
Put(sd *pgconn.StatementDescription)
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
Invalidate(sql string)
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
}

View File

@ -0,0 +1,77 @@
package stmtcache
import (
"math"
"github.com/jackc/pgx/v5/pgconn"
)
// UnlimitedCache implements Cache with no capacity limit.
type UnlimitedCache struct {
m map[string]*pgconn.StatementDescription
invalidStmts []*pgconn.StatementDescription
}
// NewUnlimitedCache creates a new UnlimitedCache.
func NewUnlimitedCache() *UnlimitedCache {
return &UnlimitedCache{
m: make(map[string]*pgconn.StatementDescription),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
return c.m[sql]
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
c.m[sd.SQL] = sd
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *UnlimitedCache) Invalidate(sql string) {
if sd, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, sd)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *UnlimitedCache) InvalidateAll() {
for _, sd := range c.m {
c.invalidStmts = append(c.invalidStmts, sd)
}
c.m = make(map[string]*pgconn.StatementDescription)
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
c.invalidStmts = nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *UnlimitedCache) Len() int {
return len(c.m)
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *UnlimitedCache) Cap() int {
return math.MaxInt
}

161
vendor/github.com/jackc/pgx/v5/large_objects.go generated vendored Normal file
View File

@ -0,0 +1,161 @@
package pgx
import (
"context"
"errors"
"io"
"github.com/jackc/pgx/v5/pgtype"
)
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
// was created.
//
// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html
type LargeObjects struct {
tx Tx
}
type LargeObjectMode int32
const (
LargeObjectModeWrite LargeObjectMode = 0x20000
LargeObjectModeRead LargeObjectMode = 0x40000
)
// Create creates a new large object. If oid is zero, the server assigns an unused OID.
func (o *LargeObjects) Create(ctx context.Context, oid uint32) (uint32, error) {
err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid)
return oid, err
}
// Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large
// object.
func (o *LargeObjects) Open(ctx context.Context, oid uint32, mode LargeObjectMode) (*LargeObject, error) {
var fd int32
err := o.tx.QueryRow(ctx, "select lo_open($1, $2)", oid, mode).Scan(&fd)
if err != nil {
return nil, err
}
return &LargeObject{fd: fd, tx: o.tx, ctx: ctx}, nil
}
// Unlink removes a large object from the database.
func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error {
var result int32
err := o.tx.QueryRow(ctx, "select lo_unlink($1)", oid).Scan(&result)
if err != nil {
return err
}
if result != 1 {
return errors.New("failed to remove large object")
}
return nil
}
// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized
// in. It uses the context it was initialized with for all operations. It implements these interfaces:
//
// io.Writer
// io.Reader
// io.Seeker
// io.Closer
type LargeObject struct {
ctx context.Context
tx Tx
fd int32
}
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
func (o *LargeObject) Write(p []byte) (int, error) {
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
if err != nil {
return nTotal, err
}
if n < 0 {
return nTotal, errors.New("failed to write to large object")
}
nTotal += n
if n < expected {
return nTotal, errors.New("short write to large object")
} else if n > expected {
return nTotal, errors.New("invalid write to large object")
}
}
return nTotal, nil
}
// Read reads up to len(p) bytes into p returning the number of bytes read.
func (o *LargeObject) Read(p []byte) (int, error) {
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
res := pgtype.PreallocBytes(p[nTotal:])
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
// We compute expected so that it always fits into p, so it should never happen
// that PreallocBytes's ScanBytes had to allocate a new slice.
nTotal += len(res)
if err != nil {
return nTotal, err
}
if len(res) < expected {
return nTotal, io.EOF
} else if len(res) > expected {
return nTotal, errors.New("invalid read of large object")
}
}
return nTotal, nil
}
// Seek moves the current location pointer to the new location specified by offset.
func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) {
err = o.tx.QueryRow(o.ctx, "select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n)
return n, err
}
// Tell returns the current read or write location of the large object descriptor.
func (o *LargeObject) Tell() (n int64, err error) {
err = o.tx.QueryRow(o.ctx, "select lo_tell64($1)", o.fd).Scan(&n)
return n, err
}
// Truncate the large object to size.
func (o *LargeObject) Truncate(size int64) (err error) {
_, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size)
return err
}
// Close the large object descriptor.
func (o *LargeObject) Close() error {
_, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd)
return err
}

295
vendor/github.com/jackc/pgx/v5/named_args.go generated vendored Normal file
View File

@ -0,0 +1,295 @@
package pgx
import (
"context"
"fmt"
"strconv"
"strings"
"unicode/utf8"
)
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
// ordinal placeholder and construct the appropriate arguments.
//
// For example, the following two queries are equivalent:
//
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
//
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
// letters, numbers, or underscores.
type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
nameToOrdinal: make(map[namedArg]int, len(na)),
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
sb := strings.Builder{}
for _, p := range l.parts {
switch p := p.(type) {
case string:
sb.WriteString(p)
case namedArg:
sb.WriteRune('$')
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
}
}
newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
}
if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
}
return sb.String(), newArgs, nil
}
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) || nextRune == '_' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return namedArgState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
l.pos += width
return oneLineCommentState
}
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
return multilineCommentState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func isLetter(r rune) bool {
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
}
func namedArgState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if r == utf8.RuneError {
if l.pos-l.start > 0 {
na := namedArg(l.src[l.start:l.pos])
if _, found := l.nameToOrdinal[na]; !found {
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
}
l.parts = append(l.parts, na)
l.start = l.pos
}
return nil
} else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
l.pos -= width
na := namedArg(l.src[l.start:l.pos])
if _, found := l.nameToOrdinal[na]; !found {
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
}
l.parts = append(l.parts, namedArg(na))
l.start = l.pos
return rawState
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func oneLineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n', '\r':
return rawState
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func multilineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
l.nested++
}
case '*':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '/' {
continue
}
l.pos += width
if l.nested == 0 {
return rawState
}
l.nested--
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}

29
vendor/github.com/jackc/pgx/v5/pgconn/README.md generated vendored Normal file
View File

@ -0,0 +1,29 @@
# pgconn
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
low-level access to PostgreSQL functionality.
## Example Usage
```go
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatalln("pgconn failed to connect:", err)
}
defer pgConn.Close(context.Background())
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
for result.NextRow() {
fmt.Println("User 123 has email:", string(result.Values()[0]))
}
_, err = result.Close()
if err != nil {
log.Fatalln("failed reading result:", err)
}
```
## Testing
See CONTRIBUTING.md for setup instructions.

272
vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go generated vendored Normal file
View File

@ -0,0 +1,272 @@
// SCRAM-SHA-256 authentication
//
// Resources:
// https://tools.ietf.org/html/rfc5802
// https://tools.ietf.org/html/rfc8265
// https://www.postgresql.org/docs/current/sasl-authentication.html
//
// Inspiration drawn from other implementations:
// https://github.com/lib/pq/pull/608
// https://github.com/lib/pq/pull/788
// https://github.com/lib/pq/pull/833
package pgconn
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strconv"
"github.com/jackc/pgx/v5/pgproto3"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/text/secure/precis"
)
const clientNonceLen = 18
// Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil {
return err
}
// Send client-first-message in a SASLInitialResponse
saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
c.frontend.Send(saslInitialResponse)
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
// Receive server-first-message payload in an AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
}
err = sc.recvServerFirstMessage(saslContinue.Data)
if err != nil {
return err
}
// Send client-final-message in a SASLResponse
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
c.frontend.Send(saslResponse)
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
// Receive server-final-message payload in an AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err
}
return sc.recvServerFinalMessage(saslFinal.Data)
}
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLFinal:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
}
type scramClient struct {
serverAuthMechanisms []string
password []byte
clientNonce []byte
clientFirstMessageBare []byte
serverFirstMessage []byte
clientAndServerNonce []byte
salt []byte
iterations int
saltedPassword []byte
authMessage []byte
}
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
sc := &scramClient{
serverAuthMechanisms: serverAuthMechanisms,
}
// Ensure server supports SCRAM-SHA-256
hasScramSHA256 := false
for _, mech := range sc.serverAuthMechanisms {
if mech == "SCRAM-SHA-256" {
hasScramSHA256 = true
break
}
}
if !hasScramSHA256 {
return nil, errors.New("server does not support SCRAM-SHA-256")
}
// precis.OpaqueString is equivalent to SASLprep for password.
var err error
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
if err != nil {
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
sc.password = []byte(password)
}
buf := make([]byte, clientNonceLen)
_, err = rand.Read(buf)
if err != nil {
return nil, err
}
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
return sc, nil
}
func (sc *scramClient) clientFirstMessage() []byte {
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
}
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
sc.serverFirstMessage = serverFirstMessage
buf := serverFirstMessage
if !bytes.HasPrefix(buf, []byte("r=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
}
buf = buf[2:]
idx := bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
sc.clientAndServerNonce = buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("s=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
buf = buf[2:]
idx = bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
saltStr := buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("i=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
buf = buf[2:]
iterationsStr := buf
var err error
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
if err != nil {
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
}
sc.iterations, err = strconv.Atoi(string(iterationsStr))
if err != nil || sc.iterations <= 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
}
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not start with client nonce")
}
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not include server nonce")
}
return nil
}
func (sc *scramClient) clientFinalMessage() string {
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
}
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
return errors.New("invalid SCRAM server-final-message received from server")
}
serverSignature := serverFinalMessage[2:]
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
return errors.New("invalid SCRAM ServerSignature received from server")
}
return nil
}
func computeHMAC(key, msg []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(msg)
return mac.Sum(nil)
}
func computeClientProof(saltedPassword, authMessage []byte) []byte {
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
storedKey := sha256.Sum256(clientKey)
clientSignature := computeHMAC(storedKey[:], authMessage)
clientProof := make([]byte, len(clientSignature))
for i := 0; i < len(clientSignature); i++ {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
base64.StdEncoding.Encode(buf, clientProof)
return buf
}
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
serverSignature := computeHMAC(serverKey, authMessage)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
base64.StdEncoding.Encode(buf, serverSignature)
return buf
}

918
vendor/github.com/jackc/pgx/v5/pgconn/config.go generated vendored Normal file
View File

@ -0,0 +1,918 @@
package pgconn
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"math"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
type GetSSLPasswordFunc func(ctx context.Context) string
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
// manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string
KerberosSpn string
Fallbacks []*FallbackConfig
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
ValidateConnect ValidateConnectFunc
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
// or prepare statements). If this returns an error the connection attempt fails.
AfterConnect AfterConnectFunc
// OnNotice is a callback function called when a notice response is received.
OnNotice NoticeHandler
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
// PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the TLSConfig field:
// according to the tls.Config docs it must not be modified after creation.
func (c *Config) Copy() *Config {
newConf := new(Config)
*newConf = *c
if newConf.TLSConfig != nil {
newConf.TLSConfig = c.TLSConfig.Clone()
}
if newConf.RuntimeParams != nil {
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
for k, v := range c.RuntimeParams {
newConf.RuntimeParams[k] = v
}
}
if newConf.Fallbacks != nil {
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
for i, fallback := range c.Fallbacks {
newFallback := new(FallbackConfig)
*newFallback = *fallback
if newFallback.TLSConfig != nil {
newFallback.TLSConfig = fallback.TLSConfig.Clone()
}
newConf.Fallbacks[i] = newFallback
}
}
return newConf
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
type FallbackConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
TLSConfig *tls.Config // nil disables TLS
}
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
func isAbsolutePath(path string) bool {
isWindowsPath := func(p string) bool {
if len(p) < 3 {
return false
}
drive := p[0]
colon := p[1]
backslash := p[2]
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
return true
}
return false
}
return strings.HasPrefix(path, "/") || isWindowsPath(path)
}
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
// net.Dial.
func NetworkAddress(host string, port uint16) (network, address string) {
if isAbsolutePath(host) {
network = "unix"
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
network = "tcp"
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
return network, address
}
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example Keyword/Value
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
//
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
// not be modified individually. They should all be modified or all left unchanged.
//
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
// values that will be tried in order. This can be used as part of a high availability system. See
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
//
// # Example URL
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or keyword/value:
//
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGPASSFILE
// PGSERVICE
// PGSERVICEFILE
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGSSLPASSWORD
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
//
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
//
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
// usually but not always the environment variable name downcased and without the "PG" prefix.
//
// Important Security Notes:
//
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
// not set.
//
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
// security each sslmode provides.
//
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
// TLSConfig.
//
// Other known differences with libpq:
//
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not.
//
// In addition, ParseConfig accepts the following options:
//
// - servicefile.
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
var parseConfigOptions ParseConfigOptions
return ParseConfigWithOptions(connString, parseConfigOptions)
}
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
// get the SSL password.
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
defaultSettings := defaultSettings()
envSettings := parseEnvSettings()
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or in PostgreSQL keyword/value format
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseKeywordValueSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
}
}
}
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
}
config := &Config{
createdByParseConfig: true,
Database: settings["database"],
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else {
defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext
}
config.LookupFunc = makeDefaultResolver().LookupHost
notRuntimeParams := map[string]struct{}{
"host": {},
"port": {},
"database": {},
"user": {},
"password": {},
"passfile": {},
"connect_timeout": {},
"sslmode": {},
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
"service": {},
"servicefile": {},
}
// Adding kerberos configuration
if _, present := settings["krbsrvname"]; present {
config.KerberosSrvName = settings["krbsrvname"]
}
if _, present := settings["krbspn"]; present {
config.KerberosSpn = settings["krbspn"]
}
for k, v := range settings {
if _, present := notRuntimeParams[k]; present {
continue
}
config.RuntimeParams[k] = v
}
fallbacks := []*FallbackConfig{}
hosts := strings.Split(settings["host"], ",")
ports := strings.Split(settings["port"], ",")
for i, host := range hosts {
var portStr string
if i < len(ports) {
portStr = ports[i]
} else {
portStr = ports[0]
}
port, err := parsePort(portStr)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
// Ignore TLS settings if Unix domain socket like libpq
if network, _ := NetworkAddress(host, port); network == "unix" {
tlsConfigs = append(tlsConfigs, nil)
} else {
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
}
}
for _, tlsConfig := range tlsConfigs {
fallbacks = append(fallbacks, &FallbackConfig{
Host: host,
Port: port,
TLSConfig: tlsConfig,
})
}
}
config.Host = fallbacks[0].Host
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {
if config.Password == "" {
host := config.Host
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
host = "localhost"
}
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
}
}
switch tsa := settings["target_session_attrs"]; tsa {
case "read-write":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
case "read-only":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
case "primary":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
case "standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
case "prefer-standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
case "any":
// do nothing
default:
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
}
func mergeSettings(settingSets ...map[string]string) map[string]string {
settings := make(map[string]string)
for _, s2 := range settingSets {
for k, v := range s2 {
settings[k] = v
}
}
return settings
}
func parseEnvSettings() map[string]string {
settings := make(map[string]string)
nameMap := map[string]string{
"PGHOST": "host",
"PGPORT": "port",
"PGDATABASE": "database",
"PGUSER": "user",
"PGPASSWORD": "password",
"PGPASSFILE": "passfile",
"PGAPPNAME": "application_name",
"PGCONNECT_TIMEOUT": "connect_timeout",
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLSNI": "sslsni",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
}
for envname, realname := range nameMap {
value := os.Getenv(envname)
if value != "" {
settings[realname] = value
}
}
return settings
}
func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string)
url, err := url.Parse(connString)
if err != nil {
return nil, err
}
if url.User != nil {
settings["user"] = url.User.Username()
if password, present := url.User.Password(); present {
settings["password"] = password
}
}
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string
var ports []string
for _, host := range strings.Split(url.Host, ",") {
if host == "" {
continue
}
if isIPOnly(host) {
hosts = append(hosts, strings.Trim(host, "[]"))
continue
}
h, p, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
}
if h != "" {
hosts = append(hosts, h)
}
if p != "" {
ports = append(ports, p)
}
}
if len(hosts) > 0 {
settings["host"] = strings.Join(hosts, ",")
}
if len(ports) > 0 {
settings["port"] = strings.Join(ports, ",")
}
database := strings.TrimLeft(url.Path, "/")
if database != "" {
settings["database"] = database
}
nameMap := map[string]string{
"dbname": "database",
}
for k, v := range url.Query() {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v[0]
}
return settings, nil
}
func isIPOnly(host string) bool {
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseKeywordValueSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
"dbname": "database",
}
for len(s) > 0 {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid keyword/value")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
if len(s) == 0 {
} else if s[0] != '\'' {
end := 0
for ; end < len(s); end++ {
if asciiSpace[s[end]] == 1 {
break
}
if s[end] == '\\' {
end++
if end == len(s) {
return nil, errors.New("invalid backslash")
}
}
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
} else { // quoted string
s = s[1:]
end := 0
for ; end < len(s); end++ {
if s[end] == '\'' {
break
}
if s[end] == '\\' {
end++
}
}
if end == len(s) {
return nil, errors.New("unterminated quoted string in connection info string")
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
}
if k, ok := nameMap[key]; ok {
key = k
}
if key == "" {
return nil, errors.New("invalid keyword/value")
}
settings[key] = val
}
return settings, nil
}
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
if err != nil {
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
}
service, err := servicefile.GetService(serviceName)
if err != nil {
return nil, fmt.Errorf("unable to find service: %v", serviceName)
}
nameMap := map[string]string{
"dbname": "database",
}
settings := make(map[string]string, len(service.Settings))
for k, v := range service.Settings {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v
}
return settings, nil
}
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
// necessary to allow returning multiple TLS configs as sslmode "allow" and
// "prefer" allow fallback.
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
host := thisHost
sslmode := settings["sslmode"]
sslrootcert := settings["sslrootcert"]
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
sslsni := settings["sslsni"]
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
}
if sslsni == "" {
sslsni = "1"
}
tlsConfig := &tls.Config{}
switch sslmode {
case "disable":
return []*tls.Config{nil}, nil
case "allow", "prefer":
tlsConfig.InsecureSkipVerify = true
case "require":
// According to PostgreSQL documentation, if a root CA file exists,
// the behavior of sslmode=require should be the same as that of verify-ca
//
// See https://www.postgresql.org/docs/12/libpq-ssl.html
if sslrootcert != "" {
goto nextCase
}
tlsConfig.InsecureSkipVerify = true
break
nextCase:
fallthrough
case "verify-ca":
// Don't perform the default certificate verification because it
// will verify the hostname. Instead, verify the server's
// certificate chain ourselves in VerifyPeerCertificate and
// ignore the server name. This emulates libpq's verify-ca
// behavior.
//
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
// for more info.
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(certificates))
for i, asn1Data := range certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return errors.New("failed to parse certificate from server: " + err.Error())
}
certs[i] = cert
}
// Leave DNSName empty to skip hostname verification.
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
Intermediates: x509.NewCertPool(),
}
// Skip the first cert because it's the leaf. All others
// are intermediates.
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}
case "verify-full":
tlsConfig.ServerName = host
default:
return nil, errors.New("sslmode is invalid")
}
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
caPath := sslrootcert
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
buf, err := os.ReadFile(sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte
var decryptedKey []byte
var decryptedError error
// If PEM is encrypted, attempt to decrypt using pass phrase
if x509.IsEncryptedPEMBlock(block) {
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
if sslpassword != "" {
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
}
//if sslpassword not provided or has decryption error when use it
//try to find sslpassword with callback function
if sslpassword == "" || decryptedError != nil {
if parseConfigOptions.GetSSLPassword != nil {
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
}
if sslpassword == "" {
return nil, fmt.Errorf("unable to find sslpassword")
}
}
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
// Should we also provide warning for PKCS#1 needed?
if decryptedError != nil {
return nil, fmt.Errorf("unable to decrypt key: %w", err)
}
pemBytes := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: decryptedKey,
}
pemKey = pem.EncodeToMemory(&pemBytes)
} else {
pemKey = pem.EncodeToMemory(block)
}
certfile, err := os.ReadFile(sslcert)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
cert, err := tls.X509KeyPair(certfile, pemKey)
if err != nil {
return nil, fmt.Errorf("unable to load cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
// Set Server Name Indication (SNI), if enabled by connection parameters.
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
// or IPv6).
if sslsni == "1" && net.ParseIP(host) == nil {
tlsConfig.ServerName = host
}
switch sslmode {
case "allow":
return []*tls.Config{nil, tlsConfig}, nil
case "prefer":
return []*tls.Config{tlsConfig, nil}, nil
case "require", "verify-ca", "verify-full":
return []*tls.Config{tlsConfig}, nil
default:
panic("BUG: bad sslmode should already have been caught")
}
}
func parsePort(s string) (uint16, error) {
port, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return 0, err
}
if port < 1 || port > math.MaxUint16 {
return 0, errors.New("outside range")
}
return uint16(port), nil
}
func makeDefaultDialer() *net.Dialer {
// rely on GOLANG KeepAlive settings
return &net.Dialer{}
}
func makeDefaultResolver() *net.Resolver {
return net.DefaultResolver
}
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, err
}
if timeout < 0 {
return 0, errors.New("negative timeout")
}
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer()
d.Timeout = timeout
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) == "on" {
return errors.New("read only connection")
}
return nil
}
// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "on" {
return errors.New("connection is not read only")
}
return nil
}
// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "t" {
return errors.New("server is not in hot standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) == "t" {
return errors.New("server is in standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "t" {
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
}
return nil
}

View File

@ -0,0 +1,80 @@
package ctxwatch
import (
"context"
"sync"
)
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
handler Handler
unwatchChan chan struct{}
lock sync.Mutex
watchInProgress bool
onCancelWasCalled bool
}
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(handler Handler) *ContextWatcher {
cw := &ContextWatcher{
handler: handler,
unwatchChan: make(chan struct{}),
}
return cw
}
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
func (cw *ContextWatcher) Watch(ctx context.Context) {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
panic("Watch already in progress")
}
cw.onCancelWasCalled = false
if ctx.Done() != nil {
cw.watchInProgress = true
go func() {
select {
case <-ctx.Done():
cw.handler.HandleCancel(ctx)
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
}
}()
} else {
cw.watchInProgress = false
}
}
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
// called then onUnwatchAfterCancel will also be called.
func (cw *ContextWatcher) Unwatch() {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.handler.HandleUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}
type Handler interface {
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
// context that was canceled.
HandleCancel(canceledCtx context.Context)
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
HandleUnwatchAfterCancel()
}

63
vendor/github.com/jackc/pgx/v5/pgconn/defaults.go generated vendored Normal file
View File

@ -0,0 +1,63 @@
//go:build !windows
// +build !windows
package pgconn
import (
"os"
"os/user"
"path/filepath"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
if err == nil {
settings["user"] = user.Username
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
candidatePaths := []string{
"/var/run/postgresql", // Debian
"/private/tmp", // OSX - homebrew
"/tmp", // standard PostgreSQL
}
for _, path := range candidatePaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return "localhost"
}

View File

@ -0,0 +1,57 @@
package pgconn
import (
"os"
"os/user"
"path/filepath"
"strings"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
appData := os.Getenv("APPDATA")
if err == nil {
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
// but the libpq default is just the `user` portion, so we strip off the first part.
username := user.Username
if strings.Contains(username, "\\") {
username = username[strings.LastIndex(username, "\\")+1:]
}
settings["user"] = username
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
return "localhost"
}

38
vendor/github.com/jackc/pgx/v5/pgconn/doc.go generated vendored Normal file
View File

@ -0,0 +1,38 @@
// Package pgconn is a low-level PostgreSQL database driver.
/*
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
nearly the same level is the C library libpq.
Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
environment for libpq style environment variables.
Executing a Query
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
reads all rows into memory.
Executing Multiple Queries in a Single Round Trip
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
result. The ReadAll method reads all query results into memory.
Pipeline Mode
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
exactly how many and when network round trips occur.
Context Support
All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
interrupting the query in such a way as to close the connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/
package pgconn

248
vendor/github.com/jackc/pgx/v5/pgconn/errors.go generated vendored Normal file
View File

@ -0,0 +1,248 @@
package pgconn
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"regexp"
"strings"
)
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
var retryableErr interface{ SafeToRetry() bool }
if errors.As(err, &retryableErr) {
return retryableErr.SafeToRetry()
}
return false
}
// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
var timeoutErr *errTimeout
return errors.As(err, &timeoutErr)
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
SeverityUnlocalized string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// SQLState returns the SQLState of the error.
func (pe *PgError) SQLState() string {
return pe.Code
}
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
err error
}
func (e *ConnectError) Error() string {
prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
details := e.err.Error()
if strings.Contains(details, "\n") {
return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
} else {
return prefix + " " + details
}
}
func (e *ConnectError) Unwrap() error {
return e.err
}
type perDialConnectError struct {
address string
originalHostname string
err error
}
func (e *perDialConnectError) Error() string {
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
}
func (e *perDialConnectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
// ParseConfigError is the error returned when a connection string cannot be parsed.
type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string
err error
}
func (e *ParseConfigError) Error() string {
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
connString := redactPW(e.ConnString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *ParseConfigError) Unwrap() error {
return e.err
}
func normalizeTimeoutError(ctx context.Context, err error) error {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if ctx.Err() == context.Canceled {
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
return context.Canceled
} else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()}
} else {
return &errTimeout{err: netErr}
}
}
return err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
type errTimeout struct {
err error
}
func (e *errTimeout) Error() string {
return fmt.Sprintf("timeout: %s", e.err.Error())
}
func (e *errTimeout) SafeToRetry() bool {
return SafeToRetry(e.err)
}
func (e *errTimeout) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
func newContextAlreadyDoneError(ctx context.Context) (err error) {
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
}
func redactPW(connString string) string {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
if u, err := url.Parse(connString); err == nil {
return redactURL(u)
}
}
quotedKV := regexp.MustCompile(`password='[^']*'`)
connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
plainKV := regexp.MustCompile(`password=[^ ]*`)
connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString
}
func redactURL(u *url.URL) string {
if u == nil {
return ""
}
if _, pwSet := u.User.Password(); pwSet {
u.User = url.UserPassword(u.User.Username(), "xxxxx")
}
return u.String()
}
type NotPreferredError struct {
err error
safeToRetry bool
}
func (e *NotPreferredError) Error() string {
return fmt.Sprintf("standby server not found: %s", e.err.Error())
}
func (e *NotPreferredError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *NotPreferredError) Unwrap() error {
return e.err
}

View File

@ -0,0 +1,139 @@
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
package bgreader
import (
"io"
"sync"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
const (
StatusStopped = iota
StatusRunning
StatusStopping
)
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
type BGReader struct {
r io.Reader
cond *sync.Cond
status int32
readResults []readResult
}
type readResult struct {
buf *[]byte
err error
}
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
// reader will stop automatically when the underlying reader returns an error.
func (r *BGReader) Start() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
r.status = StatusRunning
go r.bgRead()
case StatusRunning:
// no-op
case StatusStopping:
r.status = StatusRunning
}
}
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
// background reader is not running.
func (r *BGReader) Stop() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
// no-op
case StatusRunning:
r.status = StatusStopping
case StatusStopping:
// no-op
}
}
// Status returns the current status of the background reader.
func (r *BGReader) Status() int32 {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.status
}
func (r *BGReader) bgRead() {
keepReading := true
for keepReading {
buf := iobufpool.Get(8192)
n, err := r.r.Read(*buf)
*buf = (*buf)[:n]
r.cond.L.Lock()
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
if r.status == StatusStopping || err != nil {
r.status = StatusStopped
keepReading = false
}
r.cond.L.Unlock()
r.cond.Broadcast()
}
}
// Read implements the io.Reader interface.
func (r *BGReader) Read(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
if len(r.readResults) > 0 {
return r.readFromReadResults(p)
}
// There are no unread background read results and the background reader is stopped.
if r.status == StatusStopped {
return r.r.Read(p)
}
// Wait for results from the background reader
for len(r.readResults) == 0 {
r.cond.Wait()
}
return r.readFromReadResults(p)
}
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
buf := r.readResults[0].buf
var err error
n := copy(p, *buf)
if n == len(*buf) {
err = r.readResults[0].err
iobufpool.Put(buf)
if len(r.readResults) == 1 {
r.readResults = nil
} else {
r.readResults = r.readResults[1:]
}
} else {
*buf = (*buf)[n:]
r.readResults[0].buf = buf
}
return n, err
}
func New(r io.Reader) *BGReader {
return &BGReader{
r: r,
cond: &sync.Cond{
L: &sync.Mutex{},
},
}
}

100
vendor/github.com/jackc/pgx/v5/pgconn/krb5.go generated vendored Normal file
View File

@ -0,0 +1,100 @@
package pgconn
import (
"errors"
"fmt"
"github.com/jackc/pgx/v5/pgproto3"
)
// NewGSSFunc creates a GSS authentication provider, for use with
// RegisterGSSProvider.
type NewGSSFunc func() (GSS, error)
var newGSS NewGSSFunc
// RegisterGSSProvider registers a GSS authentication provider. For example, if
// you need to use Kerberos to authenticate with your server, add this to your
// main package:
//
// import "github.com/otan/gopgkrb5"
//
// func init() {
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
// }
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
newGSS = newGSSArg
}
// GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface {
GetInitToken(host string, service string) ([]byte, error)
GetInitTokenFromSPN(spn string) ([]byte, error)
Continue(inToken []byte) (done bool, outToken []byte, err error)
}
func (c *PgConn) gssAuth() error {
if newGSS == nil {
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
}
cli, err := newGSS()
if err != nil {
return err
}
var nextData []byte
if c.config.KerberosSpn != "" {
// Use the supplied SPN if provided.
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
} else {
// Allow the kerberos service name to be overridden
service := "postgres"
if c.config.KerberosSrvName != "" {
service = c.config.KerberosSrvName
}
nextData, err = cli.GetInitToken(c.config.Host, service)
}
if err != nil {
return err
}
for {
gssResponse := &pgproto3.GSSResponse{
Data: nextData,
}
c.frontend.Send(gssResponse)
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
resp, err := c.rxGSSContinue()
if err != nil {
return err
}
var done bool
done, nextData, err = cli.Continue(resp.Data)
if err != nil {
return err
}
if done {
break
}
}
return nil
}
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationGSSContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
}

2346
vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

7
vendor/github.com/jackc/pgx/v5/pgproto3/README.md generated vendored Normal file
View File

@ -0,0 +1,7 @@
# pgproto3
Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
See example/pgfortune for a playful example of a fake PostgreSQL server.

View File

@ -0,0 +1,51 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
type AuthenticationCleartextPassword struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeCleartextPassword {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationCleartextPassword",
})
}

View File

@ -0,0 +1,58 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type AuthenticationGSS struct{}
func (a *AuthenticationGSS) Backend() {}
func (a *AuthenticationGSS) AuthenticationResponse() {}
func (a *AuthenticationGSS) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSS {
return errors.New("bad auth type")
}
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return finishMessage(dst, sp)
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSS",
})
}
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,67 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type AuthenticationGSSContinue struct {
Data []byte
}
func (a *AuthenticationGSSContinue) Backend() {}
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSSCont {
return errors.New("bad auth type")
}
a.Data = src[4:]
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return finishMessage(dst, sp)
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSSContinue",
Data: a.Data,
})
}
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
a.Data = msg.Data
return nil
}

View File

@ -0,0 +1,76 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
type AuthenticationMD5Password struct {
Salt [4]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationMD5Password) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
if len(src) != 8 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeMD5Password {
return errors.New("bad auth type")
}
copy(dst.Salt[:], src[4:8])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Salt [4]byte
}{
Type: "AuthenticationMD5Password",
Salt: src.Salt,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Salt [4]byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Salt = msg.Salt
return nil
}

View File

@ -0,0 +1,51 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
type AuthenticationOk struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationOk) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeOk {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeOk)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationOK",
})
}

View File

@ -0,0 +1,72 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
type AuthenticationSASL struct {
AuthMechanisms []string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASL) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASL {
return errors.New("bad auth type")
}
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx == -1 {
return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"}
}
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
AuthMechanisms []string
}{
Type: "AuthenticationSASL",
AuthMechanisms: src.AuthMechanisms,
})
}

View File

@ -0,0 +1,75 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
type AuthenticationSASLContinue struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLContinue {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLContinue",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

View File

@ -0,0 +1,75 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
type AuthenticationSASLFinal struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLFinal {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Unmarshaler.
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLFinal",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

292
vendor/github.com/jackc/pgx/v5/pgproto3/backend.go generated vendored Normal file
View File

@ -0,0 +1,292 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"fmt"
"io"
)
// Backend acts as a server for the PostgreSQL wire protocol version 3.
type Backend struct {
cr *chunkReader
w io.Writer
// tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
// before it is actually transmitted (i.e. before Flush).
tracer *tracer
wbuf []byte
encodeError error
// Frontend message flyweights
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
copyData CopyData
copyDone CopyDone
describe Describe
execute Execute
flush Flush
functionCall FunctionCall
gssEncRequest GSSEncRequest
parse Parse
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
msgType byte
partialMsg bool
authType uint32
}
const (
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
)
// NewBackend creates a new Backend.
func NewBackend(r io.Reader, w io.Writer) *Backend {
cr := newChunkReader(r, 0)
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
func (b *Backend) Send(msg BackendMessage) {
if b.encodeError != nil {
return
}
prevLen := len(b.wbuf)
newBuf, err := msg.Encode(b.wbuf)
if err != nil {
b.encodeError = err
return
}
b.wbuf = newBuf
if b.tracer != nil {
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
}
}
// Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error {
if err := b.encodeError; err != nil {
b.encodeError = nil
b.wbuf = b.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
n, err := b.w.Write(b.wbuf)
const maxLen = 1024
if len(b.wbuf) > maxLen {
b.wbuf = make([]byte, 0, maxLen)
} else {
b.wbuf = b.wbuf[:0]
}
if err != nil {
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function
// PQtrace.
func (b *Backend) Trace(w io.Writer, options TracerOptions) {
b.tracer = &tracer{
w: w,
buf: &bytes.Buffer{},
TracerOptions: options,
}
}
// Untrace stops tracing.
func (b *Backend) Untrace() {
b.tracer = nil
}
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
// because the initial connection message is "special" and does not include the message type as the first byte. This
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
buf, err := b.cr.Next(4)
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
}
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
code := binary.BigEndian.Uint32(buf)
switch code {
case ProtocolVersionNumber:
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
}
return &b.startupMessage, nil
case sslRequestNumber:
err = b.sslRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.sslRequest, nil
case cancelRequestCode:
err = b.cancelRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.cancelRequest, nil
case gssEncReqNumber:
err = b.gssEncRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.gssEncRequest, nil
default:
return nil, fmt.Errorf("unknown startup message code: %d", code)
}
}
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
func (b *Backend) Receive() (FrontendMessage, error) {
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
}
b.partialMsg = true
}
var msg FrontendMessage
switch b.msgType {
case 'B':
msg = &b.bind
case 'C':
msg = &b._close
case 'D':
msg = &b.describe
case 'E':
msg = &b.execute
case 'F':
msg = &b.functionCall
case 'f':
msg = &b.copyFail
case 'd':
msg = &b.copyData
case 'c':
msg = &b.copyDone
case 'H':
msg = &b.flush
case 'P':
msg = &b.parse
case 'p':
switch b.authType {
case AuthTypeSASL:
msg = &SASLInitialResponse{}
case AuthTypeSASLContinue:
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
msg = &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
// to maintain backwards compatibility
msg = &PasswordMessage{}
}
case 'Q':
msg = &b.query
case 'S':
msg = &b.sync
case 'X':
msg = &b.terminate
default:
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.partialMsg = false
err = msg.Decode(msgBody)
if err != nil {
return nil, err
}
if b.tracer != nil {
b.tracer.traceMessage('F', int32(5+len(msgBody)), msg)
}
return msg, nil
}
// SetAuthType sets the authentication type in the backend.
// Since multiple message types can start with 'p', SetAuthType allows
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
func (b *Backend) SetAuthType(authType uint32) error {
switch authType {
case AuthTypeOk,
AuthTypeCleartextPassword,
AuthTypeMD5Password,
AuthTypeSCMCreds,
AuthTypeGSS,
AuthTypeGSSCont,
AuthTypeSSPI,
AuthTypeSASL,
AuthTypeSASLContinue,
AuthTypeSASLFinal:
b.authType = authType
default:
return fmt.Errorf("authType not recognized: %d", authType)
}
return nil
}
// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return
// an error. This is useful for protecting against malicious clients that send large messages with the intent of
// causing memory exhaustion.
// The default value is 0.
// If maxBodyLen is 0, then no maximum is enforced.
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
b.maxBodyLen = maxBodyLen
}

View File

@ -0,0 +1,50 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type BackendKeyData struct {
ProcessID uint32
SecretKey uint32
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*BackendKeyData) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *BackendKeyData) Decode(src []byte) error {
if len(src) != 8 {
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
}
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "BackendKeyData",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

37
vendor/github.com/jackc/pgx/v5/pgproto3/big_endian.go generated vendored Normal file
View File

@ -0,0 +1,37 @@
package pgproto3
import (
"encoding/binary"
)
type BigEndianBuf [8]byte
func (b BigEndianBuf) Int16(n int16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, uint16(n))
return buf
}
func (b BigEndianBuf) Uint16(n uint16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, n)
return buf
}
func (b BigEndianBuf) Int32(n int32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, uint32(n))
return buf
}
func (b BigEndianBuf) Uint32(n uint32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, n)
return buf
}
func (b BigEndianBuf) Int64(n int64) []byte {
buf := b[0:8]
binary.BigEndian.PutUint64(buf, uint64(n))
return buf
}

223
vendor/github.com/jackc/pgx/v5/pgproto3/bind.go generated vendored Normal file
View File

@ -0,0 +1,223 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Bind struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters [][]byte
ResultFormatCodes []int16
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Bind) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Bind) Decode(src []byte) error {
*dst = Bind{}
idx := bytes.IndexByte(src, 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.DestinationPortal = string(src[:idx])
rp := idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.PreparedStatement = string(src[rp : rp+idx])
rp += idx + 1
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterFormatCodeCount > 0 {
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < parameterFormatCodeCount; i++ {
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterCount > 0 {
dst.Parameters = make([][]byte, parameterCount)
for i := 0; i < parameterCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if msgSize == -1 {
continue
}
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.Parameters[i] = src[rp : rp+msgSize]
rp += msgSize
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < resultFormatCodeCount; i++ {
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'B')
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(p)))
dst = append(dst, p...)
}
if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Bind) MarshalJSON() ([]byte, error) {
formattedParameters := make([]map[string]string, len(src.Parameters))
for i, p := range src.Parameters {
if p == nil {
continue
}
textFormat := true
if len(src.ParameterFormatCodes) == 1 {
textFormat = src.ParameterFormatCodes[0] == 0
} else if len(src.ParameterFormatCodes) > 1 {
textFormat = src.ParameterFormatCodes[i] == 0
}
if textFormat {
formattedParameters[i] = map[string]string{"text": string(p)}
} else {
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
}
}
return json.Marshal(struct {
Type string
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}{
Type: "Bind",
DestinationPortal: src.DestinationPortal,
PreparedStatement: src.PreparedStatement,
ParameterFormatCodes: src.ParameterFormatCodes,
Parameters: formattedParameters,
ResultFormatCodes: src.ResultFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Bind) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.DestinationPortal = msg.DestinationPortal
dst.PreparedStatement = msg.PreparedStatement
dst.ParameterFormatCodes = msg.ParameterFormatCodes
dst.Parameters = make([][]byte, len(msg.Parameters))
dst.ResultFormatCodes = msg.ResultFormatCodes
for n, parameter := range msg.Parameters {
dst.Parameters[n], err = getValueFromJSON(parameter)
if err != nil {
return fmt.Errorf("cannot get param %d: %w", n, err)
}
}
return nil
}

View File

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type BindComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*BindComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *BindComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src BindComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "BindComplete",
})
}

View File

@ -0,0 +1,58 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
const cancelRequestCode = 80877102
type CancelRequest struct {
ProcessID uint32
SecretKey uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CancelRequest) Frontend() {}
func (dst *CancelRequest) Decode(src []byte) error {
if len(src) != 12 {
return errors.New("bad cancel request size")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != cancelRequestCode {
return errors.New("bad cancel request code")
}
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CancelRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "CancelRequest",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

90
vendor/github.com/jackc/pgx/v5/pgproto3/chunkreader.go generated vendored Normal file
View File

@ -0,0 +1,90 @@
package pgproto3
import (
"io"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
// chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
// requested. The memory returned via Next is only valid until the next call to Next.
//
// This is roughly equivalent to a bufio.Reader that only uses Peek and Discard to never copy bytes.
type chunkReader struct {
r io.Reader
buf *[]byte
rp, wp int // buf read position and write position
minBufSize int
}
// newChunkReader creates and returns a new chunkReader for r with default configuration. If minBufSize is <= 0 it uses
// a default value.
func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
if minBufSize <= 0 {
// By historical reasons Postgres currently has 8KB send buffer inside,
// so here we want to have at least the same size buffer.
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
//
// In addition, testing has found no benefit of any larger buffer.
minBufSize = 8192
}
return &chunkReader{
r: r,
minBufSize: minBufSize,
buf: iobufpool.Get(minBufSize),
}
}
// Next returns buf filled with the next n bytes. buf is only valid until next call of Next. If an error occurs, buf
// will be nil.
func (r *chunkReader) Next(n int) (buf []byte, err error) {
// Reset the buffer if it is empty
if r.rp == r.wp {
if len(*r.buf) != r.minBufSize {
iobufpool.Put(r.buf)
r.buf = iobufpool.Get(r.minBufSize)
}
r.rp = 0
r.wp = 0
}
// n bytes already in buf
if (r.wp - r.rp) >= n {
buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
r.rp += n
return buf, err
}
// buf is smaller than requested number of bytes
if len(*r.buf) < n {
bigBuf := iobufpool.Get(n)
r.wp = copy((*bigBuf), (*r.buf)[r.rp:r.wp])
r.rp = 0
iobufpool.Put(r.buf)
r.buf = bigBuf
}
// buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp)
if (len(*r.buf) - r.wp) < minReadCount {
r.wp = copy((*r.buf), (*r.buf)[r.rp:r.wp])
r.rp = 0
}
// Read at least the required number of bytes from the underlying io.Reader
readBytesCount, err := io.ReadAtLeast(r.r, (*r.buf)[r.wp:], minReadCount)
r.wp += readBytesCount
// fmt.Println("read", n)
if err != nil {
return nil, err
}
buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
r.rp += n
return buf, nil
}

81
vendor/github.com/jackc/pgx/v5/pgproto3/close.go generated vendored Normal file
View File

@ -0,0 +1,81 @@
package pgproto3
import (
"bytes"
"encoding/json"
"errors"
)
type Close struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Close) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Close) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Close) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Close",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Close) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Close.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

View File

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type CloseComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CloseComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CloseComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '3', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CloseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CloseComplete",
})
}

View File

@ -0,0 +1,66 @@
package pgproto3
import (
"bytes"
"encoding/json"
)
type CommandComplete struct {
CommandTag []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CommandComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CommandComplete) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx == -1 {
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"}
}
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"}
}
dst.CommandTag = src[:idx]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CommandComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
CommandTag string
}{
Type: "CommandComplete",
CommandTag: string(src.CommandTag),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
CommandTag string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.CommandTag = []byte(msg.CommandTag)
return nil
}

View File

@ -0,0 +1,95 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyBothResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyBothResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyBothResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'W')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyBothResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyBothResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

59
vendor/github.com/jackc/pgx/v5/pgproto3/copy_data.go generated vendored Normal file
View File

@ -0,0 +1,59 @@
package pgproto3
import (
"encoding/hex"
"encoding/json"
)
type CopyData struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyData) Backend() {}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyData) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyData) Decode(src []byte) error {
dst.Data = src
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'd')
dst = append(dst, src.Data...)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "CopyData",
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyData) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

38
vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go generated vendored Normal file
View File

@ -0,0 +1,38 @@
package pgproto3
import (
"encoding/json"
)
type CopyDone struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyDone) Backend() {}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyDone) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyDone) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
return append(dst, 'c', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyDone) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CopyDone",
})
}

45
vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go generated vendored Normal file
View File

@ -0,0 +1,45 @@
package pgproto3
import (
"bytes"
"encoding/json"
)
type CopyFail struct {
Message string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyFail) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyFail) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CopyFail"}
}
dst.Message = string(src[:idx])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'f')
dst = append(dst, src.Message...)
dst = append(dst, 0)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyFail) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Message string
}{
Type: "CopyFail",
Message: src.Message,
})
}

View File

@ -0,0 +1,96 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyInResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyInResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyInResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'G')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyInResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyInResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyInResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

View File

@ -0,0 +1,96 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyOutResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyOutResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyOutResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'H')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyOutResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyOutResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

143
vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go generated vendored Normal file
View File

@ -0,0 +1,143 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type DataRow struct {
Values [][]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*DataRow) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *DataRow) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
rp := 0
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
// If the capacity of the values slice is too small OR substantially too
// large reallocate. This is too avoid one row with many columns from
// permanently allocating memory.
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
newCap := 32
if newCap < fieldCount {
newCap = fieldCount
}
dst.Values = make([][]byte, fieldCount, newCap)
} else {
dst.Values = dst.Values[:fieldCount]
}
for i := 0; i < fieldCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if valueLen == -1 {
dst.Values[i] = nil
} else {
if len(src[rp:]) < valueLen || valueLen < 0 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
dst.Values[i] = src[rp : rp+valueLen : rp+valueLen]
rp += valueLen
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(v)))
dst = append(dst, v...)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src DataRow) MarshalJSON() ([]byte, error) {
formattedValues := make([]map[string]string, len(src.Values))
for i, v := range src.Values {
if v == nil {
continue
}
var hasNonPrintable bool
for _, b := range v {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
} else {
formattedValues[i] = map[string]string{"text": string(v)}
}
}
return json.Marshal(struct {
Type string
Values []map[string]string
}{
Type: "DataRow",
Values: formattedValues,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *DataRow) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Values []map[string]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Values = make([][]byte, len(msg.Values))
for n, parameter := range msg.Values {
var err error
dst.Values[n], err = getValueFromJSON(parameter)
if err != nil {
return err
}
}
return nil
}

80
vendor/github.com/jackc/pgx/v5/pgproto3/describe.go generated vendored Normal file
View File

@ -0,0 +1,80 @@
package pgproto3
import (
"bytes"
"encoding/json"
"errors"
)
type Describe struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Describe) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Describe) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Describe) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Describe",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Describe) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Describe.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

11
vendor/github.com/jackc/pgx/v5/pgproto3/doc.go generated vendored Normal file
View File

@ -0,0 +1,11 @@
// Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
//
// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are
// sent with Send (or a specialized Send variant). Messages are automatically buffered to minimize small writes. Call
// Flush to ensure a message has actually been sent.
//
// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a
// similar format to the PQtrace function in libpq.
//
// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages.
package pgproto3

View File

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type EmptyQueryResponse struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*EmptyQueryResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *EmptyQueryResponse) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, 'I', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "EmptyQueryResponse",
})
}

View File

@ -0,0 +1,326 @@
package pgproto3
import (
"bytes"
"encoding/json"
"strconv"
)
type ErrorResponse struct {
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ErrorResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ErrorResponse) Decode(src []byte) error {
*dst = ErrorResponse{}
buf := bytes.NewBuffer(src)
for {
k, err := buf.ReadByte()
if err != nil {
return err
}
if k == 0 {
break
}
vb, err := buf.ReadBytes(0)
if err != nil {
return err
}
v := string(vb[:len(vb)-1])
switch k {
case 'S':
dst.Severity = v
case 'V':
dst.SeverityUnlocalized = v
case 'C':
dst.Code = v
case 'M':
dst.Message = v
case 'D':
dst.Detail = v
case 'H':
dst.Hint = v
case 'P':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Position = int32(n)
case 'p':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.InternalPosition = int32(n)
case 'q':
dst.InternalQuery = v
case 'W':
dst.Where = v
case 's':
dst.SchemaName = v
case 't':
dst.TableName = v
case 'c':
dst.ColumnName = v
case 'd':
dst.DataTypeName = v
case 'n':
dst.ConstraintName = v
case 'F':
dst.File = v
case 'L':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Line = int32(n)
case 'R':
dst.Routine = v
default:
if dst.UnknownFields == nil {
dst.UnknownFields = make(map[byte]string)
}
dst.UnknownFields[k] = v
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
dst = src.appendFields(dst)
return finishMessage(dst, sp)
}
func (src *ErrorResponse) appendFields(dst []byte) []byte {
if src.Severity != "" {
dst = append(dst, 'S')
dst = append(dst, src.Severity...)
dst = append(dst, 0)
}
if src.SeverityUnlocalized != "" {
dst = append(dst, 'V')
dst = append(dst, src.SeverityUnlocalized...)
dst = append(dst, 0)
}
if src.Code != "" {
dst = append(dst, 'C')
dst = append(dst, src.Code...)
dst = append(dst, 0)
}
if src.Message != "" {
dst = append(dst, 'M')
dst = append(dst, src.Message...)
dst = append(dst, 0)
}
if src.Detail != "" {
dst = append(dst, 'D')
dst = append(dst, src.Detail...)
dst = append(dst, 0)
}
if src.Hint != "" {
dst = append(dst, 'H')
dst = append(dst, src.Hint...)
dst = append(dst, 0)
}
if src.Position != 0 {
dst = append(dst, 'P')
dst = append(dst, strconv.Itoa(int(src.Position))...)
dst = append(dst, 0)
}
if src.InternalPosition != 0 {
dst = append(dst, 'p')
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
dst = append(dst, 0)
}
if src.InternalQuery != "" {
dst = append(dst, 'q')
dst = append(dst, src.InternalQuery...)
dst = append(dst, 0)
}
if src.Where != "" {
dst = append(dst, 'W')
dst = append(dst, src.Where...)
dst = append(dst, 0)
}
if src.SchemaName != "" {
dst = append(dst, 's')
dst = append(dst, src.SchemaName...)
dst = append(dst, 0)
}
if src.TableName != "" {
dst = append(dst, 't')
dst = append(dst, src.TableName...)
dst = append(dst, 0)
}
if src.ColumnName != "" {
dst = append(dst, 'c')
dst = append(dst, src.ColumnName...)
dst = append(dst, 0)
}
if src.DataTypeName != "" {
dst = append(dst, 'd')
dst = append(dst, src.DataTypeName...)
dst = append(dst, 0)
}
if src.ConstraintName != "" {
dst = append(dst, 'n')
dst = append(dst, src.ConstraintName...)
dst = append(dst, 0)
}
if src.File != "" {
dst = append(dst, 'F')
dst = append(dst, src.File...)
dst = append(dst, 0)
}
if src.Line != 0 {
dst = append(dst, 'L')
dst = append(dst, strconv.Itoa(int(src.Line))...)
dst = append(dst, 0)
}
if src.Routine != "" {
dst = append(dst, 'R')
dst = append(dst, src.Routine...)
dst = append(dst, 0)
}
for k, v := range src.UnknownFields {
dst = append(dst, k)
dst = append(dst, v...)
dst = append(dst, 0)
}
dst = append(dst, 0)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ErrorResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}{
Type: "ErrorResponse",
Severity: src.Severity,
SeverityUnlocalized: src.SeverityUnlocalized,
Code: src.Code,
Message: src.Message,
Detail: src.Detail,
Hint: src.Hint,
Position: src.Position,
InternalPosition: src.InternalPosition,
InternalQuery: src.InternalQuery,
Where: src.Where,
SchemaName: src.SchemaName,
TableName: src.TableName,
ColumnName: src.ColumnName,
DataTypeName: src.DataTypeName,
ConstraintName: src.ConstraintName,
File: src.File,
Line: src.Line,
Routine: src.Routine,
UnknownFields: src.UnknownFields,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Severity = msg.Severity
dst.SeverityUnlocalized = msg.SeverityUnlocalized
dst.Code = msg.Code
dst.Message = msg.Message
dst.Detail = msg.Detail
dst.Hint = msg.Hint
dst.Position = msg.Position
dst.InternalPosition = msg.InternalPosition
dst.InternalQuery = msg.InternalQuery
dst.Where = msg.Where
dst.SchemaName = msg.SchemaName
dst.TableName = msg.TableName
dst.ColumnName = msg.ColumnName
dst.DataTypeName = msg.DataTypeName
dst.ConstraintName = msg.ConstraintName
dst.File = msg.File
dst.Line = msg.Line
dst.Routine = msg.Routine
dst.UnknownFields = msg.UnknownFields
return nil
}

58
vendor/github.com/jackc/pgx/v5/pgproto3/execute.go generated vendored Normal file
View File

@ -0,0 +1,58 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Execute struct {
Portal string
MaxRows uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Execute) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Execute) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Portal = string(b[:len(b)-1])
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Execute"}
}
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
dst = append(dst, src.Portal...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Execute) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Portal string
MaxRows uint32
}{
Type: "Execute",
Portal: src.Portal,
MaxRows: src.MaxRows,
})
}

34
vendor/github.com/jackc/pgx/v5/pgproto3/flush.go generated vendored Normal file
View File

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type Flush struct{}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Flush) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Flush) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Flush) Encode(dst []byte) ([]byte, error) {
return append(dst, 'H', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Flush) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "Flush",
})
}

454
vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go generated vendored Normal file
View File

@ -0,0 +1,454 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
)
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
type Frontend struct {
cr *chunkReader
w io.Writer
// tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
// before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is
// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
tracer *tracer
wbuf []byte
encodeError error
// Backend message flyweights
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationGSS AuthenticationGSS
authenticationGSSContinue AuthenticationGSSContinue
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
backendKeyData BackendKeyData
bindComplete BindComplete
closeComplete CloseComplete
commandComplete CommandComplete
copyBothResponse CopyBothResponse
copyData CopyData
copyInResponse CopyInResponse
copyOutResponse CopyOutResponse
copyDone CopyDone
dataRow DataRow
emptyQueryResponse EmptyQueryResponse
errorResponse ErrorResponse
functionCallResponse FunctionCallResponse
noData NoData
noticeResponse NoticeResponse
notificationResponse NotificationResponse
parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
// NewFrontend creates a new Frontend.
func NewFrontend(r io.Reader, w io.Writer) *Frontend {
cr := newChunkReader(r, 0)
return &Frontend{cr: cr, w: w}
}
// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
//
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
// behind an interface.
func (f *Frontend) Send(msg FrontendMessage) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// Flush writes any pending messages to the backend (i.e. the server).
func (f *Frontend) Flush() error {
if err := f.encodeError; err != nil {
f.encodeError = nil
f.wbuf = f.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
if len(f.wbuf) == 0 {
return nil
}
n, err := f.w.Write(f.wbuf)
const maxLen = 1024
if len(f.wbuf) > maxLen {
f.wbuf = make([]byte, 0, maxLen)
} else {
f.wbuf = f.wbuf[:0]
}
if err != nil {
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function
// PQtrace.
func (f *Frontend) Trace(w io.Writer, options TracerOptions) {
f.tracer = &tracer{
w: w,
buf: &bytes.Buffer{},
TracerOptions: options,
}
}
// Untrace stops tracing.
func (f *Frontend) Untrace() {
f.tracer = nil
}
// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendBind(msg *Bind) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendParse(msg *Parse) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendClose(msg *Close) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
// called. Any error encountered will be returned from Flush.
func (f *Frontend) SendDescribe(msg *Describe) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
// Any error encountered will be returned from Flush.
func (f *Frontend) SendExecute(msg *Execute) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendSync(msg *Sync) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendQuery(msg *Query) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method
// is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer
// before being written out. The internal buffer is flushed before the message is sent.
func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error {
err := f.Flush()
if err != nil {
return err
}
n, err := f.w.Write(msg)
if err != nil {
return &writeError{err: err, safeToRetry: n == 0}
}
if f.tracer != nil {
f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{})
}
return nil
}
func translateEOFtoErrUnexpectedEOF(err error) error {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
func (f *Frontend) Receive() (BackendMessage, error) {
if !f.partialMsg {
header, err := f.cr.Next(5)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
f.msgType = header[0]
msgLength := int(binary.BigEndian.Uint32(header[1:]))
if msgLength < 4 {
return nil, fmt.Errorf("invalid message length: %d", msgLength)
}
f.bodyLen = msgLength - 4
f.partialMsg = true
}
msgBody, err := f.cr.Next(f.bodyLen)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
f.partialMsg = false
var msg BackendMessage
switch f.msgType {
case '1':
msg = &f.parseComplete
case '2':
msg = &f.bindComplete
case '3':
msg = &f.closeComplete
case 'A':
msg = &f.notificationResponse
case 'c':
msg = &f.copyDone
case 'C':
msg = &f.commandComplete
case 'd':
msg = &f.copyData
case 'D':
msg = &f.dataRow
case 'E':
msg = &f.errorResponse
case 'G':
msg = &f.copyInResponse
case 'H':
msg = &f.copyOutResponse
case 'I':
msg = &f.emptyQueryResponse
case 'K':
msg = &f.backendKeyData
case 'n':
msg = &f.noData
case 'N':
msg = &f.noticeResponse
case 'R':
var err error
msg, err = f.findAuthenticationMessageType(msgBody)
if err != nil {
return nil, err
}
case 's':
msg = &f.portalSuspended
case 'S':
msg = &f.parameterStatus
case 't':
msg = &f.parameterDescription
case 'T':
msg = &f.rowDescription
case 'V':
msg = &f.functionCallResponse
case 'W':
msg = &f.copyBothResponse
case 'Z':
msg = &f.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
}
err = msg.Decode(msgBody)
if err != nil {
return nil, err
}
if f.tracer != nil {
f.tracer.traceMessage('B', int32(5+len(msgBody)), msg)
}
return msg, nil
}
// Authentication message type constants.
// See src/include/libpq/pqcomm.h for all
// constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSCMCreds = 6
AuthTypeGSS = 7
AuthTypeGSSCont = 8
AuthTypeSSPI = 9
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
if len(src) < 4 {
return nil, errors.New("authentication message too short")
}
f.authType = binary.BigEndian.Uint32(src[:4])
switch f.authType {
case AuthTypeOk:
return &f.authenticationOk, nil
case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return &f.authenticationGSS, nil
case AuthTypeGSSCont:
return &f.authenticationGSSContinue, nil
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:
return &f.authenticationSASL, nil
case AuthTypeSASLContinue:
return &f.authenticationSASLContinue, nil
case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
}
}
// GetAuthType returns the authType used in the current state of the frontend.
// See SetAuthType for more information.
func (f *Frontend) GetAuthType() uint32 {
return f.authType
}
func (f *Frontend) ReadBufferLen() int {
return f.cr.wp - f.cr.rp
}

View File

@ -0,0 +1,102 @@
package pgproto3
import (
"encoding/binary"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type FunctionCall struct {
Function uint32
ArgFormatCodes []uint16
Arguments [][]byte
ResultFormatCode uint16
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*FunctionCall) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *FunctionCall) Decode(src []byte) error {
*dst = FunctionCall{}
rp := 0
// Specifies the object ID of the function to call.
dst.Function = binary.BigEndian.Uint32(src[rp:])
rp += 4
// The number of argument format codes that follow (denoted C below).
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
// or one, in which case the specified format code is applied to all arguments;
// or it can equal the actual number of arguments.
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
argumentCodes := make([]uint16, nArgumentCodes)
for i := 0; i < nArgumentCodes; i++ {
// The argument format codes. Each must presently be zero (text) or one (binary).
ac := binary.BigEndian.Uint16(src[rp:])
if ac != 0 && ac != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
argumentCodes[i] = ac
rp += 2
}
dst.ArgFormatCodes = argumentCodes
// Specifies the number of arguments being supplied to the function.
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
arguments := make([][]byte, nArguments)
for i := 0; i < nArguments; i++ {
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if argumentLength == -1 {
arguments[i] = nil
} else {
// The value of the argument, in the format indicated by the associated format code. n is the above length.
argumentValue := src[rp : rp+argumentLength]
rp += argumentLength
arguments[i] = argumentValue
}
}
dst.Arguments = arguments
// The format code for the function result. Must presently be zero (text) or one (binary).
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
if resultFormatCode != 0 && resultFormatCode != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
dst.ResultFormatCode = resultFormatCode
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'F')
dst = pgio.AppendUint32(dst, src.Function)
if len(src.ArgFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many arg format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
if len(src.Arguments) > math.MaxUint16 {
return nil, errors.New("too many arguments")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(argument)))
dst = append(dst, argument...)
}
}
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
return finishMessage(dst, sp)
}

View File

@ -0,0 +1,97 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type FunctionCallResponse struct {
Result []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*FunctionCallResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *FunctionCallResponse) Decode(src []byte) error {
if len(src) < 4 {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
rp := 0
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if resultSize == -1 {
dst.Result = nil
return nil
}
if len(src[rp:]) != resultSize {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
dst.Result = src[rp:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'V')
if src.Result == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
dst = append(dst, src.Result...)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
var formattedValue map[string]string
var hasNonPrintable bool
for _, b := range src.Result {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
} else {
formattedValue = map[string]string{"text": string(src.Result)}
}
return json.Marshal(struct {
Type string
Result map[string]string
}{
Type: "FunctionCallResponse",
Result: formattedValue,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Result map[string]string
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.Result, err = getValueFromJSON(msg.Result)
return err
}

View File

@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
const gssEncReqNumber = 80877104
type GSSEncRequest struct {
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*GSSEncRequest) Frontend() {}
func (dst *GSSEncRequest) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("gss encoding request too short")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != gssEncReqNumber {
return errors.New("bad gss encoding request code")
}
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, gssEncReqNumber)
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src GSSEncRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProtocolVersion uint32
Parameters map[string]string
}{
Type: "GSSEncRequest",
})
}

View File

@ -0,0 +1,46 @@
package pgproto3
import (
"encoding/json"
)
type GSSResponse struct {
Data []byte
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (g *GSSResponse) Frontend() {}
func (g *GSSResponse) Decode(data []byte) error {
g.Data = data
return nil
}
func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, g.Data...)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "GSSResponse",
Data: g.Data,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
g.Data = msg.Data
return nil
}

34
vendor/github.com/jackc/pgx/v5/pgproto3/no_data.go generated vendored Normal file
View File

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type NoData struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NoData) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NoData) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoData) Encode(dst []byte) ([]byte, error) {
return append(dst, 'n', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.
func (src NoData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "NoData",
})
}

View File

@ -0,0 +1,19 @@
package pgproto3
type NoticeResponse ErrorResponse
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NoticeResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NoticeResponse) Decode(src []byte) error {
return (*ErrorResponse)(dst).Decode(src)
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'N')
dst = (*ErrorResponse)(src).appendFields(dst)
return finishMessage(dst, sp)
}

View File

@ -0,0 +1,71 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type NotificationResponse struct {
PID uint32
Channel string
Payload string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NotificationResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NotificationResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"}
}
pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
channel := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
payload := string(b[:len(b)-1])
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'A')
dst = pgio.AppendUint32(dst, src.PID)
dst = append(dst, src.Channel...)
dst = append(dst, 0)
dst = append(dst, src.Payload...)
dst = append(dst, 0)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src NotificationResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
PID uint32
Channel string
Payload string
}{
Type: "NotificationResponse",
PID: src.PID,
Channel: src.Channel,
Payload: src.Payload,
})
}

View File

@ -0,0 +1,67 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type ParameterDescription struct {
ParameterOIDs []uint32
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParameterDescription) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParameterDescription) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
}
// Reported parameter count will be incorrect when number of args is greater than uint16
buf.Next(2)
// Instead infer parameter count by remaining size of message
parameterCount := buf.Len() / 4
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
for i := 0; i < parameterCount; i++ {
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 't')
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ParameterDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ParameterOIDs []uint32
}{
Type: "ParameterDescription",
ParameterOIDs: src.ParameterOIDs,
})
}

View File

@ -0,0 +1,58 @@
package pgproto3
import (
"bytes"
"encoding/json"
)
type ParameterStatus struct {
Name string
Value string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParameterStatus) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParameterStatus) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
name := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
value := string(b[:len(b)-1])
*dst = ParameterStatus{Name: name, Value: value}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'S')
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Value...)
dst = append(dst, 0)
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (ps ParameterStatus) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Value string
}{
Type: "ParameterStatus",
Name: ps.Name,
Value: ps.Value,
})
}

89
vendor/github.com/jackc/pgx/v5/pgproto3/parse.go generated vendored Normal file
View File

@ -0,0 +1,89 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Parse struct {
Name string
Query string
ParameterOIDs []uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Parse) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Parse) Decode(src []byte) error {
*dst = Parse{}
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Name = string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
dst.Query = string(b[:len(b)-1])
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
for i := 0; i < parameterOIDCount; i++ {
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Parse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'P')
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Query...)
dst = append(dst, 0)
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Parse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Query string
ParameterOIDs []uint32
}{
Type: "Parse",
Name: src.Name,
Query: src.Query,
ParameterOIDs: src.ParameterOIDs,
})
}

Some files were not shown because too many files have changed in this diff Show More