Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner): support CommitStats #3444

Merged
merged 16 commits into from Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions spanner/client.go
Expand Up @@ -472,6 +472,8 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txOpts = options

trace.TracePrintf(ctx, map[string]interface{}{"transactionID": string(sh.getTransactionID())},
"Starting transaction attempt")
if err = t.begin(ctx); err != nil {
Expand Down
36 changes: 36 additions & 0 deletions spanner/client_test.go
Expand Up @@ -828,6 +828,42 @@ func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) {
}
}

func TestClient_ReadWriteTransactionWithOptions(t *testing.T) {
_, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()
resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
rowCount := int64(0)
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
var singerID, albumID int64
var albumTitle string
if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
return err
}
rowCount++
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
return nil
}, TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}})
if err != nil {
t.Fatalf("Failed to execute the transaction: %s", err)
}
if got, want := resp.CommitStats.MutationCount, int64(1); got != want {
t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want)
}
}

func TestClient_SessionNotFound(t *testing.T) {
// Ensure we always have at least one session in the pool.
sc := SessionPoolConfig{
Expand Down
148 changes: 148 additions & 0 deletions spanner/integration_test.go
Expand Up @@ -964,6 +964,72 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) {
}
}

// Test ReadWriteTransactionWithOptions.
func TestIntegration_ReadWriteTransactionWithOptions(t *testing.T) {
t.Parallel()

// Give a longer deadline because of transaction backoffs.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
defer cleanup()

// Set up two accounts
accounts := []*Mutation{
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
}
if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil {
t.Fatal(err)
}

readBalance := func(iter *RowIterator) (int64, error) {
defer iter.Stop()
var bal int64
for {
row, err := iter.Next()
if err == iterator.Done {
return bal, nil
}
if err != nil {
return 0, err
}
if err := row.Column(0, &bal); err != nil {
return 0, err
}
}
}

txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}}
resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
// Query Foo's balance and Bar's balance.
bf, e := readBalance(tx.Query(ctx,
Statement{"SELECT Balance FROM Accounts WHERE AccountId = @id", map[string]interface{}{"id": int64(1)}}))
if e != nil {
return e
}
bb, e := readBalance(tx.Read(ctx, "Accounts", KeySets(Key{int64(2)}), []string{"Balance"}))
if e != nil {
return e
}
if bf <= 0 {
return nil
}
bf--
bb++
return tx.BufferWrite([]*Mutation{
Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(1), bf}),
Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(2), bb}),
})
}, txOpts)
if err != nil {
t.Fatalf("Failed to execute transaction: %v", err)
}
if got, want := resp.CommitStats.MutationCount, int64(8); got != want {
t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want)
}
}

func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1056,6 +1122,88 @@ func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) {
}
}

func TestIntegration_ReadWriteTransaction_StatementBasedWithOptions(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
defer cleanup()

// Set up two accounts
accounts := []*Mutation{
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
}
if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil {
t.Fatal(err)
}

getBalance := func(txn *ReadWriteStmtBasedTransaction, key Key) (int64, error) {
row, err := txn.ReadRow(ctx, "Accounts", key, []string{"Balance"})
if err != nil {
return 0, err
}
var balance int64
if err := row.Column(0, &balance); err != nil {
return 0, err
}
return balance, nil
}

statements := func(txn *ReadWriteStmtBasedTransaction) error {
outBalance, err := getBalance(txn, Key{1})
if err != nil {
return err
}
const transferAmt = 20
if outBalance >= transferAmt {
inBalance, err := getBalance(txn, Key{2})
if err != nil {
return err
}
inBalance += transferAmt
outBalance -= transferAmt
cols := []string{"AccountId", "Balance"}
txn.BufferWrite([]*Mutation{
Update("Accounts", cols, []interface{}{1, outBalance}),
Update("Accounts", cols, []interface{}{2, inBalance}),
})
}
return nil
}

var resp CommitResponse
txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}}
for {
tx, err := NewReadWriteStmtBasedTransactionWithOptions(ctx, client, txOpts)
if err != nil {
t.Fatalf("failed to begin a transaction: %v", err)
}
err = statements(tx)
if err != nil && status.Code(err) != codes.Aborted {
tx.Rollback(ctx)
t.Fatalf("failed to execute statements: %v", err)
} else if err == nil {
resp, err = tx.CommitWithReturnResp(ctx)
if err == nil {
break
} else if status.Code(err) != codes.Aborted {
t.Fatalf("failed to commit a transaction: %v", err)
}
}
// Set a default sleep time if the server delay is absent.
delay := 10 * time.Millisecond
if serverDelay, hasServerDelay := ExtractRetryDelay(err); hasServerDelay {
delay = serverDelay
}
time.Sleep(delay)
}
if got, want := resp.CommitStats.MutationCount, int64(8); got != want {
t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want)
}
}

func TestIntegration_Reads(t *testing.T) {
t.Parallel()

Expand Down
8 changes: 7 additions & 1 deletion spanner/internal/testutil/inmem_spanner_server.go
Expand Up @@ -984,7 +984,13 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
}
s.removeTransaction(tx)
return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil
resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
if req.ReturnCommitStats {
resp.CommitStats = &spannerpb.CommitResponse_CommitStats{
MutationCount: int64(1),
}
}
return resp, nil
}

