Skip to content

Commit

Permalink
feat(spanner): support CommitStats (#3444)
Browse files Browse the repository at this point in the history
Adds support for `CommitStats` to the Spanner client.

Based on #3056.
  • Loading branch information
olavloite committed Feb 4, 2021
1 parent 6b2807f commit b7c3ca6
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 6 deletions.
2 changes: 2 additions & 0 deletions spanner/client.go
Expand Up @@ -459,6 +459,8 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txOpts = options

trace.TracePrintf(ctx, map[string]interface{}{"transactionID": string(sh.getTransactionID())},
"Starting transaction attempt")
if err = t.begin(ctx); err != nil {
Expand Down
75 changes: 75 additions & 0 deletions spanner/client_test.go
Expand Up @@ -829,6 +829,81 @@ func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) {
}
}

func TestClient_ReadWriteTransactionWithOptions(t *testing.T) {
_, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()
resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
rowCount := int64(0)
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
var singerID, albumID int64
var albumTitle string
if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
return err
}
rowCount++
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
return nil
}, TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}})
if err != nil {
t.Fatalf("Failed to execute the transaction: %s", err)
}
if got, want := resp.CommitStats.MutationCount, int64(1); got != want {
t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want)
}
}

func TestClient_ReadWriteStmtBasedTransactionWithOptions(t *testing.T) {
_, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()
tx, err := NewReadWriteStmtBasedTransactionWithOptions(ctx, client, TransactionOptions{CommitOptions{ReturnCommitStats: true}})
if err != nil {
t.Fatalf("Unexpected error when creating transaction: %v", err)
}

iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
rowCount := int64(0)
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatalf("Unexpected error when fetching query results: %v", err)
}
var singerID, albumID int64
var albumTitle string
if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
t.Fatalf("Unexpected error when getting query data: %v", err)
}
rowCount++
}
resp, err := tx.CommitWithReturnResp(ctx)
if err != nil {
t.Fatalf("Unexpected error when committing transaction: %v", err)
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
t.Errorf("Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
if got, want := resp.CommitStats.MutationCount, int64(1); got != want {
t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want)
}
}

func TestClient_ReadWriteTransaction_DoNotLeakSessionOnPanic(t *testing.T) {
// Make sure that there is always only one session in the pool.
sc := SessionPoolConfig{
Expand Down
157 changes: 157 additions & 0 deletions spanner/integration_test.go
Expand Up @@ -1021,8 +1021,79 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) {
verifyDirectPathRemoteAddress(t)
}

// Test ReadWriteTransactionWithOptions.
func TestIntegration_ReadWriteTransactionWithOptions(t *testing.T) {
t.Parallel()
skipEmulatorTest(t)

// Give a longer deadline because of transaction backoffs.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
defer cleanup()

// Set up two accounts
accounts := []*Mutation{
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
}
if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil {
t.Fatal(err)
}

readBalance := func(iter *RowIterator) (int64, error) {
defer iter.Stop()
var bal int64
for {
row, err := iter.Next()
if err == iterator.Done {
return bal, nil
}
if err != nil {
return 0, err
}
if err := row.Column(0, &bal); err != nil {
return 0, err
}
}
}

txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}}
resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
// Query Foo's balance and Bar's balance.
bf, e := readBalance(tx.Query(ctx,
Statement{"SELECT Balance FROM Accounts WHERE AccountId = @id", map[string]interface{}{"id": int64(1)}}))
if e != nil {
return e
}
bb, e := readBalance(tx.Read(ctx, "Accounts", KeySets(Key{int64(2)}), []string{"Balance"}))
if e != nil {
return e
}
if bf <= 0 {
return nil
}
bf--
bb++
return tx.BufferWrite([]*Mutation{
Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(1), bf}),
Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(2), bb}),
})
}, txOpts)
if err != nil {
t.Fatalf("Failed to execute transaction: %v", err)
}
if resp.CommitStats == nil {
t.Fatal("Missing commit stats in commit response")
}
if got, want := resp.CommitStats.MutationCount, int64(8); got != want {
t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want)
}
}

func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) {
t.Parallel()
skipEmulatorTest(t)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
Expand Down Expand Up @@ -1113,6 +1184,92 @@ func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) {
}
}

