From b7c3ca6c83cbdca95d734df8aa07c5ddb8ab3db0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 4 Feb 2021 08:32:03 +0100 Subject: [PATCH] feat(spanner): support CommitStats (#3444) Adds support for `CommitStats` to the Spanner client. Based on https://github.com/googleapis/google-cloud-go/pull/3056. --- spanner/client.go | 2 + spanner/client_test.go | 75 +++++++++ spanner/integration_test.go | 157 ++++++++++++++++++ .../internal/testutil/inmem_spanner_server.go | 8 +- spanner/transaction.go | 35 +++- spanner/transaction_test.go | 56 +++++++ 6 files changed, 327 insertions(+), 6 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index 0e12690f74b..cd02d23fee3 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -459,6 +459,8 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea t.txReadOnly.sh = sh t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo + t.txOpts = options + trace.TracePrintf(ctx, map[string]interface{}{"transactionID": string(sh.getTransactionID())}, "Starting transaction attempt") if err = t.begin(ctx); err != nil { diff --git a/spanner/client_test.go b/spanner/client_test.go index cd29a79367c..8a2504706ac 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -829,6 +829,81 @@ func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) { } } +func TestClient_ReadWriteTransactionWithOptions(t *testing.T) { + _, client, teardown := setupMockedTestServer(t) + defer teardown() + ctx := context.Background() + resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + rowCount := int64(0) + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + rowCount++ + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + } + return nil + }, TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}}) + if err != nil { + t.Fatalf("Failed to execute the transaction: %s", err) + } + if got, want := resp.CommitStats.MutationCount, int64(1); got != want { + t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want) + } +} + +func TestClient_ReadWriteStmtBasedTransactionWithOptions(t *testing.T) { + _, client, teardown := setupMockedTestServer(t) + defer teardown() + ctx := context.Background() + tx, err := NewReadWriteStmtBasedTransactionWithOptions(ctx, client, TransactionOptions{CommitOptions{ReturnCommitStats: true}}) + if err != nil { + t.Fatalf("Unexpected error when creating transaction: %v", err) + } + + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + rowCount := int64(0) + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Unexpected error when fetching query results: %v", err) + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + t.Fatalf("Unexpected error when getting query data: %v", err) + } + rowCount++ + } + resp, err := tx.CommitWithReturnResp(ctx) + if err != nil { + t.Fatalf("Unexpected error when committing transaction: %v", err) + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + t.Errorf("Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + } + if got, want := resp.CommitStats.MutationCount, int64(1); got != want { + t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want) + } +} + func TestClient_ReadWriteTransaction_DoNotLeakSessionOnPanic(t *testing.T) { // Make sure that there is always only one session in the pool. sc := SessionPoolConfig{ diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 3f87716f47c..73287c4b8d6 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -1021,8 +1021,79 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) { verifyDirectPathRemoteAddress(t) } +// Test ReadWriteTransactionWithOptions. +func TestIntegration_ReadWriteTransactionWithOptions(t *testing.T) { + t.Parallel() + skipEmulatorTest(t) + + // Give a longer deadline because of transaction backoffs. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + defer cleanup() + + // Set up two accounts + accounts := []*Mutation{ + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), + } + if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + readBalance := func(iter *RowIterator) (int64, error) { + defer iter.Stop() + var bal int64 + for { + row, err := iter.Next() + if err == iterator.Done { + return bal, nil + } + if err != nil { + return 0, err + } + if err := row.Column(0, &bal); err != nil { + return 0, err + } + } + } + + txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}} + resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + // Query Foo's balance and Bar's balance. + bf, e := readBalance(tx.Query(ctx, + Statement{"SELECT Balance FROM Accounts WHERE AccountId = @id", map[string]interface{}{"id": int64(1)}})) + if e != nil { + return e + } + bb, e := readBalance(tx.Read(ctx, "Accounts", KeySets(Key{int64(2)}), []string{"Balance"})) + if e != nil { + return e + } + if bf <= 0 { + return nil + } + bf-- + bb++ + return tx.BufferWrite([]*Mutation{ + Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(1), bf}), + Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(2), bb}), + }) + }, txOpts) + if err != nil { + t.Fatalf("Failed to execute transaction: %v", err) + } + if resp.CommitStats == nil { + t.Fatal("Missing commit stats in commit response") + } + if got, want := resp.CommitStats.MutationCount, int64(8); got != want { + t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want) + } +} + func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) { t.Parallel() + skipEmulatorTest(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -1113,6 +1184,92 @@ func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) { } } +func TestIntegration_ReadWriteTransaction_StatementBasedWithOptions(t *testing.T) { + t.Parallel() + skipEmulatorTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + defer cleanup() + + // Set up two accounts + accounts := []*Mutation{ + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), + } + if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + getBalance := func(txn *ReadWriteStmtBasedTransaction, key Key) (int64, error) { + row, err := txn.ReadRow(ctx, "Accounts", key, []string{"Balance"}) + if err != nil { + return 0, err + } + var balance int64 + if err := row.Column(0, &balance); err != nil { + return 0, err + } + return balance, nil + } + + statements := func(txn *ReadWriteStmtBasedTransaction) error { + outBalance, err := getBalance(txn, Key{1}) + if err != nil { + return err + } + const transferAmt = 20 + if outBalance >= transferAmt { + inBalance, err := getBalance(txn, Key{2}) + if err != nil { + return err + } + inBalance += transferAmt + outBalance -= transferAmt + cols := []string{"AccountId", "Balance"} + txn.BufferWrite([]*Mutation{ + Update("Accounts", cols, []interface{}{1, outBalance}), + Update("Accounts", cols, []interface{}{2, inBalance}), + }) + } + return nil + } + + var resp CommitResponse + txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}} + for { + tx, err := NewReadWriteStmtBasedTransactionWithOptions(ctx, client, txOpts) + if err != nil { + t.Fatalf("failed to begin a transaction: %v", err) + } + err = statements(tx) + if err != nil && status.Code(err) != codes.Aborted { + tx.Rollback(ctx) + t.Fatalf("failed to execute statements: %v", err) + } else if err == nil { + resp, err = tx.CommitWithReturnResp(ctx) + if err == nil { + break + } else if status.Code(err) != codes.Aborted { + t.Fatalf("failed to commit a transaction: %v", err) + } + } + // Set a default sleep time if the server delay is absent. + delay := 10 * time.Millisecond + if serverDelay, hasServerDelay := ExtractRetryDelay(err); hasServerDelay { + delay = serverDelay + } + time.Sleep(delay) + } + if resp.CommitStats == nil { + t.Fatal("Missing commit stats in commit response") + } + if got, want := resp.CommitStats.MutationCount, int64(8); got != want { + t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want) + } +} + func TestIntegration_Reads(t *testing.T) { t.Parallel() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index fd39d747868..c2b1b8894e4 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -984,7 +984,13 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") } s.removeTransaction(tx) - return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil + resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()} + if req.ReturnCommitStats { + resp.CommitStats = &spannerpb.CommitResponse_CommitStats{ + MutationCount: int64(1), + } + } + return resp, nil } func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { diff --git a/spanner/transaction.go b/spanner/transaction.go index cd7ea25bcaf..67bf2b7c16c 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -72,10 +72,14 @@ type txReadOnly struct { // qo provides options for executing a sql query. qo QueryOptions + + // txOpts provides options for a transaction. + txOpts TransactionOptions } // TransactionOptions provides options for a transaction. type TransactionOptions struct { + CommitOptions CommitOptions } // errSessionClosed returns error for using a recycled/destroyed session @@ -962,11 +966,18 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { type CommitResponse struct { // CommitTs is the commit time for a transaction. CommitTs time.Time + // CommitStats is the commit statistics for a transaction. + CommitStats *sppb.CommitResponse_CommitStats +} + +// CommitOptions provides options for commiting a transaction in a database. +type CommitOptions struct { + ReturnCommitStats bool } // commit tries to commit a readwrite transaction to Cloud Spanner. It also // returns the commit response for the transactions. -func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, error) { +func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions) (CommitResponse, error) { resp := CommitResponse{} t.mu.Lock() t.state = txClosed // No further operations after commit. @@ -987,7 +998,8 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, erro Transaction: &sppb.CommitRequest_TransactionId{ TransactionId: t.tx, }, - Mutations: mPb, + Mutations: mPb, + ReturnCommitStats: options.ReturnCommitStats, }) if e != nil { return resp, toSpannerErrorWithCommitInfo(e, true) @@ -995,6 +1007,9 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, erro if tstamp := res.GetCommitTimestamp(); tstamp != nil { resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) } + if options.ReturnCommitStats { + resp.CommitStats = res.CommitStats + } if isSessionNotFoundError(err) { t.sh.destroy() } @@ -1032,7 +1047,7 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont ) if err = f(context.WithValue(ctx, transactionInProgressKey{}, 1), t); err == nil { // Try to commit if transaction body returns no error. - resp, err = t.commit(ctx) + resp, err = t.commit(ctx, t.txOpts.CommitOptions) errDuringCommit = err != nil } if err != nil { @@ -1069,6 +1084,8 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont type ReadWriteStmtBasedTransaction struct { // ReadWriteTransaction contains methods for performing transactional reads. ReadWriteTransaction + + options TransactionOptions } // NewReadWriteStmtBasedTransaction starts a read-write transaction. Commit() or @@ -1114,6 +1131,7 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client, t.txReadOnly.sh = sh t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo + t.txOpts = options if err = t.begin(ctx); err != nil { if sh != nil { @@ -1127,7 +1145,14 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client, // Commit tries to commit a readwrite transaction to Cloud Spanner. It also // returns the commit timestamp for the transactions. func (t *ReadWriteStmtBasedTransaction) Commit(ctx context.Context) (time.Time, error) { - resp, err := t.commit(ctx) + resp, err := t.CommitWithReturnResp(ctx) + return resp.CommitTs, err +} + +// CommitWithReturnResp tries to commit a readwrite transaction. It also returns +// the commit timestamp and stats for the transactions. +func (t *ReadWriteStmtBasedTransaction) CommitWithReturnResp(ctx context.Context) (CommitResponse, error) { + resp, err := t.commit(ctx, t.txOpts.CommitOptions) // Rolling back an aborted transaction is not necessary. if err != nil && status.Code(err) != codes.Aborted { t.rollback(ctx) @@ -1135,7 +1160,7 @@ func (t *ReadWriteStmtBasedTransaction) Commit(ctx context.Context) (time.Time, if t.sh != nil { t.sh.recycle() } - return resp.CommitTs, err + return resp, err } // Rollback is called to cancel the ongoing transaction that has not been diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 686e1c81032..c3eebead750 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -510,6 +510,62 @@ func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]S return rowCount, attempts, err } +func TestReadWriteStmtBasedTransactionWithOptions(t *testing.T) { + t.Parallel() + + _, client, teardown := setupMockedTestServer(t) + defer teardown() + ctx := context.Background() + + f := func(tx *ReadWriteStmtBasedTransaction) (int64, error) { + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + rowCount := int64(0) + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return 0, err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return 0, err + } + rowCount++ + } + return rowCount, nil + } + + var resp CommitResponse + for { + tx, err := NewReadWriteStmtBasedTransactionWithOptions( + ctx, + client, + TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}}, + ) + _, err = f(tx) + if err != nil && status.Code(err) != codes.Aborted { + tx.Rollback(ctx) + break + } else if err == nil { + resp, err = tx.CommitWithReturnResp(ctx) + break + } + // Set a default sleep time if the server delay is absent. + delay := 10 * time.Millisecond + if serverDelay, hasServerDelay := ExtractRetryDelay(err); hasServerDelay { + delay = serverDelay + } + time.Sleep(delay) + } + if got, want := resp.CommitStats.MutationCount, int64(1); got != want { + t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want) + } +} + func TestBatchDML_StatementBased_WithMultipleDML(t *testing.T) { t.Parallel() ctx := context.Background()