diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml index c6f4c1f3a9..c6a936c51c 100644 --- a/google-cloud-spanner/clirr-ignored-differences.xml +++ b/google-cloud-spanner/clirr-ignored-differences.xml @@ -605,4 +605,17 @@ com/google/cloud/spanner/StructReader com.google.cloud.spanner.Value getValue(java.lang.String) + + + + + 7012 + com/google/cloud/spanner/TransactionContext + com.google.api.core.ApiFuture bufferAsync(com.google.cloud.spanner.Mutation) + + + 7012 + com/google/cloud/spanner/TransactionContext + com.google.api.core.ApiFuture bufferAsync(java.lang.Iterable) + diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java index c648b567d7..391be3d190 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java @@ -18,6 +18,7 @@ import com.google.api.core.ApiFuture; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; import com.google.cloud.spanner.TransactionManager.TransactionState; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; @@ -98,31 +99,21 @@ Timestamp get(long timeout, TimeUnit unit) *

Example usage: * *

{@code
-   * TransactionContextFuture txnFuture = manager.beginAsync();
    * final String column = "FirstName";
-   * txnFuture.then(
-   *         new AsyncTransactionFunction() {
-   *           @Override
-   *           public ApiFuture apply(TransactionContext txn, Void input)
-   *               throws Exception {
-   *             return txn.readRowAsync(
-   *                 "Singers", Key.of(singerId), Collections.singleton(column));
-   *           }
-   *         })
-   *     .then(
-   *         new AsyncTransactionFunction() {
-   *           @Override
-   *           public ApiFuture apply(TransactionContext txn, Struct input)
-   *               throws Exception {
-   *             String name = input.getString(column);
-   *             txn.buffer(
-   *                 Mutation.newUpdateBuilder("Singers")
-   *                     .set(column)
-   *                     .to(name.toUpperCase())
-   *                     .build());
-   *             return ApiFutures.immediateFuture(null);
-   *           }
-   *         })
+   * final long singerId = 1L;
+   * AsyncTransactionManager manager = client.transactionManagerAsync();
+   * TransactionContextFuture txnFuture = manager.beginAsync();
+   * txnFuture
+   *   .then((transaction, ignored) ->
+   *     transaction.readRowAsync("Singers", Key.of(singerId), Collections.singleton(column)),
+   *     executor)
+   *   .then((transaction, row) ->
+   *     transaction.bufferAsync(
+   *         Mutation.newUpdateBuilder("Singers")
+   *           .set(column).to(row.getString(column).toUpperCase())
+   *           .build()),
+   *     executor)
+   *   .commitAsync();
    * }
