aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/lib/pq/copy.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/lib/pq/copy.go')
-rw-r--r--vendor/github.com/lib/pq/copy.go303
1 files changed, 303 insertions, 0 deletions
diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go
new file mode 100644
index 0000000..c072bc3
--- /dev/null
+++ b/vendor/github.com/lib/pq/copy.go
@@ -0,0 +1,303 @@
+package pq
+
+import (
+ "database/sql/driver"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "sync"
+)
+
+var (
+ errCopyInClosed = errors.New("pq: copyin statement has already been closed")
+ errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
+ errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
+ errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
+ errCopyInProgress = errors.New("pq: COPY in progress")
+)
+
+// CopyIn creates a COPY FROM statement which can be prepared with
+// Tx.Prepare(). The target table should be visible in search_path.
+func CopyIn(table string, columns ...string) string {
+ stmt := "COPY " + QuoteIdentifier(table) + " ("
+ for i, col := range columns {
+ if i != 0 {
+ stmt += ", "
+ }
+ stmt += QuoteIdentifier(col)
+ }
+ stmt += ") FROM STDIN"
+ return stmt
+}
+
+// CopyInSchema creates a COPY FROM statement which can be prepared with
+// Tx.Prepare().
+func CopyInSchema(schema, table string, columns ...string) string {
+ stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
+ for i, col := range columns {
+ if i != 0 {
+ stmt += ", "
+ }
+ stmt += QuoteIdentifier(col)
+ }
+ stmt += ") FROM STDIN"
+ return stmt
+}
+
+type copyin struct {
+ cn *conn
+ buffer []byte
+ rowData chan []byte
+ done chan bool
+
+ closed bool
+
+ mu struct {
+ sync.Mutex
+ err error
+ driver.Result
+ }
+}
+
+const ciBufferSize = 64 * 1024
+
+// flush buffer before the buffer is filled up and needs reallocation
+const ciBufferFlushSize = 63 * 1024
+
+func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
+ if !cn.isInTransaction() {
+ return nil, errCopyNotSupportedOutsideTxn
+ }
+
+ ci := &copyin{
+ cn: cn,
+ buffer: make([]byte, 0, ciBufferSize),
+ rowData: make(chan []byte),
+ done: make(chan bool, 1),
+ }
+ // add CopyData identifier + 4 bytes for message length
+ ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
+
+ b := cn.writeBuf('Q')
+ b.string(q)
+ cn.send(b)
+
+awaitCopyInResponse:
+ for {
+ t, r := cn.recv1()
+ switch t {
+ case 'G':
+ if r.byte() != 0 {
+ err = errBinaryCopyNotSupported
+ break awaitCopyInResponse
+ }
+ go ci.resploop()
+ return ci, nil
+ case 'H':
+ err = errCopyToNotSupported
+ break awaitCopyInResponse
+ case 'E':
+ err = parseError(r)
+ case 'Z':
+ if err == nil {
+ ci.setBad(driver.ErrBadConn)
+ errorf("unexpected ReadyForQuery in response to COPY")
+ }
+ cn.processReadyForQuery(r)
+ return nil, err
+ default:
+ ci.setBad(driver.ErrBadConn)
+ errorf("unknown response for copy query: %q", t)
+ }
+ }
+
+ // something went wrong, abort COPY before we return
+ b = cn.writeBuf('f')
+ b.string(err.Error())
+ cn.send(b)
+
+ for {
+ t, r := cn.recv1()
+ switch t {
+ case 'c', 'C', 'E':
+ case 'Z':
+ // correctly aborted, we're done
+ cn.processReadyForQuery(r)
+ return nil, err
+ default:
+ ci.setBad(driver.ErrBadConn)
+ errorf("unknown response for CopyFail: %q", t)
+ }
+ }
+}
+
+func (ci *copyin) flush(buf []byte) {
+ // set message length (without message identifier)
+ binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
+
+ _, err := ci.cn.c.Write(buf)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func (ci *copyin) resploop() {
+ for {
+ var r readBuf
+ t, err := ci.cn.recvMessage(&r)
+ if err != nil {
+ ci.setBad(driver.ErrBadConn)
+ ci.setError(err)
+ ci.done <- true
+ return
+ }
+ switch t {
+ case 'C':
+ // complete
+ res, _ := ci.cn.parseComplete(r.string())
+ ci.setResult(res)
+ case 'N':
+ if n := ci.cn.noticeHandler; n != nil {
+ n(parseError(&r))
+ }
+ case 'Z':
+ ci.cn.processReadyForQuery(&r)
+ ci.done <- true
+ return
+ case 'E':
+ err := parseError(&r)
+ ci.setError(err)
+ default:
+ ci.setBad(driver.ErrBadConn)
+ ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
+ ci.done <- true
+ return
+ }
+ }
+}
+
+func (ci *copyin) setBad(err error) {
+ ci.cn.err.set(err)
+}
+
+func (ci *copyin) getBad() error {
+ return ci.cn.err.get()
+}
+
+func (ci *copyin) err() error {
+ ci.mu.Lock()
+ err := ci.mu.err
+ ci.mu.Unlock()
+ return err
+}
+
+// setError() sets ci.err if one has not been set already. Caller must not be
+// holding ci.Mutex.
+func (ci *copyin) setError(err error) {
+ ci.mu.Lock()
+ if ci.mu.err == nil {
+ ci.mu.err = err
+ }
+ ci.mu.Unlock()
+}
+
+func (ci *copyin) setResult(result driver.Result) {
+ ci.mu.Lock()
+ ci.mu.Result = result
+ ci.mu.Unlock()
+}
+
+func (ci *copyin) getResult() driver.Result {
+ ci.mu.Lock()
+ result := ci.mu.Result
+ ci.mu.Unlock()
+ if result == nil {
+ return driver.RowsAffected(0)
+ }
+ return result
+}
+
+func (ci *copyin) NumInput() int {
+ return -1
+}
+
+func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
+ return nil, ErrNotSupported
+}
+
+// Exec inserts values into the COPY stream. The insert is asynchronous
+// and Exec can return errors from previous Exec calls to the same
+// COPY stmt.
+//
+// You need to call Exec(nil) to sync the COPY stream and to get any
+// errors from pending data, since Stmt.Close() doesn't return errors
+// to the user.
+func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
+ if ci.closed {
+ return nil, errCopyInClosed
+ }
+
+ if err := ci.getBad(); err != nil {
+ return nil, err
+ }
+ defer ci.cn.errRecover(&err)
+
+ if err := ci.err(); err != nil {
+ return nil, err
+ }
+
+ if len(v) == 0 {
+ if err := ci.Close(); err != nil {
+ return driver.RowsAffected(0), err
+ }
+
+ return ci.getResult(), nil
+ }
+
+ numValues := len(v)
+ for i, value := range v {
+ ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
+ if i < numValues-1 {
+ ci.buffer = append(ci.buffer, '\t')
+ }
+ }
+
+ ci.buffer = append(ci.buffer, '\n')
+
+ if len(ci.buffer) > ciBufferFlushSize {
+ ci.flush(ci.buffer)
+ // reset buffer, keep bytes for message identifier and length
+ ci.buffer = ci.buffer[:5]
+ }
+
+ return driver.RowsAffected(0), nil
+}
+
+func (ci *copyin) Close() (err error) {
+ if ci.closed { // Don't do anything, we're already closed
+ return nil
+ }
+ ci.closed = true
+
+ if err := ci.getBad(); err != nil {
+ return err
+ }
+ defer ci.cn.errRecover(&err)
+
+ if len(ci.buffer) > 0 {
+ ci.flush(ci.buffer)
+ }
+ // Avoid touching the scratch buffer as resploop could be using it.
+ err = ci.cn.sendSimpleMessage('c')
+ if err != nil {
+ return err
+ }
+
+ <-ci.done
+ ci.cn.inCopy = false
+
+ if err := ci.err(); err != nil {
+ return err
+ }
+ return nil
+}