diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml
index cfbcb88f85..1f7beb76e9 100644
--- a/google-cloud-spanner/clirr-ignored-differences.xml
+++ b/google-cloud-spanner/clirr-ignored-differences.xml
@@ -371,4 +371,10 @@
com/google/cloud/spanner/ResultSets
com.google.cloud.spanner.AsyncResultSet toAsyncResultSet(com.google.cloud.spanner.ResultSet, com.google.api.gax.core.ExecutorProvider)
+
+
+ 7012
+ com/google/cloud/spanner/AsyncTransactionManager
+ com.google.api.core.ApiFuture closeAsync()
+
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java
index d519c68013..02d4a9dbd2 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java
@@ -18,9 +18,6 @@
import com.google.api.core.ApiFuture;
import com.google.cloud.Timestamp;
-import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionFunction;
-import com.google.cloud.spanner.AsyncTransactionManager.CommitTimestampFuture;
-import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture;
import com.google.cloud.spanner.TransactionManager.TransactionState;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
@@ -200,4 +197,11 @@ public interface AsyncTransactionFunction {
*/
@Override
void close();
+
+ /**
+ * Closes the transaction manager. If there is an active transaction, it will be rolled back. The
+ * underlying session will be released back to the session pool. The returned {@link ApiFuture} is
+ * done when the transaction (if any) has been rolled back.
+ */
+ ApiFuture closeAsync();
}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java
index 082fa827e7..350349af16 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java
@@ -24,6 +24,7 @@
import com.google.cloud.spanner.SessionImpl.SessionTransaction;
import com.google.cloud.spanner.TransactionContextFutureImpl.CommittableAsyncTransactionManager;
import com.google.cloud.spanner.TransactionManager.TransactionState;
+import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
import io.opencensus.trace.Span;
@@ -54,7 +55,17 @@ public void setSpan(Span span) {
@Override
public void close() {
+ closeAsync();
+ }
+
+ @Override
+ public ApiFuture closeAsync() {
+ ApiFuture res = null;
+ if (txnState == TransactionState.STARTED) {
+ res = rollbackAsync();
+ }
txn.close();
+ return MoreObjects.firstNonNull(res, ApiFutures.immediateFuture(null));
}
@Override
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java
index 55b6102a27..54b621b93b 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java
@@ -22,7 +22,6 @@
import com.google.api.core.ApiFutures;
import com.google.api.core.SettableApiFuture;
import com.google.cloud.Timestamp;
-import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture;
import com.google.cloud.spanner.SessionPool.PooledSessionFuture;
import com.google.cloud.spanner.TransactionContextFutureImpl.CommittableAsyncTransactionManager;
import com.google.cloud.spanner.TransactionManager.TransactionState;
@@ -59,14 +58,41 @@ public void run() {
@Override
public void close() {
- delegate.addListener(
- new Runnable() {
+ SpannerApiFutures.get(closeAsync());
+ }
+
+ @Override
+ public ApiFuture closeAsync() {
+ final SettableApiFuture res = SettableApiFuture.create();
+ ApiFutures.addCallback(
+ delegate,
+ new ApiFutureCallback() {
@Override
- public void run() {
+ public void onFailure(Throwable t) {
session.close();
}
+
+ @Override
+ public void onSuccess(AsyncTransactionManagerImpl result) {
+ ApiFutures.addCallback(
+ result.closeAsync(),
+ new ApiFutureCallback() {
+ @Override
+ public void onFailure(Throwable t) {
+ res.setException(t);
+ }
+
+ @Override
+ public void onSuccess(Void result) {
+ session.close();
+ res.set(result);
+ }
+ },
+ MoreExecutors.directExecutor());
+ }
},
MoreExecutors.directExecutor());
+ return res;
}
@Override
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
index c7b95f33f6..bf1f214a71 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java
@@ -36,7 +36,9 @@
import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import com.google.cloud.spanner.Options.ReadOption;
+import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl;
import com.google.common.base.Function;
+import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Range;
@@ -47,6 +49,8 @@
import com.google.spanner.v1.CommitRequest;
import com.google.spanner.v1.ExecuteBatchDmlRequest;
import com.google.spanner.v1.ExecuteSqlRequest;
+import com.google.spanner.v1.RollbackRequest;
+import com.google.spanner.v1.TransactionSelector;
import io.grpc.Status;
import java.util.Arrays;
import java.util.Collection;
@@ -181,6 +185,30 @@ public void onSuccess(long[] input) {
}
}
+ @Test
+ public void asyncTransactionManager_shouldRollbackOnCloseAsync() throws Exception {
+ AsyncTransactionManager manager = client().transactionManagerAsync();
+ TransactionContext txn = manager.beginAsync().get();
+ txn.executeUpdateAsync(UPDATE_STATEMENT).get();
+ final TransactionSelector selector = ((TransactionContextImpl) txn).getTransactionSelector();
+
+ SpannerApiFutures.get(manager.closeAsync());
+ // The mock server should already have the Rollback request, as we are waiting for the returned
+ // ApiFuture to be done.
+ mockSpanner.waitForRequestsToContain(
+ new Predicate() {
+ @Override
+ public boolean apply(AbstractMessage input) {
+ if (input instanceof RollbackRequest) {
+ RollbackRequest request = (RollbackRequest) input;
+ return request.getTransactionId().equals(selector.getId());
+ }
+ return false;
+ }
+ },
+ 0L);
+ }
+
@Test
public void asyncTransactionManagerUpdate() throws Exception {
final SettableApiFuture updateCount = SettableApiFuture.create();
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
index 4f55cd5ebd..5ecf9607a4 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
@@ -23,8 +23,10 @@
import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
+import com.google.common.base.Predicate;
import com.google.common.base.Stopwatch;
import com.google.common.base.Throwables;
+import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.AbstractMessage;
import com.google.protobuf.ByteString;
@@ -1927,6 +1929,23 @@ public void waitForRequestsToContain(Class extends AbstractMessage> type, long
}
}
+ public void waitForRequestsToContain(
+ Predicate super AbstractMessage> predicate, long timeoutMillis)
+ throws InterruptedException, TimeoutException {
+ Stopwatch watch = Stopwatch.createStarted();
+ while (true) {
+ Iterable msg = Iterables.filter(getRequests(), predicate);
+ if (msg.iterator().hasNext()) {
+ break;
+ }
+ Thread.sleep(10L);
+ if (watch.elapsed(TimeUnit.MILLISECONDS) > timeoutMillis) {
+ throw new TimeoutException(
+ "Timeout while waiting for requests to contain the wanted request");
+ }
+ }
+ }
+
@Override
public void addResponse(AbstractMessage response) {
throw new UnsupportedOperationException();