Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner): support CommitStats #3444

Merged
merged 16 commits into from Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions spanner/client.go
Expand Up @@ -459,6 +459,8 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txOpts = options

trace.TracePrintf(ctx, map[string]interface{}{"transactionID": string(sh.getTransactionID())},
"Starting transaction attempt")
if err = t.begin(ctx); err != nil {
Expand Down
75 changes: 75 additions & 0 deletions spanner/client_test.go
Expand Up @@ -829,6 +829,81 @@ func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) {
}
}

func TestClient_ReadWriteTransactionWithOptions(t *testing.T) {
_, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()
resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
rowCount := int64(0)
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
var singerID, albumID int64
var albumTitle string
if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
return err
}
rowCount++
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
return nil
}, TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}})
if err != nil {
t.Fatalf("Failed to execute the transaction: %s", err)
}
if got, want := resp.CommitStats.MutationCount, int64(1); got != want {
t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want)
}
}

func TestClient_ReadWriteStmtBasedTransactionWithOptions(t *testing.T) {
_, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()
tx, err := NewReadWriteStmtBasedTransactionWithOptions(ctx, client, TransactionOptions{CommitOptions{ReturnCommitStats: true}})
if err != nil {
t.Fatalf("Unexpected error when creating transaction: %v", err)
}

iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
rowCount := int64(0)
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatalf("Unexpected error when fetching query results: %v", err)
}
var singerID, albumID int64
var albumTitle string
if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
t.Fatalf("Unexpected error when getting query data: %v", err)
}
rowCount++
}
resp, err := tx.CommitWithReturnResp(ctx)
if err != nil {
t.Fatalf("Unexpected error when committing transaction: %v", err)
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
t.Errorf("Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
if got, want := resp.CommitStats.MutationCount, int64(1); got != want {
t.Fatalf("Mismatch mutation count - got: %d, want: %d", got, want)
}
}

func TestClient_ReadWriteTransaction_DoNotLeakSessionOnPanic(t *testing.T) {
// Make sure that there is always only one session in the pool.
sc := SessionPoolConfig{
Expand Down
157 changes: 157 additions & 0 deletions spanner/integration_test.go
Expand Up @@ -1021,8 +1021,79 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) {
verifyDirectPathRemoteAddress(t)
}

// Test ReadWriteTransactionWithOptions.
func TestIntegration_ReadWriteTransactionWithOptions(t *testing.T) {
t.Parallel()
skipEmulatorTest(t)

// Give a longer deadline because of transaction backoffs.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
defer cleanup()

// Set up two accounts
accounts := []*Mutation{
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
}
if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil {
t.Fatal(err)
}

readBalance := func(iter *RowIterator) (int64, error) {
defer iter.Stop()
var bal int64
for {
row, err := iter.Next()
if err == iterator.Done {
return bal, nil
}
if err != nil {
return 0, err
}
if err := row.Column(0, &bal); err != nil {
return 0, err
}
}
}

txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}}
resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
// Query Foo's balance and Bar's balance.
bf, e := readBalance(tx.Query(ctx,
Statement{"SELECT Balance FROM Accounts WHERE AccountId = @id", map[string]interface{}{"id": int64(1)}}))
if e != nil {
return e
}
bb, e := readBalance(tx.Read(ctx, "Accounts", KeySets(Key{int64(2)}), []string{"Balance"}))
if e != nil {
return e
}
if bf <= 0 {
return nil
}
bf--
bb++
return tx.BufferWrite([]*Mutation{
Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(1), bf}),
Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(2), bb}),
})
}, txOpts)
if err != nil {
t.Fatalf("Failed to execute transaction: %v", err)
}
if resp.CommitStats == nil {
t.Fatal("Missing commit stats in commit response")
}
if got, want := resp.CommitStats.MutationCount, int64(8); got != want {
t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want)
}
}

func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) {
t.Parallel()
skipEmulatorTest(t)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
Expand Down Expand Up @@ -1113,6 +1184,92 @@ func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) {
}
}

func TestIntegration_ReadWriteTransaction_StatementBasedWithOptions(t *testing.T) {
t.Parallel()
skipEmulatorTest(t)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
defer cleanup()

// Set up two accounts
accounts := []*Mutation{
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
}
if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil {
t.Fatal(err)
}

getBalance := func(txn *ReadWriteStmtBasedTransaction, key Key) (int64, error) {
row, err := txn.ReadRow(ctx, "Accounts", key, []string{"Balance"})
if err != nil {
return 0, err
}
var balance int64
if err := row.Column(0, &balance); err != nil {
return 0, err
}
return balance, nil
}

statements := func(txn *ReadWriteStmtBasedTransaction) error {
outBalance, err := getBalance(txn, Key{1})
if err != nil {
return err
}
const transferAmt = 20
if outBalance >= transferAmt {
inBalance, err := getBalance(txn, Key{2})
if err != nil {
return err
}
inBalance += transferAmt
outBalance -= transferAmt
cols := []string{"AccountId", "Balance"}
txn.BufferWrite([]*Mutation{
Update("Accounts", cols, []interface{}{1, outBalance}),
Update("Accounts", cols, []interface{}{2, inBalance}),
})
}
return nil
}

var resp CommitResponse
txOpts := TransactionOptions{CommitOptions{ReturnCommitStats: true}}
for {
tx, err := NewReadWriteStmtBasedTransactionWithOptions(ctx, client, txOpts)
if err != nil {
t.Fatalf("failed to begin a transaction: %v", err)
}
err = statements(tx)
if err != nil && status.Code(err) != codes.Aborted {
tx.Rollback(ctx)
t.Fatalf("failed to execute statements: %v", err)
} else if err == nil {
resp, err = tx.CommitWithReturnResp(ctx)
if err == nil {
break
} else if status.Code(err) != codes.Aborted {
t.Fatalf("failed to commit a transaction: %v", err)
}
}
// Set a default sleep time if the server delay is absent.
delay := 10 * time.Millisecond
if serverDelay, hasServerDelay := ExtractRetryDelay(err); hasServerDelay {
delay = serverDelay
}
time.Sleep(delay)
}
if resp.CommitStats == nil {
t.Fatal("Missing commit stats in commit response")
}
if got, want := resp.CommitStats.MutationCount, int64(8); got != want {
t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want)
}
}