func TestIntegration_ReadWriteTransaction_StatementBasedWithOptions(t *testing.T) {
t.Parallel()
skipEmulatorTest(t)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
defer cleanup()

// Set up two accounts
accounts := []*Mutation{
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
}
if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil {
t.Fatal(err)
}

getBalance := func(txn *ReadWriteStmtBasedTransaction, key Key) (int64, error) {
row, err := txn.ReadRow(ctx, "Accounts", key, []string{"Balance"})
if err != nil {
return 0, err
}
var balance int64
if err := row.Column(0, &balance); err != nil {
return 0, err
}
return balance, nil
}

statements := func(txn *ReadWriteStmtBasedTransaction) error {
outBalance, err := getBalance(txn, Key{1})
if err != nil {
return err
}
const transferAmt = 20
if outBalance >= transferAmt {
inBalance, err := getBalance(txn, Key{2})
if err != nil {
return err
}
inBalance += transferAmt
outBalance -= transferAmt
cols := []string{"AccountId", "Balance"}
txn.BufferWrite([]*Mutation{
Update("Accounts", cols, []interface{}{1, outBalance}),
Update("Accounts", cols, []interface{}{2, inBalance}),
})
}
return nil
}

var resp CommitResponse
txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}}
for {
tx, err := NewReadWriteStmtBasedTransactionWithOptions(ctx, client, txOpts)
if err != nil {
t.Fatalf("failed to begin a transaction: %v", err)
}
err = statements(tx)
if err != nil && status.Code(err) != codes.Aborted {
tx.Rollback(ctx)
t.Fatalf("failed to execute statements: %v", err)
} else if err == nil {
resp, err = tx.CommitWithReturnResp(ctx)
if err == nil {
break
} else if status.Code(err) != codes.Aborted {
t.Fatalf("failed to commit a transaction: %v", err)
}
}
// Set a default sleep time if the server delay is absent.
delay := 10 * time.Millisecond
if serverDelay, hasServerDelay := ExtractRetryDelay(err); hasServerDelay {
delay = serverDelay
}
time.Sleep(delay)
}
if resp.CommitStats == nil {
t.Fatal("Missing commit stats in commit response")
}
if got, want := resp.CommitStats.MutationCount, int64(8); got != want {
t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want)
}
}

func TestIntegration_Reads(t *testing.T) {
t.Parallel()

Expand Down
8 changes: 7 additions & 1 deletion spanner/internal/testutil/inmem_spanner_server.go
Expand Up @@ -984,7 +984,13 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
}
s.removeTransaction(tx)
return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil
resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
if req.ReturnCommitStats {
resp.CommitStats = &spannerpb.CommitResponse_CommitStats{
MutationCount: int64(1),
}
}
return resp, nil
}

func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
Expand Down
35 changes: 30 additions & 5 deletions spanner/transaction.go
Expand Up @@ -72,10 +72,14 @@ type txReadOnly struct {

// qo provides options for executing a sql query.
qo QueryOptions

// txOpts provides options for a transaction.
txOpts TransactionOptions
}

// TransactionOptions provides options for a transaction.
type TransactionOptions struct {
CommitOptions CommitOptions
}

// errSessionClosed returns error for using a recycled/destroyed session
Expand Down Expand Up @@ -962,11 +966,18 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error {
type CommitResponse struct {
// CommitTs is the commit time for a transaction.
CommitTs time.Time
// CommitStats is the commit statistics for a transaction.
CommitStats *sppb.CommitResponse_CommitStats
}

// CommitOptions provides options for commiting a transaction in a database.
type CommitOptions struct {
ReturnCommitStats bool
}

// commit tries to commit a readwrite transaction to Cloud Spanner. It also
// returns the commit response for the transactions.
func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, error) {
func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions) (CommitResponse, error) {
resp := CommitResponse{}
t.mu.Lock()
t.state = txClosed // No further operations after commit.
Expand All @@ -987,14 +998,18 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, erro
Transaction: &sppb.CommitRequest_TransactionId{
TransactionId: t.tx,
},
Mutations: mPb,
Mutations: mPb,
ReturnCommitStats: options.ReturnCommitStats,
})
if e != nil {
return resp, toSpannerErrorWithCommitInfo(e, true)
}
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
}
if options.ReturnCommitStats {
resp.CommitStats = res.CommitStats
}
if isSessionNotFoundError(err) {
t.sh.destroy()
}
Expand Down Expand Up @@ -1032,7 +1047,7 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont
)
if err = f(context.WithValue(ctx, transactionInProgressKey{}, 1), t); err == nil {
// Try to commit if transaction body returns no error.
resp, err = t.commit(ctx)
resp, err = t.commit(ctx, t.txOpts.CommitOptions)
errDuringCommit = err != nil
}
if err != nil {
Expand Down Expand Up @@ -1069,6 +1084,8 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont
type ReadWriteStmtBasedTransaction struct {
// ReadWriteTransaction contains methods for performing transactional reads.
ReadWriteTransaction

options TransactionOptions
}

// NewReadWriteStmtBasedTransaction starts a read-write transaction. Commit() or
Expand Down Expand Up @@ -1114,6 +1131,7 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client,
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txOpts = options

if err = t.begin(ctx); err != nil {
if sh != nil {
Expand All @@ -1127,15 +1145,22 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client,
// Commit tries to commit a readwrite transaction to Cloud Spanner. It also
// returns the commit timestamp for the transactions.
func (t *ReadWriteStmtBasedTransaction) Commit(ctx context.Context) (time.Time, error) {
resp, err := t.commit(ctx)
resp, err := t.CommitWithReturnResp(ctx)
return resp.CommitTs, err
}

// CommitWithReturnResp tries to commit a readwrite transaction. It also returns
// the commit timestamp and stats for the transactions.
func (t *ReadWriteStmtBasedTransaction) CommitWithReturnResp(ctx context.Context) (CommitResponse, error) {
resp, err := t.commit(ctx, t.txOpts.CommitOptions)
// Rolling back an aborted transaction is not necessary.
if err != nil && status.Code(err) != codes.Aborted {
t.rollback(ctx)
}
if t.sh != nil {
t.sh.recycle()
}
return resp.CommitTs, err
return resp, err
}

// Rollback is called to cancel the ongoing transaction that has not been
Expand Down

0 comments on commit b7c3ca6

Please sign in to comment.