diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunner.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunner.java index 432d6a8645..de15d79c7a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunner.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunner.java @@ -48,10 +48,6 @@ interface AsyncWork { * * @param txn the transaction * @return future over the result of the work - *

TODO(loite): It's probably better to let this method return `R` instead of - * `ApiFuture`, as we need to wait until the result of the work has actually finished - * before we can commit the transaction. Returning an ApiFuture here just means that the - * underlying framework code still has to call {@link ApiFuture#get()} before committing. */ ApiFuture doWorkAsync(TransactionContext txn); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java index 5782ed8e66..5dbdd1092f 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java @@ -23,9 +23,12 @@ import com.google.api.core.ApiFunction; import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; import com.google.api.gax.grpc.testing.LocalChannelProvider; import com.google.cloud.NoCredentials; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; import com.google.cloud.spanner.AsyncRunner.AsyncWork; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; @@ -35,10 +38,15 @@ import io.grpc.Server; import io.grpc.Status; import io.grpc.inprocess.InProcessServerBuilder; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.SynchronousQueue; import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.AfterClass; @@ -302,6 +310,82 @@ public ApiFuture doWorkAsync(TransactionContext txn) { } } + @Test + public void asyncRunnerWaitsUntilAsyncUpdateHasFinished() { + AsyncRunner runner = client.runAsync(); + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + txn.executeUpdateAsync(UPDATE_STATEMENT); + return ApiFutures.immediateFuture(null); + } + }, + executor); + } + @Test + public void closeTransactionBeforeEndOfAsyncQuery() throws Exception { + final BlockingQueue results = new SynchronousQueue<>(); + final SettableApiFuture finished = SettableApiFuture.create(); + DatabaseClientImpl clientImpl = (DatabaseClientImpl) client; + + // There should currently not be any sessions checked out of the pool. + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(0); + + AsyncRunner runner = client.runAsync(); + final CountDownLatch dataReceived = new CountDownLatch(1); + ApiFuture res = runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + try (AsyncResultSet rs = + txn.readAsync(READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES, Options.bufferRows(1))) { + rs.setCallback( + Executors.newSingleThreadExecutor(), + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + finished.set(true); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + dataReceived.countDown(); + results.put(resultSet.getString(0)); + } + } + } catch (Throwable t) { + finished.setException(t); + dataReceived.countDown(); + return CallbackResponse.DONE; + } + } + }); + } + return ApiFutures.immediateFuture(null); + } + }, + executor); + // Wait until at least one row has been fetched. At that moment there should be one session + // checked out. + dataReceived.await(); + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(1); + assertThat(res.isDone()).isFalse(); + // Get the data from the transaction. + List resultList = new ArrayList<>(); + do { + results.drainTo(resultList); + } while (!finished.isDone() || results.size() > 0); + assertThat(finished.get()).isTrue(); + assertThat(resultList).containsExactly("k1", "k2", "k3"); + assertThat(res.get()).isNull(); + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(0); + } + @Test public void asyncRunnerReadRow() throws Exception { AsyncRunner runner = client.runAsync();