diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 4ec5133b1a..cff8ec1f57 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -608,14 +608,10 @@ ExecuteBatchDmlRequest.Builder getExecuteBatchDmlRequestBuilder(Iterable startStream(@Nullable ByteString resumeToken) { GrpcStreamIterator stream = new GrpcStreamIterator(statement, prefetchChunks); + final ExecuteSqlRequest.Builder request = + getExecuteSqlRequestBuilder(statement, queryMode); + if (partitionToken != null) { + request.setPartitionToken(partitionToken); + } if (resumeToken != null) { request.setResumeToken(resumeToken); } SpannerRpc.StreamingCall call = rpc.executeQuery(request.build(), stream.consumer(), session.getOptions()); call.request(prefetchChunks); - stream.setCall(call); + stream.setCall(call, request.hasTransaction() && request.getTransaction().hasBegin()); return stream; } }; - return new GrpcResultSet( - stream, this, request.hasTransaction() && request.getTransaction().hasBegin()); + return new GrpcResultSet(stream, this); } /** @@ -723,10 +723,6 @@ ResultSet readInternalWithOptions( if (index != null) { builder.setIndex(index); } - TransactionSelector selector = getTransactionSelector(); - if (selector != null) { - builder.setTransaction(selector); - } if (partitionToken != null) { builder.setPartitionToken(partitionToken); } @@ -740,15 +736,18 @@ CloseableIterator startStream(@Nullable ByteString resumeToken if (resumeToken != null) { builder.setResumeToken(resumeToken); } + TransactionSelector selector = getTransactionSelector(); + if (selector != null) { + builder.setTransaction(selector); + } SpannerRpc.StreamingCall call = rpc.read(builder.build(), stream.consumer(), session.getOptions()); call.request(prefetchChunks); - stream.setCall(call); + stream.setCall(call, selector != null && selector.hasBegin()); return stream; } }; - GrpcResultSet resultSet = - new GrpcResultSet(stream, this, selector != null && selector.hasBegin()); + GrpcResultSet resultSet = new GrpcResultSet(stream, this); return resultSet; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java index 3c5e60f51a..6520b7b8fd 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java @@ -91,17 +91,14 @@ interface Listener { static class GrpcResultSet extends AbstractResultSet> { private final GrpcValueIterator iterator; private final Listener listener; - private final boolean beginTransaction; private GrpcStruct currRow; private SpannerException error; private ResultSetStats statistics; private boolean closed; - GrpcResultSet( - CloseableIterator iterator, Listener listener, boolean beginTransaction) { + GrpcResultSet(CloseableIterator iterator, Listener listener) { this.iterator = new GrpcValueIterator(iterator); this.listener = listener; - this.beginTransaction = beginTransaction; } @Override @@ -130,7 +127,7 @@ public boolean next() throws SpannerException { } return hasNext; } catch (SpannerException e) { - throw yieldError(e, beginTransaction && currRow == null); + throw yieldError(e, iterator.isWithBeginTransaction() && currRow == null); } } @@ -297,6 +294,10 @@ void close(@Nullable String message) { stream.close(message); } + boolean isWithBeginTransaction() { + return stream.isWithBeginTransaction(); + } + /** @param a is a mutable list and b will be concatenated into a. */ private void concatLists(List a, List b) { if (a.size() == 0 || b.size() == 0) { @@ -760,6 +761,8 @@ interface CloseableIterator extends Iterator { * @param message a message to include in the final RPC status */ void close(@Nullable String message); + + boolean isWithBeginTransaction(); } /** Adapts a streaming read/query call into an iterator over partial result sets. */ @@ -774,6 +777,7 @@ static class GrpcStreamIterator extends AbstractIterator private final Statement statement; private SpannerRpc.StreamingCall call; + private boolean withBeginTransaction; private SpannerException error; @VisibleForTesting @@ -792,8 +796,9 @@ protected final SpannerRpc.ResultStreamConsumer consumer() { return consumer; } - public void setCall(SpannerRpc.StreamingCall call) { + public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { this.call = call; + this.withBeginTransaction = withBeginTransaction; } @Override @@ -803,6 +808,11 @@ public void close(@Nullable String message) { } } + @Override + public boolean isWithBeginTransaction() { + return withBeginTransaction; + } + @Override protected final PartialResultSet computeNext() { PartialResultSet next; @@ -873,8 +883,8 @@ public void onError(SpannerException e) { // Visible only for testing. @VisibleForTesting - void setCall(SpannerRpc.StreamingCall call) { - GrpcStreamIterator.this.setCall(call); + void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { + GrpcStreamIterator.this.setCall(call, withBeginTransaction); } } } @@ -987,6 +997,11 @@ public void close(@Nullable String message) { } } + @Override + public boolean isWithBeginTransaction() { + return stream != null && stream.isWithBeginTransaction(); + } + @Override protected PartialResultSet computeNext() { Context context = Context.current(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index e38b704f70..02119b13a1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -267,7 +267,7 @@ ApiFuture commitAsync() { final SettableApiFuture finishOps; CommitRequest.Builder builder = CommitRequest.newBuilder().setSession(session.getName()); synchronized (lock) { - if (transactionIdFuture == null && transactionId == null) { + if (transactionIdFuture == null && transactionId == null && runningAsyncOperations == 0) { finishOps = SettableApiFuture.create(); createTxnAsync(finishOps); } else { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java index a2ab8dbc90..f1a1a3e296 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java @@ -75,13 +75,14 @@ public void cancel(@Nullable String message) {} @Override public void request(int numMessages) {} - }); + }, + false); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); + resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); } public AbstractResultSet.GrpcResultSet resultSetWithMode(QueryMode queryMode) { - return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); + return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); } @Test @@ -642,7 +643,7 @@ public com.google.protobuf.Value apply(@Nullable Value input) { private void verifySerialization( Function protoFn, Value... values) { - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); + resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); PartialResultSet.Builder builder = PartialResultSet.newBuilder(); List types = new ArrayList<>(); for (Value value : values) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java index d1e3d93cb7..b7c7d91720 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import static com.google.cloud.spanner.MockSpannerTestUtil.SELECT1; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; @@ -65,6 +66,7 @@ import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.After; import org.junit.AfterClass; @@ -1139,6 +1141,123 @@ public ApiFuture apply(TransactionContext txn, Long input) assertThat(countTransactionsStarted()).isEqualTo(1); } + @Test + public void queryWithoutNext() { + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + assertThat( + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + // This will not actually send an RPC, so it will also not request a + // transaction. + transaction.executeQuery(SELECT1); + return transaction.executeUpdate(UPDATE_STATEMENT); + } + })) + .isEqualTo(UPDATE_COUNT); + assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L); + assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + + @Test + public void queryAsyncWithoutCallback() { + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + assertThat( + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + transaction.executeQueryAsync(SELECT1); + return transaction.executeUpdate(UPDATE_STATEMENT); + } + })) + .isEqualTo(UPDATE_COUNT); + assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L); + assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + + @Test + public void readWithoutNext() { + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + assertThat( + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + transaction.read("FOO", KeySet.all(), Arrays.asList("ID")); + return transaction.executeUpdate(UPDATE_STATEMENT); + } + })) + .isEqualTo(UPDATE_COUNT); + assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L); + assertThat(mockSpanner.countRequestsOfType(ReadRequest.class)).isEqualTo(0L); + assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + + @Test + public void readAsyncWithoutCallback() { + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + assertThat( + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + transaction.readAsync("FOO", KeySet.all(), Arrays.asList("ID")); + return transaction.executeUpdate(UPDATE_STATEMENT); + } + })) + .isEqualTo(UPDATE_COUNT); + assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L); + assertThat(mockSpanner.countRequestsOfType(ReadRequest.class)).isEqualTo(0L); + assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + + @Test + public void query_ThenUpdate_ThenConsumeResultSet() + throws InterruptedException, TimeoutException { + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + assertThat( + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + ResultSet rs = transaction.executeQuery(SELECT1); + long updateCount = transaction.executeUpdate(UPDATE_STATEMENT); + // Consume the result set. + while (rs.next()) {} + return updateCount; + } + })) + .isEqualTo(UPDATE_COUNT); + // The update statement should start the transaction, and the query should use the transaction + // id returned by the update. + assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L); + assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(2L); + assertThat(countTransactionsStarted()).isEqualTo(1); + List requests = mockSpanner.getRequests(); + requests = requests.subList(requests.size() - 3, requests.size()); + assertThat(requests.get(0)).isInstanceOf(ExecuteSqlRequest.class); + assertThat(((ExecuteSqlRequest) requests.get(0)).getSql()).isEqualTo(UPDATE_STATEMENT.getSql()); + assertThat(requests.get(1)).isInstanceOf(ExecuteSqlRequest.class); + assertThat(((ExecuteSqlRequest) requests.get(1)).getSql()).isEqualTo(SELECT1.getSql()); + assertThat(requests.get(2)).isInstanceOf(CommitRequest.class); + } + private int countRequests(Class requestType) { int count = 0; for (AbstractMessage msg : mockSpanner.getRequests()) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java index 50cf96ff3c..aa479f71d4 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java @@ -117,9 +117,10 @@ public void cancel(@Nullable String message) {} @Override public void request(int numMessages) {} - }); + }, + false); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); + resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); JSONArray chunks = testCase.getJSONArray("chunks"); JSONObject expectedResult = testCase.getJSONObject("result"); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java index ef744d31a1..4f38aee940 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java @@ -116,6 +116,11 @@ protected PartialResultSet computeNext() { public void close(@Nullable String message) { stream.close(); } + + @Override + public boolean isWithBeginTransaction() { + return false; + } } Starter starter = Mockito.mock(Starter.class);