diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index 21812aa96a..dc6cb56f30 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -540,7 +540,7 @@ public ApiFuture batchUpdateAsync(Iterable statements) { decreaseAsyncOperations(); throw t; } - final ApiFuture updateCounts = + ApiFuture updateCounts = ApiFutures.transform( response, new ApiFunction() { @@ -565,19 +565,24 @@ public long[] apply(ExecuteBatchDmlResponse input) { } }, MoreExecutors.directExecutor()); + updateCounts = + ApiFutures.catching( + updateCounts, + Throwable.class, + new ApiFunction() { + @Override + public long[] apply(Throwable input) { + SpannerException e = SpannerExceptionFactory.newSpannerException(input); + onError(e); + throw e; + } + }, + MoreExecutors.directExecutor()); updateCounts.addListener( new Runnable() { @Override public void run() { - try { - updateCounts.get(); - } catch (ExecutionException e) { - onError(SpannerExceptionFactory.newSpannerException(e.getCause())); - } catch (InterruptedException e) { - onError(SpannerExceptionFactory.propagateInterrupt(e)); - } finally { - decreaseAsyncOperations(); - } + decreaseAsyncOperations(); } }, MoreExecutors.directExecutor()); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java index 70b3fd3f92..c7b95f33f6 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java @@ -734,6 +734,47 @@ public ApiFuture apply(TransactionContext txn, Void input) CommitRequest.class); } + @Test + public void asyncTransactionManagerBatchUpdateAbortedBeforeFirstStatement() throws Exception { + final AtomicInteger attempt = new AtomicInteger(); + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + return txn.batchUpdateAsync( + ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + } + }, + executor) + .commitAsync() + .get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(attempt.get()).isEqualTo(2); + // There should only be 1 CommitRequest, as the first attempt should abort already after the + // ExecuteBatchDmlRequest. + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + @Test public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Exception { try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) {