func TestIntegration_Reads(t *testing.T) {
t.Parallel()

Expand Down
8 changes: 7 additions & 1 deletion spanner/internal/testutil/inmem_spanner_server.go
Expand Up @@ -984,7 +984,13 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
}
s.removeTransaction(tx)
return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil
resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
if req.ReturnCommitStats {
resp.CommitStats = &spannerpb.CommitResponse_CommitStats{
MutationCount: int64(1),
}
}
return resp, nil
}

func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
Expand Down
35 changes: 30 additions & 5 deletions spanner/transaction.go
Expand Up @@ -72,10 +72,14 @@ type txReadOnly struct {

// qo provides options for executing a sql query.
qo QueryOptions

// txOpts provides options for a transaction.
txOpts TransactionOptions
}

// TransactionOptions provides options for a transaction.
type TransactionOptions struct {
CommitOptions CommitOptions
}

// errSessionClosed returns error for using a recycled/destroyed session
Expand Down Expand Up @@ -962,11 +966,18 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error {
type CommitResponse struct {
// CommitTs is the commit time for a transaction.
CommitTs time.Time
// CommitStats is the commit statistics for a transaction.
CommitStats *sppb.CommitResponse_CommitStats
}

// CommitOptions provides options for commiting a transaction in a database.
type CommitOptions struct {
ReturnCommitStats bool
}

// commit tries to commit a readwrite transaction to Cloud Spanner. It also
// returns the commit response for the transactions.
func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, error) {
func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions) (CommitResponse, error) {
resp := CommitResponse{}
t.mu.Lock()
t.state = txClosed // No further operations after commit.
Expand All @@ -987,14 +998,18 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (CommitResponse, erro
Transaction: &sppb.CommitRequest_TransactionId{
TransactionId: t.tx,
},
Mutations: mPb,
Mutations: mPb,
ReturnCommitStats: options.ReturnCommitStats,
})
if e != nil {
return resp, toSpannerErrorWithCommitInfo(e, true)
}
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
}
if options.ReturnCommitStats {
resp.CommitStats = res.CommitStats
}
if isSessionNotFoundError(err) {
t.sh.destroy()
}
Expand Down Expand Up @@ -1032,7 +1047,7 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont
)
if err = f(context.WithValue(ctx, transactionInProgressKey{}, 1), t); err == nil {
// Try to commit if transaction body returns no error.
resp, err = t.commit(ctx)
resp, err = t.commit(ctx, t.txOpts.CommitOptions)
errDuringCommit = err != nil
}
if err != nil {
Expand Down Expand Up @@ -1069,6 +1084,8 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont
type ReadWriteStmtBasedTransaction struct {
// ReadWriteTransaction contains methods for performing transactional reads.
ReadWriteTransaction

options TransactionOptions
}

// NewReadWriteStmtBasedTransaction starts a read-write transaction. Commit() or
Expand Down Expand Up @@ -1114,6 +1131,7 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client,
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txOpts = options

if err = t.begin(ctx); err != nil {
if sh != nil {
Expand All @@ -1127,15 +1145,22 @@ func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client,
// Commit tries to commit a readwrite transaction to Cloud Spanner. It also
// returns the commit timestamp for the transactions.
func (t *ReadWriteStmtBasedTransaction) Commit(ctx context.Context) (time.Time, error) {
resp, err := t.commit(ctx)
resp, err := t.CommitWithReturnResp(ctx)
return resp.CommitTs, err
}

// CommitWithReturnResp tries to commit a readwrite transaction. It also returns
// the commit timestamp and stats for the transactions.
func (t *ReadWriteStmtBasedTransaction) CommitWithReturnResp(ctx context.Context) (CommitResponse, error) {
resp, err := t.commit(ctx, t.txOpts.CommitOptions)
// Rolling back an aborted transaction is not necessary.
if err != nil && status.Code(err) != codes.Aborted {
t.rollback(ctx)
}
if t.sh != nil {
t.sh.recycle()
}
return resp.CommitTs, err
return resp, err
}

// Rollback is called to cancel the ongoing transaction that has not been
Expand Down