diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index 5bbcf20db2f07..e3e8e0e7da7e2 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -542,3 +542,24 @@ func (noRows) LastInsertId() (int64, error) { func (noRows) RowsAffected() (int64, error) { return 0, errors.New("no RowsAffected available after DDL statement") } + +// CopyResult signals termination of a CopyIn. +// Both Res and Err may be populated in case of a partial write. +type CopyResult struct { + Res Result + Err error +} + +// Copier is an optional interface that may be implemented by a Conn. +type Copier interface { + + // CopyIn sends a batch of data to the server in a single execution. + // The driver may do this asynchronously. + // + // The data channel may be buffered. + // The sender closes the data channel when all data is sent. + // + // Exact one of CopyResult must always be send on the result channel. + // Either after all data is flushed, or after encountering an error. + CopyIn(ctx context.Context, table string, columns ...string) (data chan<- []NamedValue, result <-chan CopyResult) +} diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index e3580698fd6fa..8f6140175ddda 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -3293,3 +3293,96 @@ func withLock(lk sync.Locker, fn func()) { defer lk.Unlock() // in case fn panics fn() } + +// Batch is input to a CopyIn command. +type Batch interface { + Table() string + Columns() []string + // Next returns a row of data in a batch. + // It should return io.EOF when no more data is available. + Next() ([]interface{}, error) +} + +// ErrNotCopier is returned when the driver does not support batch operations. +var ErrNotCopier = errors.New("sql: driver does not support CopyIn") + +// CopyIn sends the provided batch in a single, asynchronous operation. +// +// In case of an error, partial data may have been written to the database. +// If supported, a driver may return both Result and error to indicate the amount of rows written. +func (tx *Tx) CopyIn(ctx context.Context, batch Batch) (Result, error) { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return nil, err + } + return tx.db.copyDC(ctx, dc, release, batch) +} + +func (db *DB) copyDC(ctx context.Context, dc *driverConn, release func(error), batch Batch) (res Result, err error) { + defer func() { + release(err) + }() + + copier, ok := dc.ci.(driver.Copier) + if !ok { + err = ErrNotCopier + return + } + + withLock(dc, func() { + data, result := copier.CopyIn(ctx, batch.Table(), batch.Columns()...) + + for { + var ( + args []interface{} + nvdargs []driver.NamedValue + ) + + args, err = batch.Next() + if err != nil { + close(data) + r := <-result + + if err == io.EOF { + res, err = r.Res, r.Err + } else { + res, err = r.Res, fmt.Errorf("sql Batch: %w", err) + } + + return + } + + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + + select { + + case r := <-result: + res, err = r.Res, r.Err + return + + case <-ctx.Done(): + close(data) + + r := <-result + + // Give priority for the error reported from the driver. + // If its nil, we set the context error. + if err = r.Err; err == nil { + err = ctx.Err() + } + + return + + case data <- nvdargs: + continue + } + + } + + }) + + return +}