aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn/krb5.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/krb5.go100
1 files changed, 100 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go
new file mode 100644
index 0000000..3c1af34
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go
@@ -0,0 +1,100 @@
+package pgconn
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/jackc/pgx/v5/pgproto3"
+)
+
+// NewGSSFunc creates a GSS authentication provider, for use with
+// RegisterGSSProvider.
+type NewGSSFunc func() (GSS, error)
+
+var newGSS NewGSSFunc
+
+// RegisterGSSProvider registers a GSS authentication provider. For example, if
+// you need to use Kerberos to authenticate with your server, add this to your
+// main package:
+//
+// import "github.com/otan/gopgkrb5"
+//
+// func init() {
+// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
+// }
+func RegisterGSSProvider(newGSSArg NewGSSFunc) {
+ newGSS = newGSSArg
+}
+
+// GSS provides GSSAPI authentication (e.g., Kerberos).
+type GSS interface {
+ GetInitToken(host string, service string) ([]byte, error)
+ GetInitTokenFromSPN(spn string) ([]byte, error)
+ Continue(inToken []byte) (done bool, outToken []byte, err error)
+}
+
+func (c *PgConn) gssAuth() error {
+ if newGSS == nil {
+ return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
+ }
+ cli, err := newGSS()
+ if err != nil {
+ return err
+ }
+
+ var nextData []byte
+ if c.config.KerberosSpn != "" {
+ // Use the supplied SPN if provided.
+ nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
+ } else {
+ // Allow the kerberos service name to be overridden
+ service := "postgres"
+ if c.config.KerberosSrvName != "" {
+ service = c.config.KerberosSrvName
+ }
+ nextData, err = cli.GetInitToken(c.config.Host, service)
+ }
+ if err != nil {
+ return err
+ }
+
+ for {
+ gssResponse := &pgproto3.GSSResponse{
+ Data: nextData,
+ }
+ c.frontend.Send(gssResponse)
+ err = c.flushWithPotentialWriteReadDeadlock()
+ if err != nil {
+ return err
+ }
+ resp, err := c.rxGSSContinue()
+ if err != nil {
+ return err
+ }
+ var done bool
+ done, nextData, err = cli.Continue(resp.Data)
+ if err != nil {
+ return err
+ }
+ if done {
+ break
+ }
+ }
+ return nil
+}
+
+func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
+ msg, err := c.receiveMessage()
+ if err != nil {
+ return nil, err
+ }
+
+ switch m := msg.(type) {
+ case *pgproto3.AuthenticationGSSContinue:
+ return m, nil
+ case *pgproto3.ErrorResponse:
+ return nil, ErrorResponseToPgError(m)
+ }
+
+ return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
+}