diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml index cfbcb88f85..1f7beb76e9 100644 --- a/google-cloud-spanner/clirr-ignored-differences.xml +++ b/google-cloud-spanner/clirr-ignored-differences.xml @@ -371,4 +371,10 @@ com/google/cloud/spanner/ResultSets com.google.cloud.spanner.AsyncResultSet toAsyncResultSet(com.google.cloud.spanner.ResultSet, com.google.api.gax.core.ExecutorProvider) + + + 7012 + com/google/cloud/spanner/AsyncTransactionManager + com.google.api.core.ApiFuture closeAsync() + diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java index d519c68013..02d4a9dbd2 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java @@ -18,9 +18,6 @@ import com.google.api.core.ApiFuture; import com.google.cloud.Timestamp; -import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionFunction; -import com.google.cloud.spanner.AsyncTransactionManager.CommitTimestampFuture; -import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; import com.google.cloud.spanner.TransactionManager.TransactionState; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; @@ -200,4 +197,11 @@ public interface AsyncTransactionFunction { */ @Override void close(); + + /** + * Closes the transaction manager. If there is an active transaction, it will be rolled back. The + * underlying session will be released back to the session pool. The returned {@link ApiFuture} is + * done when the transaction (if any) has been rolled back. + */ + ApiFuture closeAsync(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java index 082fa827e7..350349af16 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java @@ -24,6 +24,7 @@ import com.google.cloud.spanner.SessionImpl.SessionTransaction; import com.google.cloud.spanner.TransactionContextFutureImpl.CommittableAsyncTransactionManager; import com.google.cloud.spanner.TransactionManager.TransactionState; +import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; import io.opencensus.trace.Span; @@ -54,7 +55,17 @@ public void setSpan(Span span) { @Override public void close() { + closeAsync(); + } + + @Override + public ApiFuture closeAsync() { + ApiFuture res = null; + if (txnState == TransactionState.STARTED) { + res = rollbackAsync(); + } txn.close(); + return MoreObjects.firstNonNull(res, ApiFutures.immediateFuture(null)); } @Override diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java index 55b6102a27..54b621b93b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java @@ -22,7 +22,6 @@ import com.google.api.core.ApiFutures; import com.google.api.core.SettableApiFuture; import com.google.cloud.Timestamp; -import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; import com.google.cloud.spanner.SessionPool.PooledSessionFuture; import com.google.cloud.spanner.TransactionContextFutureImpl.CommittableAsyncTransactionManager; import com.google.cloud.spanner.TransactionManager.TransactionState; @@ -59,14 +58,41 @@ public void run() { @Override public void close() { - delegate.addListener( - new Runnable() { + SpannerApiFutures.get(closeAsync()); + } + + @Override + public ApiFuture closeAsync() { + final SettableApiFuture res = SettableApiFuture.create(); + ApiFutures.addCallback( + delegate, + new ApiFutureCallback() { @Override - public void run() { + public void onFailure(Throwable t) { session.close(); } + + @Override + public void onSuccess(AsyncTransactionManagerImpl result) { + ApiFutures.addCallback( + result.closeAsync(), + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + res.setException(t); + } + + @Override + public void onSuccess(Void result) { + session.close(); + res.set(result); + } + }, + MoreExecutors.directExecutor()); + } }, MoreExecutors.directExecutor()); + return res; } @Override diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java index c7b95f33f6..bf1f214a71 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java @@ -36,7 +36,9 @@ import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.Options.ReadOption; +import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; import com.google.common.base.Function; +import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.collect.Range; @@ -47,6 +49,8 @@ import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.ExecuteBatchDmlRequest; import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.TransactionSelector; import io.grpc.Status; import java.util.Arrays; import java.util.Collection; @@ -181,6 +185,30 @@ public void onSuccess(long[] input) { } } + @Test + public void asyncTransactionManager_shouldRollbackOnCloseAsync() throws Exception { + AsyncTransactionManager manager = client().transactionManagerAsync(); + TransactionContext txn = manager.beginAsync().get(); + txn.executeUpdateAsync(UPDATE_STATEMENT).get(); + final TransactionSelector selector = ((TransactionContextImpl) txn).getTransactionSelector(); + + SpannerApiFutures.get(manager.closeAsync()); + // The mock server should already have the Rollback request, as we are waiting for the returned + // ApiFuture to be done. + mockSpanner.waitForRequestsToContain( + new Predicate() { + @Override + public boolean apply(AbstractMessage input) { + if (input instanceof RollbackRequest) { + RollbackRequest request = (RollbackRequest) input; + return request.getTransactionId().equals(selector.getId()); + } + return false; + } + }, + 0L); + } + @Test public void asyncTransactionManagerUpdate() throws Exception { final SettableApiFuture updateCount = SettableApiFuture.create(); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 4f55cd5ebd..5ecf9607a4 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -23,8 +23,10 @@ import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; import com.google.common.base.Optional; import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; import com.google.common.base.Stopwatch; import com.google.common.base.Throwables; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.AbstractMessage; import com.google.protobuf.ByteString; @@ -1927,6 +1929,23 @@ public void waitForRequestsToContain(Class type, long } } + public void waitForRequestsToContain( + Predicate predicate, long timeoutMillis) + throws InterruptedException, TimeoutException { + Stopwatch watch = Stopwatch.createStarted(); + while (true) { + Iterable msg = Iterables.filter(getRequests(), predicate); + if (msg.iterator().hasNext()) { + break; + } + Thread.sleep(10L); + if (watch.elapsed(TimeUnit.MILLISECONDS) > timeoutMillis) { + throw new TimeoutException( + "Timeout while waiting for requests to contain the wanted request"); + } + } + } + @Override public void addResponse(AbstractMessage response) { throw new UnsupportedOperationException();