*/ interface AsyncTransactionStep extends ApiFuture { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java index 799c80b707..60e8b6910c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java @@ -431,8 +431,7 @@ CommitResponse writeAtLeastOnceWithOptions( * lifecycle. This API is meant for advanced users. Most users should instead use the {@link * #runAsync()} API instead. * - *

Example of using {@link AsyncTransactionManager} with lambda expressions (Java 8 and - * higher). + *

Example of using {@link AsyncTransactionManager}. * *

{@code
    * long singerId = 1L;
@@ -449,56 +448,11 @@ CommitResponse writeAtLeastOnceWithOptions(
    *             .then(
    *                 (transaction, row) -> {
    *                   String name = row.getString(column);
-   *                   transaction.buffer(
+   *                   return transaction.bufferAsync(
    *                       Mutation.newUpdateBuilder("Singers")
    *                           .set(column)
    *                           .to(name.toUpperCase())
    *                           .build());
-   *                   return ApiFutures.immediateFuture(null);
-   *                 })
-   *             .commitAsync();
-   *     try {
-   *       commitTimestamp.get();
-   *       break;
-   *     } catch (AbortedException e) {
-   *       Thread.sleep(e.getRetryDelayInMillis());
-   *       transactionFuture = manager.resetForRetryAsync();
-   *     }
-   *   }
-   * }
-   * }
- * - *

Example of using {@link AsyncTransactionManager} (Java 7). - * - *

{@code
-   * final long singerId = 1L;
-   * try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-   *   TransactionContextFuture transactionFuture = manager.beginAsync();
-   *   while (true) {
-   *     final String column = "FirstName";
-   *     CommitTimestampFuture commitTimestamp =
-   *         transactionFuture.then(
-   *                 new AsyncTransactionFunction() {
-   *                   @Override
-   *                   public ApiFuture apply(TransactionContext transaction, Void input)
-   *                       throws Exception {
-   *                     return transaction.readRowAsync(
-   *                         "Singers", Key.of(singerId), Collections.singleton(column));
-   *                   }
-   *                 })
-   *             .then(
-   *                 new AsyncTransactionFunction() {
-   *                   @Override
-   *                   public ApiFuture apply(TransactionContext transaction, Struct input)
-   *                       throws Exception {
-   *                     String name = input.getString(column);
-   *                     transaction.buffer(
-   *                         Mutation.newUpdateBuilder("Singers")
-   *                             .set(column)
-   *                             .to(name.toUpperCase())
-   *                             .build());
-   *                     return ApiFutures.immediateFuture(null);
-   *                   }
    *                 })
    *             .commitAsync();
    *     try {
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
index 47f2c33899..fbfc472bf5 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
@@ -675,6 +675,11 @@ public void buffer(Mutation mutation) {
       delegate.buffer(mutation);
     }
 
+    @Override
+    public ApiFuture bufferAsync(Mutation mutation) {
+      return delegate.bufferAsync(mutation);
+    }
+
     @Override
     public Struct readRowUsingIndex(String table, String index, Key key, Iterable columns) {
       try {
@@ -703,6 +708,11 @@ public void buffer(Iterable mutations) {
       delegate.buffer(mutations);
     }
 
+    @Override
+    public ApiFuture bufferAsync(Iterable mutations) {
+      return delegate.bufferAsync(mutations);
+    }
+
     @Override
     public long executeUpdate(Statement statement, UpdateOption... options) {
       try {
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java
index 64c45b12c0..2590d5b309 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java
@@ -91,6 +91,11 @@ public interface TransactionContext extends ReadContext {
    */
   void buffer(Mutation mutation);
 
+  /** Same as {@link #buffer(Mutation)}, but is guaranteed to be non-blocking. */
+  default ApiFuture bufferAsync(Mutation mutation) {
+    throw new UnsupportedOperationException("method should be overwritten");
+  }
+
   /**
    * Buffers mutations to be applied if the transaction commits successfully. The effects of the
    * mutations will not be visible to subsequent operations in the transaction. All buffered
@@ -98,6 +103,11 @@ public interface TransactionContext extends ReadContext {
    */
   void buffer(Iterable mutations);
 
+  /** Same as {@link #buffer(Iterable)}, but is guaranteed to be non-blocking. */
+  default ApiFuture bufferAsync(Iterable mutations) {
+    throw new UnsupportedOperationException("method should be overwritten");
+  }
+
   /**
    * Executes the DML statement(s) and returns the number of rows modified. For non-DML statements,
    * it will result in an {@code IllegalArgumentException}. The effects of the DML statement will be
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
index 2484b9d3c6..e04dace003 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
@@ -54,7 +54,9 @@
 import io.opencensus.trace.Tracing;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Queue;
 import java.util.concurrent.Callable;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
 import java.util.concurrent.TimeUnit;
@@ -75,6 +77,9 @@ class TransactionRunnerImpl implements SessionTransaction, TransactionRunner {
    */
   private static final String TRANSACTION_CANCELLED_MESSAGE = "invalidated by a later transaction";
 
+  private static final String TRANSACTION_ALREADY_COMMITTED_MESSAGE =
+      "Transaction has already committed";
+
   @VisibleForTesting
   static class TransactionContextImpl extends AbstractReadContext implements TransactionContext {
     static class Builder extends AbstractReadContext.Builder {
@@ -146,7 +151,9 @@ public void removeListener(Runnable listener) {
       }
     }
 
-    @GuardedBy("lock")
+    private final Object committingLock = new Object();
+
+    @GuardedBy("committingLock")
     private volatile boolean committing;
 
     @GuardedBy("lock")
@@ -155,8 +162,7 @@ public void removeListener(Runnable listener) {
     @GuardedBy("lock")
     private volatile int runningAsyncOperations;
 
-    @GuardedBy("lock")
-    private List mutations = new ArrayList<>();
+    private final Queue mutations = new ConcurrentLinkedQueue<>();
 
     @GuardedBy("lock")
     private boolean aborted;
@@ -280,6 +286,16 @@ void commit() {
     volatile ApiFuture commitFuture;
 
     ApiFuture commitAsync() {
+      List mutationsProto = new ArrayList<>();
+      synchronized (committingLock) {
+        if (committing) {
+          throw new IllegalStateException(TRANSACTION_ALREADY_COMMITTED_MESSAGE);
+        }
+        committing = true;
+        if (!mutations.isEmpty()) {
+          Mutation.toProto(mutations, mutationsProto);
+        }
+      }
       final SettableApiFuture res = SettableApiFuture.create();
       final SettableApiFuture finishOps;
       CommitRequest.Builder builder =
@@ -303,14 +319,8 @@ ApiFuture commitAsync() {
         } else {
           finishOps = finishedAsyncOperations;
         }
-        if (!mutations.isEmpty()) {
-          List mutationsProto = new ArrayList<>();
-          Mutation.toProto(mutations, mutationsProto);
-          builder.addAllMutations(mutationsProto);
-        }
-        // Ensure that no call to buffer mutations that would be lost can succeed.
-        mutations = null;
       }
+      builder.addAllMutations(mutationsProto);
       finishOps.addListener(
           new CommitRunnable(res, finishOps, builder), MoreExecutors.directExecutor());
       return res;
@@ -603,22 +613,44 @@ public void onDone(boolean withBeginTransaction) {
 
     @Override
     public void buffer(Mutation mutation) {
-      synchronized (lock) {
-        checkNotNull(mutations, "Context is closed");
+      synchronized (committingLock) {
+        if (committing) {
+          throw new IllegalStateException(TRANSACTION_ALREADY_COMMITTED_MESSAGE);
+        }
         mutations.add(checkNotNull(mutation));
       }
     }
 
+    @Override
+    public ApiFuture bufferAsync(Mutation mutation) {
+      // Normally, we would call the async method from the sync method, but this is also safe as
+      // both are non-blocking anyways, and this prevents the creation of an ApiFuture that is not
+      // really used when the sync method is called.
+      buffer(mutation);
+      return ApiFutures.immediateFuture(null);
+    }
+
     @Override
     public void buffer(Iterable mutations) {
-      synchronized (lock) {
-        checkNotNull(this.mutations, "Context is closed");
+      synchronized (committingLock) {
+        if (committing) {
+          throw new IllegalStateException(TRANSACTION_ALREADY_COMMITTED_MESSAGE);
+        }
         for (Mutation mutation : mutations) {
           this.mutations.add(checkNotNull(mutation));
         }
       }
     }
 
+    @Override
+    public ApiFuture bufferAsync(Iterable mutations) {
+      // Normally, we would call the async method from the sync method, but this is also safe as
+      // both are non-blocking anyways, and this prevents the creation of an ApiFuture that is not
+      // really used when the sync method is called.
+      buffer(mutations);
+      return ApiFutures.immediateFuture(null);
+    }
+
     @Override
     public long executeUpdate(Statement statement, UpdateOption... options) {
       beforeReadOrQuery();
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
index 345e58b7cb..de3e267434 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
@@ -197,13 +197,15 @@ public void asyncTransactionManager_shouldRollbackOnCloseAsync() throws Exceptio
   public void testAsyncTransactionManager_returnsCommitStats() throws Exception {
     try (AsyncTransactionManager manager =
         client().transactionManagerAsync(Options.commitStats())) {
-      TransactionContextFuture transaction = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
           CommitTimestampFuture commitTimestamp =
-              transaction
+              transactionContextFuture
                   .then(
-                      AsyncTransactionManagerHelper.buffer(Mutation.delete("FOO", Key.of("foo"))),
+                      (transactionContext, ignored) ->
+                          transactionContext.bufferAsync(
+                              Collections.singleton(Mutation.delete("FOO", Key.of("foo")))),
                       executor)
                   .commitAsync();
           assertNotNull(commitTimestamp.get());
@@ -212,7 +214,7 @@ public void testAsyncTransactionManager_returnsCommitStats() throws Exception {
           assertEquals(1L, manager.getCommitResponse().get().getCommitStats().getMutationCount());
           break;
         } catch (AbortedException e) {
-          transaction = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -220,23 +222,21 @@ public void testAsyncTransactionManager_returnsCommitStats() throws Exception {
 
   @Test
   public void asyncTransactionManagerUpdate() throws Exception {
-    final SettableApiFuture updateCount = SettableApiFuture.create();
-
     try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture commitTimestamp =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(
-                          updateCount, UPDATE_STATEMENT),
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCount =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext.executeUpdateAsync(UPDATE_STATEMENT),
+                  executor);
+          CommitTimestampFuture commitTimestamp = updateCount.commitAsync();
           assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT);
           assertThat(commitTimestamp.get()).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -244,25 +244,23 @@ public void asyncTransactionManagerUpdate() throws Exception {
 
   @Test
   public void asyncTransactionManagerIsNonBlocking() throws Exception {
-    SettableApiFuture updateCount = SettableApiFuture.create();
-
     mockSpanner.freeze();
     try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture commitTimestamp =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(
-                          updateCount, UPDATE_STATEMENT),
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCount =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext.executeUpdateAsync(UPDATE_STATEMENT),
+                  executor);
+          CommitTimestampFuture commitTimestamp = updateCount.commitAsync();
           mockSpanner.unfreeze();
           assertThat(updateCount.get(10L, TimeUnit.SECONDS)).isEqualTo(UPDATE_COUNT);
           assertThat(commitTimestamp.get(10L, TimeUnit.SECONDS)).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -271,9 +269,10 @@ public void asyncTransactionManagerIsNonBlocking() throws Exception {
   @Test
   public void asyncTransactionManagerInvalidUpdate() throws Exception {
     try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       CommitTimestampFuture commitTimestamp =
-          txn.then(
+          transactionContextFuture
+              .then(
                   (transaction, ignored) ->
                       transaction.executeUpdateAsync(INVALID_UPDATE_STATEMENT),
                   executor)
@@ -286,33 +285,31 @@ public void asyncTransactionManagerInvalidUpdate() throws Exception {
 
   @Test
   public void asyncTransactionManagerCommitAborted() throws Exception {
-    SettableApiFuture updateCount = SettableApiFuture.create();
     final AtomicInteger attempt = new AtomicInteger();
     try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
           attempt.incrementAndGet();
-          CommitTimestampFuture commitTimestamp =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(
-                          updateCount, UPDATE_STATEMENT),
-                      executor)
-                  .then(
-                      (transaction, ignored) -> {
-                        if (attempt.get() == 1) {
-                          mockSpanner.abortTransaction(transaction);
-                        }
-                        return ApiFutures.immediateFuture(null);
-                      },
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCount =
+              transactionContextFuture.then(
+                  (transaction, ignored) -> transaction.executeUpdateAsync(UPDATE_STATEMENT),
+                  executor);
+          updateCount.then(
+              (transaction, ignored) -> {
+                if (attempt.get() == 1) {
+                  mockSpanner.abortTransaction(transaction);
+                }
+                return ApiFutures.immediateFuture(null);
+              },
+              executor);
+          CommitTimestampFuture commitTimestamp = updateCount.commitAsync();
           assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT);
           assertThat(commitTimestamp.get()).isNotNull();
           assertThat(attempt.get()).isEqualTo(2);
           break;
         } catch (AbortedException e) {
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -320,42 +317,26 @@ public void asyncTransactionManagerCommitAborted() throws Exception {
 
   @Test
   public void asyncTransactionManagerFireAndForgetInvalidUpdate() throws Exception {
-    final SettableApiFuture updateCount = SettableApiFuture.create();
-
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
-                      (transaction, ignored) -> {
-                        // This fire-and-forget update statement should not fail the transaction.
-                        // The exception will however cause the transaction to be retried, as the
-                        // statement will not return a transaction id.
-                        transaction.executeUpdateAsync(INVALID_UPDATE_STATEMENT);
-                        ApiFutures.addCallback(
-                            transaction.executeUpdateAsync(UPDATE_STATEMENT),
-                            new ApiFutureCallback() {
-                              @Override
-                              public void onFailure(Throwable t) {
-                                updateCount.setException(t);
-                              }
-
-                              @Override
-                              public void onSuccess(Long result) {
-                                updateCount.set(result);
-                              }
-                            },
-                            MoreExecutors.directExecutor());
-                        return updateCount;
-                      },
-                      executor)
-                  .commitAsync();
-          assertThat(ts.get()).isNotNull();
-          assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT);
+          AsyncTransactionStep transaction =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) -> {
+                    // This fire-and-forget update statement should not fail the transaction.
+                    // The exception will however cause the transaction to be retried, as the
+                    // statement will not return a transaction id.
+                    transactionContext.executeUpdateAsync(INVALID_UPDATE_STATEMENT);
+                    return transactionContext.executeUpdateAsync(UPDATE_STATEMENT);
+                  },
+                  executor);
+          CommitTimestampFuture commitTimestamp = transaction.commitAsync();
+          assertThat(commitTimestamp.get()).isNotNull();
+          assertThat(transaction.get()).isEqualTo(UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -375,15 +356,19 @@ public void onSuccess(Long result) {
 
   @Test
   public void asyncTransactionManagerChain() throws Exception {
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(AsyncTransactionManagerHelper.executeUpdateAsync(UPDATE_STATEMENT), executor)
+          CommitTimestampFuture commitTimestamp =
+              transactionContextFuture
+                  .then(
+                      (transaction, ignored) -> transaction.executeUpdateAsync(UPDATE_STATEMENT),
+                      executor)
                   .then(
-                      AsyncTransactionManagerHelper.readRowAsync(
-                          READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES),
+                      (transactionContext, ignored) ->
+                          transactionContext.readRowAsync(
+                              READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES),
                       executor)
                   .then(
                       (ignored, input) -> ApiFutures.immediateFuture(input.getString("Value")),
@@ -395,10 +380,10 @@ public void asyncTransactionManagerChain() throws Exception {
                       },
                       executor)
                   .commitAsync();
-          assertThat(ts.get()).isNotNull();
+          assertThat(commitTimestamp.get()).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -406,13 +391,15 @@ public void asyncTransactionManagerChain() throws Exception {
 
   @Test
   public void asyncTransactionManagerChainWithErrorInTheMiddle() throws Exception {
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(INVALID_UPDATE_STATEMENT),
+          CommitTimestampFuture commitTimestampFuture =
+              transactionContextFuture
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.executeUpdateAsync(INVALID_UPDATE_STATEMENT),
                       executor)
                   .then(
                       (ignored1, ignored2) -> {
@@ -420,16 +407,12 @@ public void asyncTransactionManagerChainWithErrorInTheMiddle() throws Exception
                       },
                       executor)
                   .commitAsync();
-          ts.get();
+          SpannerException e =
+              assertThrows(SpannerException.class, () -> get(commitTimestampFuture));
+          assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
-        } catch (ExecutionException e) {
-          mgr.rollbackAsync();
-          assertThat(e.getCause()).isInstanceOf(SpannerException.class);
-          SpannerException se = (SpannerException) e.getCause();
-          assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT);
-          break;
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -437,16 +420,17 @@ public void asyncTransactionManagerChainWithErrorInTheMiddle() throws Exception
 
   @Test
   public void asyncTransactionManagerUpdateAborted() throws Exception {
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
       // Temporarily set the result of the update to 2 rows.
       mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L));
       final AtomicInteger attempt = new AtomicInteger();
 
-      TransactionContextFuture txn = mgr.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
+          CommitTimestampFuture commitTimestampFuture =
+              transactionContextFuture
+                  .then(
                       (ignored1, ignored2) -> {
                         if (attempt.incrementAndGet() == 1) {
                           // Abort the first attempt.
@@ -460,12 +444,14 @@ public void asyncTransactionManagerUpdateAborted() throws Exception {
                       },
                       executor)
                   .then(
-                      AsyncTransactionManagerHelper.executeUpdateAsync(UPDATE_STATEMENT), executor)
+                      (transactionContext, ignored) ->
+                          transactionContext.executeUpdateAsync(UPDATE_STATEMENT),
+                      executor)
                   .commitAsync();
-          assertThat(ts.get()).isNotNull();
+          assertThat(commitTimestampFuture.get()).isNotNull();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
       assertThat(attempt.get()).isEqualTo(2);
@@ -477,12 +463,13 @@ public void asyncTransactionManagerUpdateAborted() throws Exception {
   @Test
   public void asyncTransactionManagerUpdateAbortedWithoutGettingResult() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
+          CommitTimestampFuture commitTimestampFuture =
+              transactionContextFuture
+                  .then(
                       (transaction, ignored) -> {
                         if (attempt.incrementAndGet() == 1) {
                           mockSpanner.abortNextStatement();
@@ -498,7 +485,7 @@ public void asyncTransactionManagerUpdateAbortedWithoutGettingResult() throws Ex
                       },
                       executor)
                   .commitAsync();
-          assertThat(ts.get()).isNotNull();
+          assertThat(commitTimestampFuture.get()).isNotNull();
           assertThat(attempt.get()).isEqualTo(2);
           // The server may receive 1 or 2 commit requests depending on whether the call to
           // commitAsync() already knows that the transaction has aborted. If it does, it will not
@@ -513,7 +500,7 @@ public void asyncTransactionManagerUpdateAbortedWithoutGettingResult() throws Ex
                   CommitRequest.class);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -571,45 +558,45 @@ public void asyncTransactionManagerWaitsUntilAsyncUpdateHasFinished() throws Exc
 
   @Test
   public void asyncTransactionManagerBatchUpdate() throws Exception {
-    final SettableApiFuture result = SettableApiFuture.create();
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  AsyncTransactionManagerHelper.batchUpdateAsync(
-                      result, UPDATE_STATEMENT, UPDATE_STATEMENT),
-                  executor)
-              .commitAsync()
-              .get();
+          AsyncTransactionStep updateCounts =
+              transactionContextFuture.then(
+                  (transaction, ignored) ->
+                      transaction.batchUpdateAsync(
+                          ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
+                  executor);
+          get(updateCounts.commitAsync());
+          assertThat(get(updateCounts)).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
   }
 
   @Test
   public void asyncTransactionManagerIsNonBlockingWithBatchUpdate() throws Exception {
-    SettableApiFuture res = SettableApiFuture.create();
     mockSpanner.freeze();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          CommitTimestampFuture ts =
-              txn.then(
-                      AsyncTransactionManagerHelper.batchUpdateAsync(res, UPDATE_STATEMENT),
-                      executor)
-                  .commitAsync();
+          AsyncTransactionStep updateCounts =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext.batchUpdateAsync(Collections.singleton(UPDATE_STATEMENT)),
+                  executor);
+          CommitTimestampFuture commitTimestampFuture = updateCounts.commitAsync();
           mockSpanner.unfreeze();
-          assertThat(ts.get()).isNotNull();
-          assertThat(res.get()).asList().containsExactly(UPDATE_COUNT);
+          assertThat(commitTimestampFuture.get()).isNotNull();
+          assertThat(updateCounts.get()).asList().containsExactly(UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -617,17 +604,18 @@ public void asyncTransactionManagerIsNonBlockingWithBatchUpdate() throws Excepti
 
   @Test
   public void asyncTransactionManagerInvalidBatchUpdate() throws Exception {
-    SettableApiFuture result = SettableApiFuture.create();
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       SpannerException e =
           assertThrows(
               SpannerException.class,
               () ->
                   get(
-                      txn.then(
-                              AsyncTransactionManagerHelper.batchUpdateAsync(
-                                  result, UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT),
+                      transactionContextFuture
+                          .then(
+                              (transactionContext, ignored) ->
+                                  transactionContext.batchUpdateAsync(
+                                      ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT)),
                               executor)
                           .commitAsync()));
       assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT);
@@ -637,31 +625,32 @@ public void asyncTransactionManagerInvalidBatchUpdate() throws Exception {
 
   @Test
   public void asyncTransactionManagerFireAndForgetInvalidBatchUpdate() throws Exception {
-    SettableApiFuture result = SettableApiFuture.create();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
-                    transaction.batchUpdateAsync(
-                        ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT));
-                    return ApiFutures.immediateFuture(null);
-                  },
-                  executor)
-              .then(
-                  AsyncTransactionManagerHelper.batchUpdateAsync(
-                      result, UPDATE_STATEMENT, UPDATE_STATEMENT),
-                  executor)
-              .commitAsync()
-              .get();
+          AsyncTransactionStep updateCounts =
+              transactionContextFuture
+                  .then(
+                      (transactionContext, ignored) -> {
+                        transactionContext.batchUpdateAsync(
+                            ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT));
+                        return ApiFutures.immediateFuture(null);
+                      },
+                      executor)
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.batchUpdateAsync(
+                              ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
+                      executor);
+          updateCounts.commitAsync().get();
+          assertThat(updateCounts.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
     assertThat(mockSpanner.getRequestTypes())
         .containsExactly(
             BatchCreateSessionsRequest.class,
@@ -673,11 +662,12 @@ public void asyncTransactionManagerFireAndForgetInvalidBatchUpdate() throws Exce
   @Test
   public void asyncTransactionManagerBatchUpdateAborted() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
+          transactionContextFuture
+              .then(
                   (transaction, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
                       return transaction.batchUpdateAsync(
@@ -692,7 +682,7 @@ public void asyncTransactionManagerBatchUpdateAborted() throws Exception {
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -711,16 +701,17 @@ public void asyncTransactionManagerBatchUpdateAborted() throws Exception {
   @Test
   public void asyncTransactionManagerBatchUpdateAbortedBeforeFirstStatement() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
+          transactionContextFuture
+              .then(
+                  (transactionContext, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
                       mockSpanner.abortNextStatement();
                     }
-                    return transaction.batchUpdateAsync(
+                    return transactionContext.batchUpdateAsync(
                         ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT));
                   },
                   executor)
@@ -728,7 +719,7 @@ public void asyncTransactionManagerBatchUpdateAbortedBeforeFirstStatement() thro
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -746,28 +737,30 @@ public void asyncTransactionManagerBatchUpdateAbortedBeforeFirstStatement() thro
 
   @Test
   public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Exception {
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
       // Temporarily set the result of the update to 2 rows.
       mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L));
       final AtomicInteger attempt = new AtomicInteger();
-      TransactionContextFuture txn = mgr.beginAsync();
+      TransactionContextFuture txn = manager.beginAsync();
       while (true) {
-        final SettableApiFuture result = SettableApiFuture.create();
         try {
-          txn.then(
-                  (ignored1, ignored2) -> {
-                    if (attempt.get() > 0) {
-                      // Set the result of the update statement back to 1 row.
-                      mockSpanner.putStatementResult(
-                          StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
-                    }
-                    return ApiFutures.immediateFuture(null);
-                  },
-                  executor)
-              .then(
-                  AsyncTransactionManagerHelper.batchUpdateAsync(
-                      result, UPDATE_STATEMENT, UPDATE_STATEMENT),
-                  executor)
+          AsyncTransactionStep updateCounts =
+              txn.then(
+                      (ignored1, ignored2) -> {
+                        if (attempt.get() > 0) {
+                          // Set the result of the update statement back to 1 row.
+                          mockSpanner.putStatementResult(
+                              StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT));
+                        }
+                        return ApiFutures.immediateFuture(null);
+                      },
+                      executor)
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.batchUpdateAsync(
+                              ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
+                      executor);
+          updateCounts
               .then(
                   (transaction, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
@@ -778,11 +771,11 @@ public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Excepti
                   executor)
               .commitAsync()
               .get();
-          assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
+          assertThat(updateCounts.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT);
           assertThat(attempt.get()).isEqualTo(2);
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          txn = manager.resetForRetryAsync();
         }
       }
     } finally {
@@ -801,12 +794,13 @@ public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Excepti
   @Test
   public void asyncTransactionManagerBatchUpdateAbortedWithoutGettingResult() throws Exception {
     final AtomicInteger attempt = new AtomicInteger();
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
+          transactionContextFuture
+              .then(
+                  (transactionContext, ignored) -> {
                     if (attempt.incrementAndGet() == 1) {
                       mockSpanner.abortNextStatement();
                     }
@@ -816,7 +810,7 @@ public void asyncTransactionManagerBatchUpdateAbortedWithoutGettingResult() thro
                     // directly in the transaction manager if the ABORTED error has already been
                     // returned by the batch update call before the commit call starts.
                     // Otherwise, the backend will return an ABORTED error for the commit call.
-                    transaction.batchUpdateAsync(
+                    transactionContext.batchUpdateAsync(
                         ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT));
                     return ApiFutures.immediateFuture(null);
                   },
@@ -825,7 +819,7 @@ public void asyncTransactionManagerBatchUpdateAbortedWithoutGettingResult() thro
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -860,16 +854,18 @@ public void asyncTransactionManagerWithBatchUpdateCommitFails() throws Exception
             Status.RESOURCE_EXHAUSTED
                 .withDescription("mutation limit exceeded")
                 .asRuntimeException()));
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       SpannerException e =
           assertThrows(
               SpannerException.class,
               () ->
                   get(
-                      txn.then(
-                              AsyncTransactionManagerHelper.batchUpdateAsync(
-                                  UPDATE_STATEMENT, UPDATE_STATEMENT),
+                      transactionContextFuture
+                          .then(
+                              (transactionContext, ignored) ->
+                                  transactionContext.batchUpdateAsync(
+                                      ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)),
                               executor)
                           .commitAsync()));
       assertThat(e.getErrorCode()).isEqualTo(ErrorCode.RESOURCE_EXHAUSTED);
@@ -882,13 +878,14 @@ public void asyncTransactionManagerWithBatchUpdateCommitFails() throws Exception
 
   @Test
   public void asyncTransactionManagerWaitsUntilAsyncBatchUpdateHasFinished() throws Exception {
-    try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          txn.then(
-                  (transaction, ignored) -> {
-                    transaction.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT));
+          transactionContextFuture
+              .then(
+                  (transactionContext, ignored) -> {
+                    transactionContext.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT));
                     return ApiFutures.immediateFuture(null);
                   },
                   executor)
@@ -896,7 +893,7 @@ public void asyncTransactionManagerWaitsUntilAsyncBatchUpdateHasFinished() throw
               .get();
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
@@ -907,55 +904,53 @@ public void asyncTransactionManagerWaitsUntilAsyncBatchUpdateHasFinished() throw
 
   @Test
   public void asyncTransactionManagerReadRow() throws Exception {
-    ApiFuture val;
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          AsyncTransactionStep step;
-          val =
-              step =
-                  txn.then(
-                          AsyncTransactionManagerHelper.readRowAsync(
+          AsyncTransactionStep value =
+              transactionContextFuture
+                  .then(
+                      (transactionContext, ignored) ->
+                          transactionContext.readRowAsync(
                               READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES),
-                          executor)
-                      .then(
-                          (ignored, input) -> ApiFutures.immediateFuture(input.getString("Value")),
-                          executor);
-          step.commitAsync().get();
+                      executor)
+                  .then(
+                      (ignored, input) -> ApiFutures.immediateFuture(input.getString("Value")),
+                      executor);
+          value.commitAsync().get();
+          assertThat(value.get()).isEqualTo("v1");
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(val.get()).isEqualTo("v1");
   }
 
   @Test
   public void asyncTransactionManagerRead() throws Exception {
-    AsyncTransactionStep> res;
-    try (AsyncTransactionManager mgr = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = mgr.beginAsync();
+    try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         try {
-          res =
-              txn.then(
-                  (transaction, ignored) ->
-                      transaction
+          AsyncTransactionStep> values =
+              transactionContextFuture.then(
+                  (transactionContext, ignored) ->
+                      transactionContext
                           .readAsync(READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES)
                           .toListAsync(
                               input -> input.getString("Value"), MoreExecutors.directExecutor()),
                   executor);
           // Commit the transaction.
-          res.commitAsync().get();
+          values.commitAsync().get();
+          assertThat(values.get()).containsExactly("v1", "v2", "v3");
           break;
         } catch (AbortedException e) {
-          txn = mgr.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
-    assertThat(res.get()).containsExactly("v1", "v2", "v3");
   }
 
   @Test
@@ -966,24 +961,24 @@ public void asyncTransactionManagerQuery() throws Exception {
             MockSpannerTestUtil.READ_FIRST_NAME_SINGERS_RESULTSET));
     final long singerId = 1L;
     try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
-      TransactionContextFuture txn = manager.beginAsync();
+      TransactionContextFuture transactionContextFuture = manager.beginAsync();
       while (true) {
         final String column = "FirstName";
         CommitTimestampFuture commitTimestamp =
-            txn.then(
-                    (transaction, ignored) ->
-                        transaction.readRowAsync(
+            transactionContextFuture
+                .then(
+                    (transactionContext, ignored) ->
+                        transactionContext.readRowAsync(
                             "Singers", Key.of(singerId), Collections.singleton(column)),
                     executor)
                 .then(
                     (transaction, input) -> {
                       String name = input.getString(column);
-                      transaction.buffer(
+                      return transaction.bufferAsync(
                           Mutation.newUpdateBuilder("Singers")
                               .set(column)
                               .to(name.toUpperCase())
                               .build());
-                      return ApiFutures.immediateFuture(null);
                     },
                     executor)
                 .commitAsync();
@@ -991,8 +986,7 @@ public void asyncTransactionManagerQuery() throws Exception {
           commitTimestamp.get();
           break;
         } catch (AbortedException e) {
-          Thread.sleep(e.getRetryDelayInMillis());
-          txn = manager.resetForRetryAsync();
+          transactionContextFuture = manager.resetForRetryAsync();
         }
       }
     }
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java
index 369385478d..b7035f64fa 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java
@@ -16,6 +16,7 @@
 
 package com.google.cloud.spanner;
 
+import static org.junit.Assert.assertThrows;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyMap;
 import static org.mockito.Mockito.mock;
@@ -26,20 +27,125 @@
 import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl;
 import com.google.cloud.spanner.spi.v1.SpannerRpc;
 import com.google.protobuf.ByteString;
+import com.google.protobuf.Timestamp;
 import com.google.rpc.Code;
 import com.google.rpc.Status;
 import com.google.spanner.v1.CommitRequest;
 import com.google.spanner.v1.ExecuteBatchDmlRequest;
 import com.google.spanner.v1.ExecuteBatchDmlResponse;
 import java.util.Collections;
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.mockito.Mock;
 import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
 
 @RunWith(JUnit4.class)
 public class TransactionContextImplTest {
 
+  @Mock private SpannerRpc rpc;
+
+  @Mock private SessionImpl session;
+
+  @SuppressWarnings("unchecked")
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+    when(rpc.commitAsync(any(CommitRequest.class), anyMap()))
+        .thenReturn(
+            ApiFutures.immediateFuture(
+                com.google.spanner.v1.CommitResponse.newBuilder()
+                    .setCommitTimestamp(Timestamp.newBuilder().setSeconds(99L).setNanos(10).build())
+                    .build()));
+    when(session.getName()).thenReturn("test");
+  }
+
+  private TransactionContextImpl createContext() {
+    return TransactionContextImpl.newBuilder()
+        .setSession(session)
+        .setRpc(rpc)
+        .setTransactionId(ByteString.copyFromUtf8("test"))
+        .setOptions(Options.fromTransactionOptions())
+        .build();
+  }
+
+  @Test
+  public void testCanBufferBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.buffer(Mutation.delete("test", KeySet.all()));
+    }
+  }
+
+  @Test
+  public void testCanBufferAsyncBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.bufferAsync(Mutation.delete("test", KeySet.all()));
+    }
+  }
+
+  @Test
+  public void testCanBufferIterableBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.buffer(Collections.singleton(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCanBufferIterableAsyncBeforeCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.bufferAsync(Collections.singleton(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCannotBufferAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class, () -> context.buffer(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCannotBufferAsyncAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class,
+          () -> context.bufferAsync(Mutation.delete("test", KeySet.all())));
+    }
+  }
+
+  @Test
+  public void testCannotBufferIterableAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class,
+          () -> context.buffer(Collections.singleton(Mutation.delete("test", KeySet.all()))));
+    }
+  }
+
+  @Test
+  public void testCannotBufferIterableAsyncAfterCommit() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(
+          IllegalStateException.class,
+          () -> context.bufferAsync(Collections.singleton(Mutation.delete("test", KeySet.all()))));
+    }
+  }
+
+  @Test
+  public void testCannotCommitTwice() {
+    try (TransactionContextImpl context = createContext()) {
+      context.commit();
+      assertThrows(IllegalStateException.class, () -> context.commit());
+    }
+  }
+
   @Test(expected = AbortedException.class)
   public void batchDmlAborted() {
     batchDml(Code.ABORTED_VALUE);
@@ -53,13 +159,7 @@ public void batchDmlException() {
   @SuppressWarnings("unchecked")
   @Test
   public void testReturnCommitStats() {
-    SessionImpl session = mock(SessionImpl.class);
-    when(session.getName()).thenReturn("test");
     ByteString transactionId = ByteString.copyFromUtf8("test");
-    SpannerRpc rpc = mock(SpannerRpc.class);
-    when(rpc.commitAsync(any(CommitRequest.class), anyMap()))
-        .thenReturn(
-            ApiFutures.immediateFuture(com.google.spanner.v1.CommitResponse.getDefaultInstance()));
 
     try (TransactionContextImpl context =
         TransactionContextImpl.newBuilder()
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextTest.java
new file mode 100644
index 0000000000..045c58d837
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextTest.java
@@ -0,0 +1,144 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import static org.junit.Assert.assertThrows;
+
+import com.google.api.core.ApiFuture;
+import com.google.cloud.spanner.Options.QueryOption;
+import com.google.cloud.spanner.Options.ReadOption;
+import com.google.cloud.spanner.Options.UpdateOption;
+import java.util.Collections;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class TransactionContextTest {
+
+  @Test
+  public void testDefaultImplementations() {
+    try (TransactionContext context =
+        new TransactionContext() {
+          @Override
+          public AsyncResultSet readUsingIndexAsync(
+              String table,
+              String index,
+              KeySet keys,
+              Iterable columns,
+              ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public ResultSet readUsingIndex(
+              String table,
+              String index,
+              KeySet keys,
+              Iterable columns,
+              ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public ApiFuture readRowUsingIndexAsync(
+              String table, String index, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public Struct readRowUsingIndex(
+              String table, String index, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public ApiFuture readRowAsync(String table, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public Struct readRow(String table, Key key, Iterable columns) {
+            return null;
+          }
+
+          @Override
+          public AsyncResultSet readAsync(
+              String table, KeySet keys, Iterable columns, ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public ResultSet read(
+              String table, KeySet keys, Iterable columns, ReadOption... options) {
+            return null;
+          }
+
+          @Override
+          public AsyncResultSet executeQueryAsync(Statement statement, QueryOption... options) {
+            return null;
+          }
+
+          @Override
+          public ResultSet executeQuery(Statement statement, QueryOption... options) {
+            return null;
+          }
+
+          @Override
+          public void close() {}
+
+          @Override
+          public ResultSet analyzeQuery(Statement statement, QueryAnalyzeMode queryMode) {
+            return null;
+          }
+
+          @Override
+          public ApiFuture executeUpdateAsync(Statement statement, UpdateOption... options) {
+            return null;
+          }
+
+          @Override
+          public long executeUpdate(Statement statement, UpdateOption... options) {
+            return 0;
+          }
+
+          @Override
+          public void buffer(Iterable mutations) {}
+
+          @Override
+          public void buffer(Mutation mutation) {}
+
+          @Override
+          public ApiFuture batchUpdateAsync(
+              Iterable statements, UpdateOption... options) {
+            return null;
+          }
+
+          @Override
+          public long[] batchUpdate(Iterable statements, UpdateOption... options) {
+            return null;
+          }
+        }) {
+      assertThrows(
+          UnsupportedOperationException.class,
+          () -> context.bufferAsync(Mutation.delete("foo", KeySet.all())));
+      assertThrows(
+          UnsupportedOperationException.class,
+          () -> context.bufferAsync(Collections.singleton(Mutation.delete("foo", KeySet.all()))));
+    }
+  }
+}