Skip to content

Commit

Permalink
feat: support getting the commit timestamp of a transaction or statem…
Browse files Browse the repository at this point in the history
…ent (#52)

* feat: support getting the commit timestamp of a transaction or statement

Cloud Spanner returns the commit timestamp of a read/write transaction as a part of the commit response.
This PR adds support for getting that timestamp in two ways:
1. Through a CommitTimestamp() method on the SpannerConn interface.
2. Through the SHOW VARIABLE COMMIT_TIMESTAMP client side statement.

* docs: add comments
  • Loading branch information
olavloite committed Oct 19, 2021
1 parent 01828bb commit 802e7be
Show file tree
Hide file tree
Showing 8 changed files with 411 additions and 19 deletions.
20 changes: 20 additions & 0 deletions client_side_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ import (
type statementExecutor struct {
}

func (s *statementExecutor) ShowCommitTimestamp(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) {
ts, err := c.CommitTimestamp()
var commitTs *time.Time
if err == nil {
commitTs = &ts
}
it, err := createTimestampIterator("CommitTimestamp", commitTs)
if err != nil {
return nil, err
}
return &rows{it: it}, nil
}

func (s *statementExecutor) ShowRetryAbortsInternally(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) {
it, err := createBooleanIterator("RetryAbortsInternally", c.RetryAbortsInternally())
if err != nil {
Expand Down Expand Up @@ -210,6 +223,13 @@ func createStringIterator(column string, value string) (*clientSideIterator, err
return createSingleValueIterator(column, value, sppb.TypeCode_STRING)
}

// createTimestampIterator creates a row iterator with a single TIMESTAMP column with
// one row. This is used for client side statements that return a result set
// containing a TIMESTAMP value.
func createTimestampIterator(column string, value *time.Time) (*clientSideIterator, error) {
return createSingleValueIterator(column, value, sppb.TypeCode_TIMESTAMP)
}

func createSingleValueIterator(column string, value interface{}, code sppb.TypeCode) (*clientSideIterator, error) {
row, err := spanner.NewRow([]string{column}, []interface{}{value})
if err != nil {
Expand Down
44 changes: 44 additions & 0 deletions client_side_statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,47 @@ func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) {
}
}
}

func TestShowCommitTimestamp(t *testing.T) {
t.Parallel()

c := &conn{retryAborts: true}
s := &statementExecutor{}
ctx := context.Background()

ts := time.Now()
for _, test := range []struct {
wantValue *time.Time
}{
{&ts},
{nil},
} {
c.commitTs = test.wantValue

it, err := s.ShowCommitTimestamp(ctx, c, "", nil)
if err != nil {
t.Fatalf("could not get current commit timestamp from connection: %v", err)
}
cols := it.Columns()
wantCols := []string{"CommitTimestamp"}
if !cmp.Equal(cols, wantCols) {
t.Fatalf("column names mismatch\nGot: %v\nWant: %v", cols, wantCols)
}
values := make([]driver.Value, len(cols))
if err := it.Next(values); err != nil {
t.Fatalf("failed to get first row for commit timestamp: %v", err)
}
var wantValues []driver.Value
if test.wantValue != nil {
wantValues = []driver.Value{*test.wantValue}
} else {
wantValues = []driver.Value{nil}
}
if !cmp.Equal(values, wantValues) {
t.Fatalf("commit timestamp values mismatch\nGot: %v\nWant: %v", values, wantValues)
}
if err := it.Next(values); err != io.EOF {
t.Fatalf("error mismatch\nGot: %v\nWant: %v", err, io.EOF)
}
}
}
11 changes: 10 additions & 1 deletion client_side_statements_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@ package spannerdriver
var jsonFile = `{
"statements":
[
{
{
"name": "SHOW VARIABLE COMMIT_TIMESTAMP",
"executorName": "ClientSideStatementNoParamExecutor",
"resultType": "RESULT_SET",
"regex": "(?is)\\A\\s*show\\s+variable\\s+commit_timestamp\\s*\\z",
"method": "statementShowCommitTimestamp",
"exampleStatements": ["show variable commit_timestamp"],
"examplePrerequisiteStatements": ["update foo set bar=1"]
},
{
"name": "SHOW VARIABLE RETRY_ABORTS_INTERNALLY",
"executorName": "ClientSideStatementNoParamExecutor",
"resultType": "RESULT_SET",
Expand Down
40 changes: 33 additions & 7 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,24 @@ type SpannerConn interface {
// connection is in a read/write transaction. Use Apply to write mutations outside a transaction.
// See also spanner.ReadWriteTransaction#BufferWrite
BufferWrite(ms []*spanner.Mutation) error

// CommitTimestamp returns the commit timestamp of the last implicit or explicit read/write transaction that
// was executed on the connection, or an error if the connection has not executed a read/write transaction
// that committed successfully. The timestamp is in the local timezone.
CommitTimestamp() (commitTimestamp time.Time, err error)
}

type conn struct {
closed bool
client *spanner.Client
adminClient *adminapi.DatabaseAdminClient
tx contextTransaction
commitTs *time.Time
database string
retryAborts bool

execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error)
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, time.Time, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error)

// batch is the currently active DDL or DML batch on this connection.
Expand Down Expand Up @@ -345,6 +351,13 @@ const (
PartitionedNonAtomic
)

func (c *conn) CommitTimestamp() (time.Time, error) {
if c.commitTs == nil {
return time.Time{}, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "this connection has not executed a read/write transaction that committed successfully"))
}
return *c.commitTs, nil
}

func (c *conn) RetryAbortsInternally() bool {
return c.retryAborts
}
Expand Down Expand Up @@ -597,6 +610,7 @@ func (c *conn) ResetSession(_ context.Context) error {
return driver.ErrBadConn
}
}
c.commitTs = nil
c.batch = nil
c.retryAborts = true
c.autocommitDMLMode = Transactional
Expand Down Expand Up @@ -698,6 +712,8 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
if clientStmt != nil {
return clientStmt.QueryContext(ctx, args)
}
// Clear the commit timestamp of this connection before we execute the query.
c.commitTs = nil

stmt, err := prepareSpannerStmt(query, args)
if err != nil {
Expand All @@ -721,6 +737,8 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
if stmt != nil {
return stmt.ExecContext(ctx, args)
}
// Clear the commit timestamp of this connection before we execute the statement.
c.commitTs = nil

// Use admin API if DDL statement is provided.
isDDL, err := isDDL(query)
Expand All @@ -743,12 +761,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
}

var rowsAffected int64
var commitTs time.Time
if c.tx == nil {
if c.InDMLBatch() {
c.batch.statements = append(c.batch.statements, ss)
} else {
if c.autocommitDMLMode == Transactional {
rowsAffected, err = c.execSingleDMLTransactional(ctx, c.client, ss)
rowsAffected, commitTs, err = c.execSingleDMLTransactional(ctx, c.client, ss)
if err == nil {
c.commitTs = &commitTs
}
} else if c.autocommitDMLMode == PartitionedNonAtomic {
rowsAffected, err = c.execSingleDMLPartitioned(ctx, c.client, ss)
} else {
Expand Down Expand Up @@ -800,11 +822,15 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
ctx: ctx,
client: c.client,
rwTx: tx,
close: func() {
close: func(commitTs *time.Time, commitErr error) {
c.tx = nil
if commitErr == nil {
c.commitTs = commitTs
}
},
retryAborts: c.retryAborts,
}
c.commitTs = nil
return c.tx, nil
}

Expand Down Expand Up @@ -832,18 +858,18 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
return c.Single().WithTimestampBound(tb).Query(ctx, statement)
}

func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, time.Time, error) {
var rowsAffected int64
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
count, err := tx.Update(ctx, statement)
rowsAffected = count
return err
}
_, err := c.ReadWriteTransaction(ctx, fn)
ts, err := c.ReadWriteTransaction(ctx, fn)
if err != nil {
return 0, err
return 0, time.Time{}, err
}
return rowsAffected, nil
return rowsAffected, ts, nil
}

func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
Expand Down
61 changes: 57 additions & 4 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"cloud.google.com/go/spanner"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/codes"
)

func TestExtractDnsParts(t *testing.T) {
Expand Down Expand Up @@ -135,6 +136,7 @@ func TestConnection_Reset(t *testing.T) {
c := conn{
readOnlyStaleness: spanner.ExactStaleness(time.Second),
batch: &batch{tp: dml},
commitTs: &time.Time{},
tx: &readOnlyTransaction{
close: func() {
txClosed = true
Expand All @@ -151,6 +153,9 @@ func TestConnection_Reset(t *testing.T) {
if c.inBatch() {
t.Error("failed to clear batch")
}
if c.commitTs != nil {
t.Errorf("failed to clear commit timestamp")
}
if !txClosed {
t.Error("failed to close transaction")
}
Expand Down Expand Up @@ -259,8 +264,8 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) {
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
return 0, nil
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, time.Time, error) {
return 0, time.Time{}, nil
},
execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
return 0, nil
Expand Down Expand Up @@ -292,8 +297,8 @@ func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) {
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
return 0, nil
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, time.Time, error) {
return 0, time.Time{}, nil
},
execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
return 0, nil
Expand All @@ -319,3 +324,51 @@ func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) {
t.Fatalf("executing query failed: %v", err)
}
}

func TestConn_GetCommitTimestampAfterAutocommitDml(t *testing.T) {
want := time.Now()
c := &conn{
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, time.Time, error) {
return 0, want, nil
},
execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
return 0, nil
},
}
ctx := context.Background()
if _, err := c.ExecContext(ctx, "UPDATE FOO SET BAR=1 WHERE TRUE", []driver.NamedValue{}); err != nil {
t.Fatalf("failed to execute DML statement: %v", err)
}
got, err := c.CommitTimestamp()
if err != nil {
t.Fatalf("failed to get commit timestamp: %v", err)
}
if !cmp.Equal(want, got) {
t.Fatalf("commit timestamp mismatch\n Got: %v\nWant: %v", got, want)
}
}

func TestConn_GetCommitTimestampAfterAutocommitQuery(t *testing.T) {
c := &conn{
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, time.Time, error) {
return 0, time.Time{}, nil
},
execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement) (int64, error) {
return 0, nil
},
}
ctx := context.Background()
if _, err := c.QueryContext(ctx, "SELECT * FROM Foo", []driver.NamedValue{}); err != nil {
t.Fatalf("failed to execute query: %v", err)
}
_, err := c.CommitTimestamp()
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}
}

0 comments on commit 802e7be

Please sign in to comment.