func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
Expand Down
35 changes: 30 additions & 5 deletions spanner/transaction.go
Expand Up @@ -72,10 +72,14 @@ type txReadOnly struct {

// qo provides options for executing a sql query.
qo QueryOptions

// txOpts provides options for a transaction.
txOpts TransactionOptions
}

// TransactionOptions provides options for a transaction.
type TransactionOptions struct {
CommitOptions CommitOptions
}

// errSessionClosed returns error for using a recycled/destroyed session
Expand Down Expand Up @@ -962,11 +966,18 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error {
type CommitResponse struct {
// CommitTs is the commit time for a transaction.
CommitTs time.Time
// CommitStats is the commit statistics for a transaction.
CommitStats *sppb.CommitResponse_CommitStats
}

// CommitOptions provides options for commiting a transaction in a database.
type CommitOptions struct {
ReturnCommitStats bool
}

// commit tries to commit a readwrite transaction to Cloud Spanner. It also
// returns the commit response for the transactions.
func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, error) {
func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions) (CommitResponse, error) {
resp := CommitResponse{}
t.mu.Lock()
t.state = txClosed // No further operations after commit.
Expand All @@ -987,14 +998,18 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, erro
Transaction: &sppb.CommitRequest_TransactionId{
TransactionId: t.tx,
},
Mutations: mPb,
Mutations: mPb,
ReturnCommitStats: options.ReturnCommitStats,
})
if e != nil {
return resp, toSpannerErrorWithCommitInfo(e, true)
}
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
}
if options.ReturnCommitStats {
resp.CommitStats = res.CommitStats
}
if isSessionNotFoundError(err) {
t.sh.destroy()
}
Expand Down Expand Up @@ -1032,7 +1047,7 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont
)
if err = f(context.WithValue(ctx, transactionInProgressKey{}, 1), t); err == nil {
// Try to commit if transaction body returns no error.
resp, err = t.commit(ctx)
resp, err = t.commit(ctx, t.txOpts.CommitOptions)
errDuringCommit = err != nil
}
if err != nil {
Expand Down Expand Up @@ -1069,6 +1084,8 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont
type ReadWriteStmtBasedTransaction struct {
// ReadWriteTransaction contains methods for performing transactional reads.
ReadWriteTransaction

options TransactionOptions
}

// NewReadWriteStmtBasedTransaction starts a read-write transaction. Commit() or
Expand Down Expand Up @@ -1114,6 +1131,7 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client,
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txOpts = options

if err = t.begin(ctx); err != nil {
if sh != nil {
Expand All @@ -1127,15 +1145,22 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client,
// Commit tries to commit a readwrite transaction to Cloud Spanner. It also
// returns the commit timestamp for the transactions.
func (t *ReadWriteStmtBasedTransaction) Commit(ctx context.Context) (time.Time, error) {
resp, err := t.commit(ctx)
resp, err := t.CommitWithReturnResp(ctx)
return resp.CommitTs, err
}

// CommitWithReturnResp tries to commit a readwrite transaction. It also returns
// the commit timestamp and stats for the transactions.
func (t *ReadWriteStmtBasedTransaction) CommitWithReturnResp(ctx context.Context) (CommitResponse, error) {
resp, err := t.commit(ctx, t.txOpts.CommitOptions)
// Rolling back an aborted transaction is not necessary.
if err != nil && status.Code(err) != codes.Aborted {
t.rollback(ctx)
}
if t.sh != nil {
t.sh.recycle()
}
return resp.CommitTs, err
return resp, err
}

// Rollback is called to cancel the ongoing transaction that has not been
Expand Down
56 changes: 56 additions & 0 deletions spanner/transaction_test.go
Expand Up @@ -510,6 +510,62 @@ func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]S
return rowCount, attempts, err
}

func TestReadWriteStmtBasedTransactionWithOptions(t *testing.T) {
t.Parallel()

_, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()

f := func(tx *ReadWriteStmtBasedTransaction) (int64, error) {
iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
rowCount := int64(0)
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return 0, err
}
var singerID, albumID int64
var albumTitle string
if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
return 0, err
}
rowCount++
}
return rowCount, nil
}

var resp CommitResponse
for {
tx, err := NewReadWriteStmtBasedTransactionWithOptions(
ctx,
client,
TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}},
)
_, err = f(tx)
if err != nil && status.Code(err) != codes.Aborted {
tx.Rollback(ctx)
break
} else if err == nil {
resp, err = tx.CommitWithReturnResp(ctx)
break
}
// Set a default sleep time if the server delay is absent.
delay := 10 * time.Millisecond
if serverDelay, hasServerDelay := ExtractRetryDelay(err); hasServerDelay {
delay = serverDelay
}
time.Sleep(delay)
}
if got, want := resp.CommitStats.MutationCount, int64(1); got != want {
t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want)
}
}

func TestBatchDML_StatementBased_WithMultipleDML(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down