From 7d6816f1fd14bcd2c7f91d814855b5d921ba970d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 18 May 2021 09:23:09 +0200 Subject: [PATCH] feat: add bufferAsync methods (#1145) * feat: add bufferAsync methods Adds bufferAsync methods to TransactionContext. The existing buffer methods were already non-blocking, but the async versions also return an ApiFuture, which make them easier to use when chaining multiple async calls together. Also changes some calls in the AsyncTransactionManagerTest to use lambdas instead of the test helper methods. Fixes #1126 * fix: do not take lock on async method * build: remove custom skip tests variable * test: add test for committing twice * fix: synchronize buffering and committing --- .../clirr-ignored-differences.xml | 13 + .../spanner/AsyncTransactionManager.java | 39 +- .../google/cloud/spanner/DatabaseClient.java | 50 +- .../com/google/cloud/spanner/SessionPool.java | 10 + .../cloud/spanner/TransactionContext.java | 10 + .../cloud/spanner/TransactionRunnerImpl.java | 60 ++- .../spanner/AsyncTransactionManagerTest.java | 462 +++++++++--------- .../spanner/TransactionContextImplTest.java | 112 ++++- .../cloud/spanner/TransactionContextTest.java | 144 ++++++ 9 files changed, 574 insertions(+), 326 deletions(-) create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextTest.java diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml index c6f4c1f3a9..c6a936c51c 100644 --- a/google-cloud-spanner/clirr-ignored-differences.xml +++ b/google-cloud-spanner/clirr-ignored-differences.xml @@ -605,4 +605,17 @@ com/google/cloud/spanner/StructReader com.google.cloud.spanner.Value getValue(java.lang.String) + + + + + 7012 + com/google/cloud/spanner/TransactionContext + com.google.api.core.ApiFuture bufferAsync(com.google.cloud.spanner.Mutation) + + + 7012 + com/google/cloud/spanner/TransactionContext + com.google.api.core.ApiFuture bufferAsync(java.lang.Iterable) + diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java index c648b567d7..391be3d190 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java @@ -18,6 +18,7 @@ import com.google.api.core.ApiFuture; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; import com.google.cloud.spanner.TransactionManager.TransactionState; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; @@ -98,31 +99,21 @@ Timestamp get(long timeout, TimeUnit unit) *

Example usage: * *

{@code
-   * TransactionContextFuture txnFuture = manager.beginAsync();
    * final String column = "FirstName";
-   * txnFuture.then(
-   *         new AsyncTransactionFunction() {
-   *           @Override
-   *           public ApiFuture apply(TransactionContext txn, Void input)
-   *               throws Exception {
-   *             return txn.readRowAsync(
-   *                 "Singers", Key.of(singerId), Collections.singleton(column));
-   *           }
-   *         })
-   *     .then(
-   *         new AsyncTransactionFunction() {
-   *           @Override
-   *           public ApiFuture apply(TransactionContext txn, Struct input)
-   *               throws Exception {
-   *             String name = input.getString(column);
-   *             txn.buffer(
-   *                 Mutation.newUpdateBuilder("Singers")
-   *                     .set(column)
-   *                     .to(name.toUpperCase())
-   *                     .build());
-   *             return ApiFutures.immediateFuture(null);
-   *           }
-   *         })
+   * final long singerId = 1L;
+   * AsyncTransactionManager manager = client.transactionManagerAsync();
+   * TransactionContextFuture txnFuture = manager.beginAsync();
+   * txnFuture
+   *   .then((transaction, ignored) ->
+   *     transaction.readRowAsync("Singers", Key.of(singerId), Collections.singleton(column)),
+   *     executor)
+   *   .then((transaction, row) ->
+   *     transaction.bufferAsync(
+   *         Mutation.newUpdateBuilder("Singers")
+   *           .set(column).to(row.getString(column).toUpperCase())
+   *           .build()),
+   *     executor)
+   *   .commitAsync();
    * }
*/ interface AsyncTransactionStep extends ApiFuture { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java index 799c80b707..60e8b6910c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java @@ -431,8 +431,7 @@ CommitResponse writeAtLeastOnceWithOptions( * lifecycle. This API is meant for advanced users. Most users should instead use the {@link * #runAsync()} API instead. * - *

Example of using {@link AsyncTransactionManager} with lambda expressions (Java 8 and - * higher). + *

Example of using {@link AsyncTransactionManager}. * *

{@code
    * long singerId = 1L;
@@ -449,56 +448,11 @@ CommitResponse writeAtLeastOnceWithOptions(
    *             .then(
    *                 (transaction, row) -> {
    *                   String name = row.getString(column);
-   *                   transaction.buffer(
+   *                   return transaction.bufferAsync(
    *                       Mutation.newUpdateBuilder("Singers")
    *                           .set(column)
    *                           .to(name.toUpperCase())
    *                           .build());
-   *                   return ApiFutures.immediateFuture(null);
-   *                 })
-   *             .commitAsync();
-   *     try {
-   *       commitTimestamp.get();
-   *       break;
-   *     } catch (AbortedException e) {
-   *       Thread.sleep(e.getRetryDelayInMillis());
-   *       transactionFuture = manager.resetForRetryAsync();
-   *     }
-   *   }
-   * }
-   * }
- * - *

Example of using {@link AsyncTransactionManager} (Java 7). - * - *

{@code
-   * final long singerId = 1L;
-   * try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-   *   TransactionContextFuture transactionFuture = manager.beginAsync();
-   *   while (true) {
-   *     final String column = "FirstName";
-   *     CommitTimestampFuture commitTimestamp =
-   *         transactionFuture.then(
-   *                 new AsyncTransactionFunction() {
-   *                   @Override
-   *                   public ApiFuture apply(TransactionContext transaction, Void input)
-   *                       throws Exception {
-   *                     return transaction.readRowAsync(
-   *                         "Singers", Key.of(singerId), Collections.singleton(column));
-   *                   }
-   *                 })
-   *             .then(
-   *                 new AsyncTransactionFunction() {
-   *                   @Override
-   *                   public ApiFuture apply(TransactionContext transaction, Struct input)
-   *                       throws Exception {
-   *                     String name = input.getString(column);
-   *                     transaction.buffer(
-   *                         Mutation.newUpdateBuilder("Singers")
-   *                             .set(column)
-   *                             .to(name.toUpperCase())
-   *                             .build());
-   *                     return ApiFutures.immediateFuture(null);
-   *                   }
    *                 })
    *             .commitAsync();
    *     try {
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
index 47f2c33899..fbfc472bf5 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
@@ -675,6 +675,11 @@ public void buffer(Mutation mutation) {
       delegate.buffer(mutation);
     }
 
+    @Override
+    public ApiFuture bufferAsync(Mutation mutation) {
+      return delegate.bufferAsync(mutation);
+    }
+
     @Override
     public Struct readRowUsingIndex(String table, String index, Key key, Iterable columns) {
       try {
@@ -703,6 +708,11 @@ public void buffer(Iterable mutations) {
       delegate.buffer(mutations);
     }
 
+    @Override
+    public ApiFuture bufferAsync(Iterable mutations) {
+      return delegate.bufferAsync(mutations);
+    }
+
     @Override
     public long executeUpdate(Statement statement, UpdateOption... options) {
       try {
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java
index 64c45b12c0..2590d5b309 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java
@@ -91,6 +91,11 @@ public interface TransactionContext extends ReadContext {
    */
   void buffer(Mutation mutation);
 
+  /** Same as {@link #buffer(Mutation)}, but is guaranteed to be non-blocking. */
+  default ApiFuture bufferAsync(Mutation mutation) {
+    throw new UnsupportedOperationException("method should be overwritten");
+  }
+
   /**
    * Buffers mutations to be applied if the transaction commits successfully. The effects of the
    * mutations will not be visible to subsequent operations in the transaction. All buffered
@@ -98,6 +103,11 @@ public interface TransactionContext extends ReadContext {
    */
   void buffer(Iterable mutations);
 
+  /** Same as {@link #buffer(Iterable)}, but is guaranteed to be non-blocking. */
+  default ApiFuture bufferAsync(Iterable mutations) {
+    throw new UnsupportedOperationException("method should be overwritten");
+  }
+
   /**
    * Executes the DML statement(s) and returns the number of rows modified. For non-DML statements,
    * it will result in an {@code IllegalArgumentException}. The effects of the DML statement will be
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
index 2484b9d3c6..e04dace003 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
@@ -54,7 +54,9 @@
 import io.opencensus.trace.Tracing;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Queue;
 import java.util.concurrent.Callable;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
 import java.util.concurrent.TimeUnit;
@@ -75,6 +77,9 @@ class TransactionRunnerImpl implements SessionTransaction, TransactionRunner {
    */
   private static final String TRANSACTION_CANCELLED_MESSAGE = "invalidated by a later transaction";
 
+  private static final String TRANSACTION_ALREADY_COMMITTED_MESSAGE =
+      "Transaction has already committed";
+
   @VisibleForTesting
   static class TransactionContextImpl extends AbstractReadContext implements TransactionContext {
     static class Builder extends AbstractReadContext.Builder {
@@ -146,7 +151,9 @@ public void removeListener(Runnable listener) {
       }
     }
 
-    @GuardedBy("lock")
+    private final Object committingLock = new Object();
+
+    @GuardedBy("committingLock")
     private volatile boolean committing;
 
     @GuardedBy("lock")
@@ -155,8 +162,7 @@ public void removeListener(Runnable listener) {
     @GuardedBy("lock")
     private volatile int runningAsyncOperations;
 
-    @GuardedBy("lock")
-    private List mutations = new ArrayList<>();
+    private final Queue mutations = new ConcurrentLinkedQueue<>();
 
     @GuardedBy("lock")
     private boolean aborted;
@@ -280,6 +286,16 @@ void commit() {
     volatile ApiFuture commitFuture;
 
     ApiFuture commitAsync() {
+      List mutationsProto = new ArrayList<>();
+      synchronized (committingLock) {
+        if (committing) {
+          throw new IllegalStateException(TRANSACTION_ALREADY_COMMITTED_MESSAGE);
+        }
+        committing = true;
+        if (!mutations.isEmpty()) {
+          Mutation.toProto(mutations, mutationsProto);
+        }
+      }
       final SettableApiFuture res = SettableApiFuture.create();
       final SettableApiFuture finishOps;
       CommitRequest.Builder builder =
@@ -303,14 +319,8 @@ ApiFuture commitAsync() {
         } else {
           finishOps = finishedAsyncOperations;
         }
-        if (!mutations.isEmpty()) {
-          List mutationsProto = new ArrayList<>();
-          Mutation.toProto(mutations, mutationsProto);
-          builder.addAllMutations(mutationsProto);
-        }
-        // Ensure that no call to buffer mutations that would be lost can succeed.
-        mutations = null;
       }
+      builder.addAllMutations(mutationsProto);
       finishOps.addListener(
           new CommitRunnable(res, finishOps, builder), MoreExecutors.directExecutor());
       return res;
@@ -603,22 +613,44 @@ public void onDone(boolean withBeginTransaction) {
 
     @Override
     public void buffer(Mutation mutation) {
-      synchronized (lock) {
-        checkNotNull(mutations, "Context is closed");
+      synchronized (committingLock) {
+        if (committing) {
+          throw new IllegalStateException(TRANSACTION_ALREADY_COMMITTED_MESSAGE);
+        }
         mutations.add(checkNotNull(mutation));
       }
     }
 
+    @Override
+    public ApiFuture bufferAsync(Mutation mutation) {
+      // Normally, we would call the async method from the sync method, but this is also safe as
+      // both are non-blocking anyways, and this prevents the creation of an ApiFuture that is not
+      // really used when the sync method is called.
+      buffer(mutation);
+      return ApiFutures.immediateFuture(null);
+    }
+
     @Override
     public void buffer(Iterable mutations) {
-      synchronized (lock) {
-        checkNotNull(this.mutations, "Context is closed");
+      synchronized (committingLock) {
+        if (committing) {
+          throw new IllegalStateException(TRANSACTION_ALREADY_COMMITTED_MESSAGE);
+        }
         for (Mutation mutation : mutations) {
           this.mutations.add(checkNotNull(mutation));
         }
       }
     }
 
+    @Override
+    public ApiFuture bufferAsync(Iterable mutations) {
+      // Normally, we would call the async method from the sync method, but this is also safe as
+      // both are non-blocking anyways, and this prevents the creation of an ApiFuture that is not
+      // really used when the sync method is called.
+      buffer(mutations);
+      return ApiFutures.immediateFuture(null);
+    }
+
     @Override
     public long executeUpdate(Statement statement, UpdateOption... options) {
       beforeReadOrQuery();
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
index 345e58b7cb..de3e267434 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
@@ -197,13 +197,15 @@ public void asyncTransactionManager_shouldRollbackOnCloseAsync() throws Exceptio
   public void testAsyncTransactionManager_returnsCommitStats() throws Exception {
     try (AsyncTransactionManager manager =
         client().transactionManagerAsync(Options.commitStats())) {
-      TransactionContextFuture transaction = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
           CommitTimestampFuture commitTimestamp =
-              transaction
+              transactionContextFuture
                   .then(
-                      AsyncTransactionManagerHelper.buffer(Mutation.delete("FOO", Key.of("foo"))),
+                      (transactionContext, ignored) ->
+                          transactionContext.bufferAsync(
+                              Collections.singleton(Mutation.delete("FOO", Key.of("foo")))),
                       executor)
                   .commitAsync();
           assertNotNull(commitTimestamp.get());
@@ -212,7 +214,7 @@ public void testAsyncTransactionManager_returnsCommitStats() throws Exception {
           assertEquals(1L, manager.getCommitResponse().get().getCommitStats().getMutationCount());
           break;
         } catch (AbortedException e) {
-          transaction = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -220,23 +222,21 @@ public void testAsyncTransactionManager_returnsCommitStats() throws Exception {
 
   @Test
   public void asyncTransactionManagerUpdate() throws Exception {
-    final SettableApiFuture updateCount = SettableApiFuture.create();
-
     try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture commitTimestamp =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(
-                          updateCount, UPDATE_STATEMENT),
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCount =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext.executeUpdateAsync(UPDATE_STATEMENT),
+                  executor);
+          CommitTimestampFuture commitTimestamp = updateCount.commitAsync();
           assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT);
           assertThat(commitTimestamp.get()).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -244,25 +244,23 @@ public void asyncTransactionManagerUpdate() throws Exception {
 
   @Test
   public void asyncTransactionManagerIsNonBlocking() throws Exception {
-    SettableApiFuture updateCount = SettableApiFuture.create();
-
     mockSpanner.freeze();
     try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture commitTimestamp =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(
-                          updateCount, UPDATE_STATEMENT),
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCount =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext.executeUpdateAsync(UPDATE_STATEMENT),
+                  executor);
+          CommitTimestampFuture commitTimestamp = updateCount.commitAsync();
           mockSpanner.unfreeze();
           assertThat(updateCount.get(10L, TimeUnit.SECONDS)).isEqualTo(UPDATE_COUNT);
           assertThat(commitTimestamp.get(10L, TimeUnit.SECONDS)).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -271,9 +269,10 @@ public void asyncTransactionManagerIsNonBlocking() throws Exception {
   @Test
   public void asyncTransactionManagerInvalidUpdate() throws Exception {
     try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       CommitTimestampFuture commitTimestamp =
-          txn.then(
+          transactionContextFuture
+              .then(
                   (transaction, ignored) ->
                       transaction.executeUpdateAsync(INVALID_UPDATE_STATEMENT),
                   executor)
@@ -286,33 +285,31 @@ public void asyncTransactionManagerInvalidUpdate() throws Exception {
 
   @Test
   public void asyncTransactionManagerCommitAborted() throws Exception {
-    SettableApiFuture updateCount = SettableApiFuture.create();
     final AtomicInteger attempt = new AtomicInteger();
     try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
           attempt.incrementAndGet();
-          CommitTimestampFuture commitTimestamp =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(
-                          updateCount, UPDATE_STATEMENT),
-                      executor)
-                  .then(
-                      (transaction, ignored) -> {
-                        if (attempt.get() == 1) {
-                          mockSpanner.abortTransaction(transaction);
-                        }
-                        return ApiFutures.immediateFuture(null);
-                      },
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCount =
+              transactionContextFuture.then(
+                  (transaction, ignored) -> transaction.executeUpdateAsync(UPDATE_STATEMENT),
+                  executor);
+          updateCount.then(
+              (transaction, ignored) -> {
+                if (attempt.get() == 1) {
+                  mockSpanner.abortTransaction(transaction);
+                }
+                return ApiFutures.immediateFuture(null);
+              },
+              executor);
+          CommitTimestampFuture commitTimestamp = updateCount.commitAsync();
           assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT);
           assertThat(commitTimestamp.get()).isNotNull();
           assertThat(attempt.get()).isEqualTo(2);
           break;
         } catch (AbortedException e) {
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -320,42 +317,26 @@ public void asyncTransactionManagerCommitAborted() throws Exception {
 
   @Test
   public void asyncTransactionManagerFireAndForgetInvalidUpdate() throws Exception {
-    final SettableApiFuture updateCount = SettableApiFuture.create();
-
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
-                      (transaction, ignored) -> {
-                        // This fire-and-forget update statement should not fail the transaction.
-                        // The exception will however cause the transaction to be retried, as the
-                        // statement will not return a transaction id.
-                        transaction.executeUpdateAsync(INVALID_UPDATE_STATEMENT);
-                        ApiFutures.addCallback(
-                            transaction.executeUpdateAsync(UPDATE_STATEMENT),
-                            new ApiFutureCallback() {
-                              @Override
-                              public void onFailure(Throwable t) {
-                                updateCount.setException(t);
-                              }
-
-                              @Override
-                              public void onSuccess(Long result) {
-                                updateCount.set(result);
-                              }
-                            },
-                            MoreExecutors.directExecutor());
-                        return updateCount;
-                      },
-                      executor)
-                  .commitAsync();
-          assertThat(ts.get()).isNotNull();
-          assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT);
+          AsyncTransactionStep transaction =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) -> {
+                    // This fire-and-forget update statement should not fail the transaction.
+                    // The exception will however cause the transaction to be retried, as the
+                    // statement will not return a transaction id.
+                    transactionContext.executeUpdateAsync(INVALID_UPDATE_STATEMENT);
+                    return transactionContext.executeUpdateAsync(UPDATE_STATEMENT);
+                  },
+                  executor);
+          CommitTimestampFuture commitTimestamp = transaction.commitAsync();
+          assertThat(commitTimestamp.get()).isNotNull();
+          assertThat(transaction.get()).isEqualTo(UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -375,15 +356,19 @@ public void onSuccess(Long result) {
 
   @Test
   public void asyncTransactionManagerChain() throws Exception {
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(AsyncTransactionManagerHelper.executeUpdateAsync(UPDATE_STATEMENT), executor)
+          CommitTimestampFuture commitTimestamp =
+              transactionContextFuture
+                  .then(
+                      (transaction, ignored) -> transaction.executeUpdateAsync(UPDATE_STATEMENT),
+                      executor)
                   .then(
-                      AsyncTransactionManagerHelper.readRowAsync(
-                          READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES),
+                      (transactionContext, ignored) ->
+                          transactionContext.readRowAsync(
+                              READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES),
                       executor)
                   .then(
                       (ignored, input) -> ApiFutures.immediateFuture(input.getString("Value")),
@@ -395,10 +380,10 @@ public void asyncTransactionManagerChain() throws Exception {
                       },
                       executor)
                   .commitAsync();
-          assertThat(ts.get()).isNotNull();
+          assertThat(commitTimestamp.get()).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -406,13 +391,15 @@ public void asyncTransactionManagerChain() throws Exception {
 
   @Test
   public void asyncTransactionManagerChainWithErrorInTheMiddle() throws Exception {
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(INVALID_UPDATE_STATEMENT),
+          CommitTimestampFuture commitTimestampFuture =
+              transactionContextFuture
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.executeUpdateAsync(INVALID_UPDATE_STATEMENT),
                       executor)
                   .then(
                       (ignored1, ignored2) -> {
@@ -420,16 +407,12 @@ public void asyncTransactionManagerChainWithErrorInTheMiddle() throws Exception
                       },
                       executor)
                   .commitAsync();
-          ts.get();
+          SpannerException e =
+              assertThrows(SpannerException.class, () -> get(commitTimestampFuture));
+          assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
-        } catch (ExecutionException e) {
-          mgr.rollbackAsync();
-          assertThat(e.getCause()).isInstanceOf(SpannerException.class);
-          SpannerException se = (SpannerException) e.getCause();
-          assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT);
-          break;
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -437,16 +420,17 @@ public void asyncTransactionManagerChainWithErrorInTheMiddle() throws Exception
 
   @Test
   public void asyncTransactionManagerUpdateAborted() throws Exception {
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
       // Temporarily set the result of the update to 2 rows.
       mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L));
       final AtomicInteger attempt = new AtomicInteger();
 
-      TransactionContextFuture txn = mgr.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
+          CommitTimestampFuture commitTimestampFuture =
+              transactionContextFuture
+                  .then(
                       (ignored1, ignored2) -> {
                         if (attempt.incrementAndGet() == 1) {
                           // Abort the first attempt.
@@ -460,12 +444,14 @@ public void asyncTransactionManagerUpdateAborted() throws Exception {
                       },
                       executor)
                   .then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(UPDATE_STATEMENT), executor)
+                      (transactionContext, ignored) ->
+                          transactionContext.executeUpdateAsync(UPDATE_STATEMENT),
+                      executor)
                   .commitAsync();
-          assertThat(ts.get()).isNotNull();
+          assertThat(commitTimestampFuture.get()).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
       assertThat(attempt.get()).isEqualTo(2);
@@ -477,12 +463,13 @@ public void asyncTransactionManagerUpdateAborted() throws Exception {
   @Test
   public void asyncTransactionManagerUpdateAbortedWithoutGettingResult() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
+          CommitTimestampFuture commitTimestampFuture =
+              transactionContextFuture
+                  .then(
                       (transaction, ignored) -> {
                         if (attempt.incrementAndGet() == 1) {
                           mockSpanner.abortNextStatement();
@@ -498,7 +485,7 @@ public void asyncTransactionManagerUpdateAbortedWithoutGettingResult() throws Ex
                       },
                       executor)
                   .commitAsync();
-          assertThat(ts.get()).isNotNull();
+          assertThat(commitTimestampFuture.get()).isNotNull();
           assertThat(attempt.get()).isEqualTo(2);
           // The server may receive 1 or 2 commit requests depending on whether the call to
           // commitAsync() already knows that the transaction has aborted. If it does, it will not
@@ -513,7 +500,7 @@ public void asyncTransactionManagerUpdateAbortedWithoutGettingResult() throws Ex
                   CommitRequest.class);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -571,45 +558,45 @@ public void asyncTransactionManagerWaitsUntilAsyncUpdateHasFinished() throws Exc
 
   @Test
   public void asyncTransactionManagerBatchUpdate() throws Exception {
-    final SettableApiFuture result = SettableApiFuture.create();
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  AsyncTransactionManagerHelper.batchUpdateAsync(
-                      result, UPDATE_STATEMENT, UPDATE_STATEMENT),
-                  executor)
-              .commitAsync()
-              .get();
+          AsyncTransactionStep updateCounts =
+              transactionContextFuture.then(
+                  (transaction, ignored) ->
+                      transaction.batchUpdateAsync(
+                          ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
+                  executor);
+          get(updateCounts.commitAsync());
+          assertThat(get(updateCounts)).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
   }
 
   @Test
   public void asyncTransactionManagerIsNonBlockingWithBatchUpdate() throws Exception {
-    SettableApiFuture res = SettableApiFuture.create();
     mockSpanner.freeze();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
-                      AsyncTransactionManagerHelper.batchUpdateAsync(res, UPDATE_STATEMENT),
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCounts =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext.batchUpdateAsync(Collections.singleton(UPDATE_STATEMENT)),
+                  executor);
+          CommitTimestampFuture commitTimestampFuture = updateCounts.commitAsync();
           mockSpanner.unfreeze();
-          assertThat(ts.get()).isNotNull();
-          assertThat(res.get()).asList().containsExactly(UPDATE_COUNT);
+          assertThat(commitTimestampFuture.get()).isNotNull();
+          assertThat(updateCounts.get()).asList().containsExactly(UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -617,17 +604,18 @@ public void asyncTransactionManagerIsNonBlockingWithBatchUpdate() throws Excepti
 
   @Test
   public void asyncTransactionManagerInvalidBatchUpdate() throws Exception {
-    SettableApiFuture result = SettableApiFuture.create();
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       SpannerException e =
           assertThrows(
               SpannerException.class,
               () ->
                   get(
-                      txn.then(
-                              AsyncTransactionManagerHelper.batchUpdateAsync(
-                                  result, UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT),
+                      transactionContextFuture
+                          .then(
+                              (transactionContext, ignored) ->
+                                  transactionContext.batchUpdateAsync(
+                                      ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT)),
                               executor)
                           .commitAsync()));
       assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT);
@@ -637,31 +625,32 @@ public void asyncTransactionManagerInvalidBatchUpdate() throws Exception {
 
   @Test
   public void asyncTransactionManagerFireAndForgetInvalidBatchUpdate() throws Exception {
-    SettableApiFuture result = SettableApiFuture.create();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
-                    transaction.batchUpdateAsync(
-                        ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT));
-                    return ApiFutures.immediateFuture(null);
-                  },
-                  executor)
-              .then(
-                  AsyncTransactionManagerHelper.batchUpdateAsync(
-                      result, UPDATE_STATEMENT, UPDATE_STATEMENT),
-                  executor)
-              .commitAsync()
-              .get();
+          AsyncTransactionStep updateCounts =
+              transactionContextFuture
+                  .then(
+                      (transactionContext, ignored) -> {
+                        transactionContext.batchUpdateAsync(
+                            ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT));
+                        return ApiFutures.immediateFuture(null);
+                      },
+                      executor)
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.batchUpdateAsync(
+                              ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
+                      executor);
+          updateCounts.commitAsync().get();
+          assertThat(updateCounts.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
     assertThat(mockSpanner.getRequestTypes())
         .containsExactly(
             BatchCreateSessionsRequest.class,
@@ -673,11 +662,12 @@ public void asyncTransactionManagerFireAndForgetInvalidBatchUpdate() throws Exce
   @Test
   public void asyncTransactionManagerBatchUpdateAborted() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
+          transactionContextFuture
+              .then(
                   (transaction, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
                       return transaction.batchUpdateAsync(
@@ -692,7 +682,7 @@ public void asyncTransactionManagerBatchUpdateAborted() throws Exception {
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -711,16 +701,17 @@ public void asyncTransactionManagerBatchUpdateAborted() throws Exception {
   @Test
   public void asyncTransactionManagerBatchUpdateAbortedBeforeFirstStatement() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
+          transactionContextFuture
+              .then(
+                  (transactionContext, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
                       mockSpanner.abortNextStatement();
                     }
-                    return transaction.batchUpdateAsync(
+                    return transactionContext.batchUpdateAsync(
                         ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT));
                   },
                   executor)
@@ -728,7 +719,7 @@ public void asyncTransactionManagerBatchUpdateAbortedBeforeFirstStatement() thro
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -746,28 +737,30 @@ public void asyncTransactionManagerBatchUpdateAbortedBeforeFirstStatement() thro
 
   @Test
   public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Exception {
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
       // Temporarily set the result of the update to 2 rows.
       mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L));
       final AtomicInteger attempt = new AtomicInteger();
-      TransactionContextFuture txn = mgr.beginAsync();
+      TransactionContextFuture txn = manager.beginAsync();
       while (true) {
-        final SettableApiFuture result = SettableApiFuture.create();
         try {
-          txn.then(
-                  (ignored1, ignored2) -> {
-                    if (attempt.get() > 0) {
-                      // Set the result of the update statement back to 1 row.
-                      mockSpanner.putStatementResult(
-                          StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
-                    }
-                    return ApiFutures.immediateFuture(null);
-                  },
-                  executor)
-              .then(
-                  AsyncTransactionManagerHelper.batchUpdateAsync(
-                      result, UPDATE_STATEMENT, UPDATE_STATEMENT),
-                  executor)
+          AsyncTransactionStep updateCounts =
+              txn.then(
+                      (ignored1, ignored2) -> {
+                        if (attempt.get() > 0) {
+                          // Set the result of the update statement back to 1 row.
+                          mockSpanner.putStatementResult(
+                              StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
+                        }
+                        return ApiFutures.immediateFuture(null);
+                      },
+                      executor)
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.batchUpdateAsync(
+                              ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
+                      executor);
+          updateCounts
               .then(
                   (transaction, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
@@ -778,11 +771,11 @@ public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Excepti
                   executor)
               .commitAsync()
               .get();
-          assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
+          assertThat(updateCounts.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
           assertThat(attempt.get()).isEqualTo(2);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          txn = manager.resetForRetryAsync();
         }
       }
     } finally {
@@ -801,12 +794,13 @@ public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Excepti
   @Test
   public void asyncTransactionManagerBatchUpdateAbortedWithoutGettingResult() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
+          transactionContextFuture
+              .then(
+                  (transactionContext, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
                       mockSpanner.abortNextStatement();
                     }
@@ -816,7 +810,7 @@ public void asyncTransactionManagerBatchUpdateAbortedWithoutGettingResult() thro
                     // directly in the transaction manager if the ABORTED error has already been
                     // returned by the batch update call before the commit call starts.
                     // Otherwise, the backend will return an ABORTED error for the commit call.
-                    transaction.batchUpdateAsync(
+                    transactionContext.batchUpdateAsync(
                         ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT));
                     return ApiFutures.immediateFuture(null);
                   },
@@ -825,7 +819,7 @@ public void asyncTransactionManagerBatchUpdateAbortedWithoutGettingResult() thro
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -860,16 +854,18 @@ public void asyncTransactionManagerWithBatchUpdateCommitFails() throws Exception
             Status.RESOURCE_EXHAUSTED
                 .withDescription("mutation limit exceeded")
                 .asRuntimeException()));
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       SpannerException e =
           assertThrows(
               SpannerException.class,
               () ->
                   get(
-                      txn.then(
-                              AsyncTransactionManagerHelper.batchUpdateAsync(
-                                  UPDATE_STATEMENT, UPDATE_STATEMENT),
+                      transactionContextFuture
+                          .then(
+                              (transactionContext, ignored) ->
+                                  transactionContext.batchUpdateAsync(
+                                      ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
                               executor)
                           .commitAsync()));
       assertThat(e.getErrorCode()).isEqualTo(ErrorCode.RESOURCE_EXHAUSTED);
@@ -882,13 +878,14 @@ public void asyncTransactionManagerWithBatchUpdateCommitFails() throws Exception
 
   @Test
   public void asyncTransactionManagerWaitsUntilAsyncBatchUpdateHasFinished() throws Exception {
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
-                    transaction.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT));
+          transactionContextFuture
+              .then(
+                  (transactionContext, ignored) -> {
+                    transactionContext.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT));
                     return ApiFutures.immediateFuture(null);
                   },
                   executor)
@@ -896,7 +893,7 @@ public void asyncTransactionManagerWaitsUntilAsyncBatchUpdateHasFinished() throw
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -907,55 +904,53 @@ public void asyncTransactionManagerWaitsUntilAsyncBatchUpdateHasFinished() throw
 
   @Test
   public void asyncTransactionManagerReadRow() throws Exception {
-    ApiFuture val;
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          AsyncTransactionStep step;
-          val =
-              step =
-                  txn.then(
-                          AsyncTransactionManagerHelper.readRowAsync(
+          AsyncTransactionStep value =
+              transactionContextFuture
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.readRowAsync(
                               READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES),
-                          executor)
-                      .then(
-                          (ignored, input) -> ApiFutures.immediateFuture(input.getString("Value")),
-                          executor);
-          step.commitAsync().get();
+                      executor)
+                  .then(
+                      (ignored, input) -> ApiFutures.immediateFuture(input.getString("Value")),
+                      executor);
+          value.commitAsync().get();
+          assertThat(value.get()).isEqualTo("v1");
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(val.get()).isEqualTo("v1");
   }
 
   @Test
   public void asyncTransactionManagerRead() throws Exception {
-    AsyncTransactionStep> res;
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          res =
-              txn.then(
-                  (transaction, ignored) ->
-                      transaction
+          AsyncTransactionStep> values =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext
                           .readAsync(READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES)
                           .toListAsync(
                               input -> input.getString("Value"), MoreExecutors.directExecutor()),
                   executor);
           // Commit the transaction.
-          res.commitAsync().get();
+          values.commitAsync().get();
+          assertThat(values.get()).containsExactly("v1", "v2", "v3");
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(res.get()).containsExactly("v1", "v2", "v3");
   }
 
   @Test
@@ -966,24 +961,24 @@ public void asyncTransactionManagerQuery() throws Exception {
             MockSpannerTestUtil.READ_FIRST_NAME_SINGERS_RESULTSET));
     final long singerId = 1L;
     try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         final String column = "FirstName";
         CommitTimestampFuture commitTimestamp =
-            txn.then(
-                    (transaction, ignored) ->
-                        transaction.readRowAsync(
+            transactionContextFuture
+                .then(
+                    (transactionContext, ignored) ->
+                        transactionContext.readRowAsync(
                             "Singers", Key.of(singerId), Collections.singleton(column)),
                     executor)
                 .then(
                     (transaction, input) -> {
                       String name = input.getString(column);
-                      transaction.buffer(
+                      return transaction.bufferAsync(
                           Mutation.newUpdateBuilder("Singers")
                               .set(column)
                               .to(name.toUpperCase())
                               .build());
-                      return ApiFutures.immediateFuture(null);
                     },
                     executor)
                 .commitAsync();
@@ -991,8 +986,7 @@ public void asyncTransactionManagerQuery() throws Exception {
           commitTimestamp.get();
           break;
         } catch (AbortedException e) {
-          Thread.sleep(e.getRetryDelayInMillis());
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java
index 369385478d..b7035f64fa 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java
@@ -16,6 +16,7 @@
 
 package com.google.cloud.spanner;
 
+import static org.junit.Assert.assertThrows;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyMap;
 import static org.mockito.Mockito.mock;
@@ -26,20 +27,125 @@
 import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl;
 import com.google.cloud.spanner.spi.v1.SpannerRpc;
 import com.google.protobuf.ByteString;
+import com.google.protobuf.Timestamp;
 import com.google.rpc.Code;
 import com.google.rpc.Status;
 import com.google.spanner.v1.CommitRequest;
 import com.google.spanner.v1.ExecuteBatchDmlRequest;
 import com.google.spanner.v1.ExecuteBatchDmlResponse;
 import java.util.Collections;
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.mockito.Mock;
 import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
 
 @RunWith(JUnit4.class)
 public class TransactionContextImplTest {
 
+  @Mock private SpannerRpc rpc;
+
+  @Mock private SessionImpl session;
+
+  @SuppressWarnings("unchecked")
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+    when(rpc.commitAsync(any(CommitRequest.class), anyMap()))
+        .thenReturn(
+            ApiFutures.immediateFuture(
+                com.google.spanner.v1.CommitResponse.newBuilder()
+                    .setCommitTimestamp(Timestamp.newBuilder().setSeconds(99L).setNanos(10).build())
+                    .build()));
+    when(session.getName()).thenReturn("test");
+  }
+
+  private TransactionContextImpl createContext() {
+    return TransactionContextImpl.newBuilder()
+        .setSession(session)
+        .setRpc(rpc)
+        .setTransactionId(ByteString.copyFromUtf8("test"))
+        .setOptions(Options.fromTransactionOptions())
+        .build();
+  }
+
+  @Test
+  public void testCanBufferBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.buffer(Mutation.delete("test", KeySet.all()));
+    }
+  }
+
+  @Test
+  public void testCanBufferAsyncBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.bufferAsync(Mutation.delete("test", KeySet.all()));
+    }
+  }
+
+  @Test
+  public void testCanBufferIterableBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.buffer(Collections.singleton(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCanBufferIterableAsyncBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.bufferAsync(Collections.singleton(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCannotBufferAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class, () -> context.buffer(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCannotBufferAsyncAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class,
+          () -> context.bufferAsync(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCannotBufferIterableAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class,
+          () -> context.buffer(Collections.singleton(Mutation.delete("test", KeySet.all()))));
+    }
+  }
+
+  @Test
+  public void testCannotBufferIterableAsyncAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class,
+          () -> context.bufferAsync(Collections.singleton(Mutation.delete("test", KeySet.all()))));
+    }
+  }
+
+  @Test
+  public void testCannotCommitTwice() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(IllegalStateException.class, () -> context.commit());
+    }
+  }
+
   @Test(expected = AbortedException.class)
   public void batchDmlAborted() {
     batchDml(Code.ABORTED_VALUE);
@@ -53,13 +159,7 @@ public void batchDmlException() {
   @SuppressWarnings("unchecked")
   @Test
   public void testReturnCommitStats() {
-    SessionImpl session = mock(SessionImpl.class);
-    when(session.getName()).thenReturn("test");
     ByteString transactionId = ByteString.copyFromUtf8("test");
-    SpannerRpc rpc = mock(SpannerRpc.class);
-    when(rpc.commitAsync(any(CommitRequest.class), anyMap()))
-        .thenReturn(
-            ApiFutures.immediateFuture(com.google.spanner.v1.CommitResponse.getDefaultInstance()));
 
     try (TransactionContextImpl context =
         TransactionContextImpl.newBuilder()
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextTest.java
new file mode 100644
index 0000000000..045c58d837
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextTest.java
@@ -0,0 +1,144 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import static org.junit.Assert.assertThrows;
+
+import com.google.api.core.ApiFuture;
+import com.google.cloud.spanner.Options.QueryOption;
+import com.google.cloud.spanner.Options.ReadOption;
+import com.google.cloud.spanner.Options.UpdateOption;
+import java.util.Collections;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class TransactionContextTest {
+
+  @Test
+  public void testDefaultImplementations() {
+    try (TransactionContext context =
+        new TransactionContext() {
+          @Override
+          public AsyncResultSet readUsingIndexAsync(
+              String table,
+              String index,
+              KeySet keys,
+              Iterable columns,
+              ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public ResultSet readUsingIndex(
+              String table,
+              String index,
+              KeySet keys,
+              Iterable columns,
+              ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public ApiFuture readRowUsingIndexAsync(
+              String table, String index, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public Struct readRowUsingIndex(
+              String table, String index, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public ApiFuture readRowAsync(String table, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public Struct readRow(String table, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public AsyncResultSet readAsync(
+              String table, KeySet keys, Iterable columns, ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public ResultSet read(
+              String table, KeySet keys, Iterable columns, ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public AsyncResultSet executeQueryAsync(Statement statement, QueryOption... options) {
+            return null;
+          }
+
+          @Override
+          public ResultSet executeQuery(Statement statement, QueryOption... options) {
+            return null;
+          }
+
+          @Override
+          public void close() {}
+
+          @Override
+          public ResultSet analyzeQuery(Statement statement, QueryAnalyzeMode queryMode) {
+            return null;
+          }
+
+          @Override
+          public ApiFuture executeUpdateAsync(Statement statement, UpdateOption... options) {
+            return null;
+          }
+
+          @Override
+          public long executeUpdate(Statement statement, UpdateOption... options) {
+            return 0;
+          }
+
+          @Override
+          public void buffer(Iterable mutations) {}
+
+          @Override
+          public void buffer(Mutation mutation) {}
+
+          @Override
+          public ApiFuture batchUpdateAsync(
+              Iterable statements, UpdateOption... options) {
+            return null;
+          }
+
+          @Override
+          public long[] batchUpdate(Iterable statements, UpdateOption... options) {
+            return null;
+          }
+        }) {
+      assertThrows(
+          UnsupportedOperationException.class,
+          () -> context.bufferAsync(Mutation.delete("foo", KeySet.all())));
+      assertThrows(
+          UnsupportedOperationException.class,
+          () -> context.bufferAsync(Collections.singleton(Mutation.delete("foo", KeySet.all()))));
+    }
+  }
+}