From 436833b8dc5d43d285b5af35059f35c9aa1fd5cb Mon Sep 17 00:00:00 2001 From: Olav Loite Date: Fri, 17 Jul 2020 14:59:11 +0200 Subject: [PATCH] feat: add inline begin for async runner --- .../cloud/spanner/AbstractReadContext.java | 8 +- .../cloud/spanner/AbstractResultSet.java | 13 +- .../cloud/spanner/DatabaseClientImpl.java | 14 + .../com/google/cloud/spanner/SessionImpl.java | 5 +- .../cloud/spanner/TransactionRunnerImpl.java | 94 +++++-- .../cloud/spanner/GrpcResultSetTest.java | 8 +- .../cloud/spanner/InlineBeginBenchmark.java | 101 ++++--- .../spanner/InlineBeginTransactionTest.java | 246 +++++++++++++++++- .../cloud/spanner/ReadFormatTestRunner.java | 4 +- 9 files changed, 412 insertions(+), 81 deletions(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index d4e04942cd..5f5720048c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -633,7 +633,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, request.hasTransaction() && request.getTransaction().hasBegin()); } /** @@ -685,7 +686,7 @@ public void close() { public void onTransactionMetadata(Transaction transaction) {} @Override - public void onError(SpannerException e) {} + public void onError(SpannerException e, boolean withBeginTransaction) {} @Override public void onDone() {} @@ -746,7 +747,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - GrpcResultSet resultSet = new GrpcResultSet(stream, this); + GrpcResultSet resultSet = + new GrpcResultSet(stream, this, selector != null && selector.hasBegin()); return resultSet; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java index 11f36f4a92..17f62bf3a7 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java @@ -81,7 +81,7 @@ interface Listener { void onTransactionMetadata(Transaction transaction) throws SpannerException; /** Called when the read finishes with an error. */ - void onError(SpannerException e); + void onError(SpannerException e, boolean withBeginTransaction); /** Called when the read finishes normally. */ void onDone(); @@ -91,14 +91,17 @@ interface Listener { static class GrpcResultSet extends AbstractResultSet> { private final GrpcValueIterator iterator; private final Listener listener; + private final boolean beginTransaction; private GrpcStruct currRow; private SpannerException error; private ResultSetStats statistics; private boolean closed; - GrpcResultSet(CloseableIterator iterator, Listener listener) { + GrpcResultSet( + CloseableIterator iterator, Listener listener, boolean beginTransaction) { this.iterator = new GrpcValueIterator(iterator); this.listener = listener; + this.beginTransaction = beginTransaction; } @Override @@ -127,7 +130,7 @@ public boolean next() throws SpannerException { } return hasNext; } catch (SpannerException e) { - throw yieldError(e); + throw yieldError(e, beginTransaction && currRow == null); } } @@ -149,9 +152,9 @@ public Type getType() { return currRow.getType(); } - private SpannerException yieldError(SpannerException e) { + private SpannerException yieldError(SpannerException e, boolean beginTransaction) { close(); - listener.onError(e); + listener.onError(e, beginTransaction); throw e; } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java index 68f69fc97e..dd16e0a616 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java @@ -231,6 +231,10 @@ private TransactionManager inlinedTransactionManager() { @Override public AsyncRunner runAsync() { + return inlineBeginReadWriteTransactions ? inlinedRunAsync() : preparedRunAsync(); + } + + private AsyncRunner preparedRunAsync() { Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { return getReadWriteSession().runAsync(); @@ -240,6 +244,16 @@ public AsyncRunner runAsync() { } } + private AsyncRunner inlinedRunAsync() { + Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION_WITH_INLINE_BEGIN).startSpan(); + try (Scope s = tracer.withSpan(span)) { + return getReadSession().runAsync(); + } catch (RuntimeException e) { + TraceUtil.endSpanWithFailure(span, e); + throw e; + } + } + @Override public AsyncTransactionManager transactionManagerAsync() { Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 0990fbd53e..65c4b3a64d 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -238,7 +238,10 @@ public AsyncRunner runAsync() { return new AsyncRunnerImpl( setActive( new TransactionRunnerImpl( - this, spanner.getRpc(), spanner.getDefaultPrefetchChunks(), false))); + this, + spanner.getRpc(), + spanner.getDefaultPrefetchChunks(), + spanner.getOptions().isInlineBeginForReadWriteTransaction()))); } @Override diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index 59b2219e19..4187b5358a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -55,10 +55,10 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -158,7 +158,8 @@ public void removeListener(Runnable listener) { * been created, the lock is released and concurrent requests can be executed on the * transaction. */ - private final ReentrantLock transactionLock = new ReentrantLock(); + // private final ReentrantLock transactionLock = new ReentrantLock(); + private volatile CountDownLatch transactionLatch = new CountDownLatch(0); private volatile ByteString transactionId; private Timestamp commitTimestamp; @@ -333,7 +334,7 @@ public void run() { } span.addAnnotation("Commit Failed", TraceUtil.getExceptionAnnotations(e)); TraceUtil.endSpanWithFailure(opSpan, e); - onError((SpannerException) e); + onError((SpannerException) e, false); res.setException(e); } } @@ -401,20 +402,38 @@ TransactionSelector getTransactionSelector() { try { // Wait if another request is already beginning, committing or rolling back the // transaction. - transactionLock.lockInterruptibly(); - // Check again if a transactionId is now available. It could be that the thread that was - // holding the lock and that had sent a statement with a BeginTransaction request caused - // an error and did not return a transaction. - if (transactionId == null) { - // Return a TransactionSelector that will start a new transaction as part of the - // statement that is being executed. - return TransactionSelector.newBuilder() - .setBegin( - TransactionOptions.newBuilder() - .setReadWrite(TransactionOptions.ReadWrite.getDefaultInstance())) - .build(); - } else { - transactionLock.unlock(); + + // transactionLock.lockInterruptibly(); + while (true) { + CountDownLatch latch; + synchronized (lock) { + latch = transactionLatch; + } + latch.await(); + + synchronized (lock) { + if (transactionLatch.getCount() > 0L) { + continue; + } + // Check again if a transactionId is now available. It could be that the thread that + // was + // holding the lock and that had sent a statement with a BeginTransaction request + // caused + // an error and did not return a transaction. + if (transactionId == null) { + transactionLatch = new CountDownLatch(1); + // Return a TransactionSelector that will start a new transaction as part of the + // statement that is being executed. + return TransactionSelector.newBuilder() + .setBegin( + TransactionOptions.newBuilder() + .setReadWrite(TransactionOptions.ReadWrite.getDefaultInstance())) + .build(); + } else { + // transactionLock.unlock(); + break; + } + } } } catch (InterruptedException e) { throw SpannerExceptionFactory.newSpannerExceptionForCancellation(null, e); @@ -430,18 +449,24 @@ public void onTransactionMetadata(Transaction transaction) { // transaction on this instance and release the lock to allow other statements to proceed. if (this.transactionId == null && transaction != null && transaction.getId() != null) { this.transactionId = transaction.getId(); - transactionLock.unlock(); + transactionLatch.countDown(); + // transactionLock.unlock(); } } @Override - public void onError(SpannerException e) { + public void onError(SpannerException e, boolean withBeginTransaction) { // Release the transactionLock if that is being held by this thread. That would mean that the // statement that was trying to start a transaction caused an error. The next statement should // in that case also include a BeginTransaction option. - if (transactionLock.isHeldByCurrentThread()) { - transactionLock.unlock(); + + // if (transactionLock.isHeldByCurrentThread()) { + // transactionLock.unlock(); + // } + if (withBeginTransaction) { + transactionLatch.countDown(); } + if (e.getErrorCode() == ErrorCode.ABORTED) { long delay = -1L; if (e instanceof AbortedException) { @@ -494,7 +519,7 @@ public long executeUpdate(Statement statement) { // For standard DML, using the exact row count. return resultSet.getStats().getRowCountExact(); } catch (SpannerException e) { - onError(e); + onError(e, builder.hasTransaction() && builder.getTransaction().hasBegin()); throw e; } } @@ -504,7 +529,7 @@ public ApiFuture executeUpdateAsync(Statement statement) { beforeReadOrQuery(); final ExecuteSqlRequest.Builder builder = getExecuteSqlRequestBuilder(statement, QueryMode.NORMAL); - ApiFuture resultSet; + final ApiFuture resultSet; try { // Register the update as an async operation that must finish before the transaction may // commit. @@ -538,7 +563,7 @@ public Long apply(ResultSet input) { @Override public Long apply(Throwable input) { SpannerException e = SpannerExceptionFactory.newSpannerException(input); - onError(e); + onError(e, builder.hasTransaction() && builder.getTransaction().hasBegin()); throw e; } }, @@ -547,6 +572,14 @@ public Long apply(Throwable input) { new Runnable() { @Override public void run() { + try { + if (resultSet.get().getMetadata().hasTransaction()) { + onTransactionMetadata(resultSet.get().getMetadata().getTransaction()); + } + } catch (ExecutionException | InterruptedException e) { + // Ignore this error here as it is handled by the future that is returned by the + // executeUpdateAsync method. + } decreaseAsyncOperations(); } }, @@ -582,7 +615,7 @@ public long[] batchUpdate(Iterable statements) { } return results; } catch (SpannerException e) { - onError(e); + onError(e, builder.hasTransaction() && builder.getTransaction().hasBegin()); throw e; } } @@ -610,6 +643,9 @@ public long[] apply(ExecuteBatchDmlResponse input) { long[] results = new long[input.getResultSetsCount()]; for (int i = 0; i < input.getResultSetsCount(); ++i) { results[i] = input.getResultSets(i).getStats().getRowCountExact(); + if (input.getResultSets(i).getMetadata().hasTransaction()) { + onTransactionMetadata(input.getResultSets(i).getMetadata().getTransaction()); + } } // If one of the DML statements was aborted, we should throw an aborted exception. // In all other cases, we should throw a BatchUpdateException. @@ -633,9 +669,13 @@ public void run() { try { updateCounts.get(); } catch (ExecutionException e) { - onError(SpannerExceptionFactory.newSpannerException(e.getCause())); + onError( + SpannerExceptionFactory.newSpannerException(e.getCause()), + builder.hasTransaction() && builder.getTransaction().hasBegin()); } catch (InterruptedException e) { - onError(SpannerExceptionFactory.propagateInterrupt(e)); + onError( + SpannerExceptionFactory.propagateInterrupt(e), + builder.hasTransaction() && builder.getTransaction().hasBegin()); } finally { decreaseAsyncOperations(); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java index 4952e179ad..9cf66dd222 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java @@ -58,7 +58,7 @@ private static class NoOpListener implements AbstractResultSet.Listener { public void onTransactionMetadata(Transaction transaction) throws SpannerException {} @Override - public void onError(SpannerException e) {} + public void onError(SpannerException e, boolean withBeginTransaction) {} @Override public void onDone() {} @@ -76,11 +76,11 @@ public void cancel(@Nullable String message) {} public void request(int numMessages) {} }); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); } public AbstractResultSet.GrpcResultSet resultSetWithMode(QueryMode queryMode) { - return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); } @Test @@ -641,7 +641,7 @@ public com.google.protobuf.Value apply(@Nullable Value input) { private void verifySerialization( Function protoFn, Value... values) { - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); PartialResultSet.Builder builder = PartialResultSet.newBuilder(); List types = new ArrayList<>(); for (Value value : values) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginBenchmark.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginBenchmark.java index c0d192dd2a..3e08f0f633 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginBenchmark.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginBenchmark.java @@ -21,10 +21,12 @@ import com.google.api.gax.rpc.TransportChannelProvider; import com.google.cloud.NoCredentials; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import com.google.common.base.Stopwatch; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Random; @@ -46,10 +48,10 @@ import org.openjdk.jmh.annotations.Warmup; /** - * Benchmarks for inlining the BeginTransaction RPC with the first statement of a transaction. The simulated execution times are based on - * reasonable estimates and are primarily intended to keep the benchmarks comparable with each other - * before and after changes have been made to the pool. The benchmarks are bound to the Maven - * profile `benchmark` and can be executed like this: + * Benchmarks for inlining the BeginTransaction RPC with the first statement of a transaction. The + * simulated execution times are based on reasonable estimates and are primarily intended to keep + * the benchmarks comparable with each other before and after changes have been made to the pool. + * The benchmarks are bound to the Maven profile `benchmark` and can be executed like this: * mvn clean test -DskipTests -Pbenchmark -Dbenchmark.name=InlineBeginBenchmark * */ @@ -68,10 +70,13 @@ public class InlineBeginBenchmark { @State(Scope.Thread) public static class BenchmarkState { + private final boolean useRealServer = Boolean.valueOf(System.getProperty("useRealServer")); + private final String instance = System.getProperty("instance", TEST_INSTANCE); + private final String database = System.getProperty("database", TEST_DATABASE); private StandardBenchmarkMockServer mockServer; private Spanner spanner; private DatabaseClientImpl client; - + @Param({"false", "true"}) boolean inlineBegin; @@ -80,36 +85,61 @@ public static class BenchmarkState { @Setup(Level.Invocation) public void setup() throws Exception { - mockServer = new StandardBenchmarkMockServer(); - TransportChannelProvider channelProvider = mockServer.start(); - - SpannerOptions options = - SpannerOptions.newBuilder() - .setProjectId(TEST_PROJECT) - .setChannelProvider(channelProvider) - .setCredentials(NoCredentials.getInstance()) - .setSessionPoolOption( - SessionPoolOptions.newBuilder() - .setWriteSessionsFraction(writeFraction) - .build()) - .setInlineBeginForReadWriteTransaction(inlineBegin) - .build(); + System.out.println("useRealServer: " + System.getProperty("useRealServer")); + System.out.println("instance: " + System.getProperty("instance")); + SpannerOptions options; + if (useRealServer) { + System.out.println("running benchmark with **REAL** server"); + System.out.println("instance: " + instance); + System.out.println("database: " + database); + options = createRealServerOptions(); + } else { + System.out.println("running benchmark with **MOCK** server"); + mockServer = new StandardBenchmarkMockServer(); + TransportChannelProvider channelProvider = mockServer.start(); + options = createBenchmarkServerOptions(channelProvider); + } spanner = options.getService(); client = (DatabaseClientImpl) - spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + spanner.getDatabaseClient(DatabaseId.of(options.getProjectId(), instance, database)); + Stopwatch watch = Stopwatch.createStarted(); // Wait until the session pool has initialized. while (client.pool.getNumberOfSessionsInPool() < spanner.getOptions().getSessionPoolOptions().getMinSessions()) { Thread.sleep(1L); + if (watch.elapsed(TimeUnit.SECONDS) > 10L) { + break; + } } } + SpannerOptions createBenchmarkServerOptions(TransportChannelProvider channelProvider) { + return SpannerOptions.newBuilder() + .setProjectId(TEST_PROJECT) + .setChannelProvider(channelProvider) + .setCredentials(NoCredentials.getInstance()) + .setSessionPoolOption( + SessionPoolOptions.newBuilder().setWriteSessionsFraction(writeFraction).build()) + .setInlineBeginForReadWriteTransaction(inlineBegin) + .build(); + } + + SpannerOptions createRealServerOptions() throws IOException { + return SpannerOptions.newBuilder() + .setSessionPoolOption( + SessionPoolOptions.newBuilder().setWriteSessionsFraction(writeFraction).build()) + .setInlineBeginForReadWriteTransaction(inlineBegin) + .build(); + } + @TearDown(Level.Invocation) public void teardown() throws Exception { spanner.close(); - mockServer.shutdown(); + if (mockServer != null) { + mockServer.shutdown(); + } } } @@ -118,10 +148,9 @@ public void teardown() throws Exception { public void burstRead(final BenchmarkState server) throws Exception { int totalQueries = server.spanner.getOptions().getSessionPoolOptions().getMaxSessions() * 8; int parallelThreads = server.spanner.getOptions().getSessionPoolOptions().getMaxSessions() * 2; - final DatabaseClient client = - server.spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); - SessionPool pool = ((DatabaseClientImpl) client).pool; - assertThat(pool.totalSessions()).isEqualTo(server.spanner.getOptions().getSessionPoolOptions().getMinSessions()); + SessionPool pool = server.client.pool; + assertThat(pool.totalSessions()) + .isEqualTo(server.spanner.getOptions().getSessionPoolOptions().getMinSessions()); ListeningScheduledExecutorService service = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(parallelThreads)); @@ -134,7 +163,7 @@ public void burstRead(final BenchmarkState server) throws Exception { public Void call() throws Exception { Thread.sleep(RND.nextInt(RND_WAIT_TIME_BETWEEN_REQUESTS)); try (ResultSet rs = - client.singleUse().executeQuery(StandardBenchmarkMockServer.SELECT1)) { + server.client.singleUse().executeQuery(StandardBenchmarkMockServer.SELECT1)) { while (rs.next()) { Thread.sleep(RND.nextInt(HOLD_SESSION_TIME)); } @@ -152,10 +181,9 @@ public Void call() throws Exception { public void burstWrite(final BenchmarkState server) throws Exception { int totalWrites = server.spanner.getOptions().getSessionPoolOptions().getMaxSessions() * 8; int parallelThreads = server.spanner.getOptions().getSessionPoolOptions().getMaxSessions() * 2; - final DatabaseClient client = - server.spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); - SessionPool pool = ((DatabaseClientImpl) client).pool; - assertThat(pool.totalSessions()).isEqualTo(server.spanner.getOptions().getSessionPoolOptions().getMinSessions()); + SessionPool pool = server.client.pool; + assertThat(pool.totalSessions()) + .isEqualTo(server.spanner.getOptions().getSessionPoolOptions().getMinSessions()); ListeningScheduledExecutorService service = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(parallelThreads)); @@ -167,7 +195,7 @@ public void burstWrite(final BenchmarkState server) throws Exception { @Override public Long call() throws Exception { Thread.sleep(RND.nextInt(RND_WAIT_TIME_BETWEEN_REQUESTS)); - TransactionRunner runner = client.readWriteTransaction(); + TransactionRunner runner = server.client.readWriteTransaction(); return runner.run( new TransactionCallable() { @Override @@ -189,10 +217,9 @@ public void burstReadAndWrite(final BenchmarkState server) throws Exception { int totalWrites = server.spanner.getOptions().getSessionPoolOptions().getMaxSessions() * 4; int totalReads = server.spanner.getOptions().getSessionPoolOptions().getMaxSessions() * 4; int parallelThreads = server.spanner.getOptions().getSessionPoolOptions().getMaxSessions() * 2; - final DatabaseClient client = - server.spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); - SessionPool pool = ((DatabaseClientImpl) client).pool; - assertThat(pool.totalSessions()).isEqualTo(server.spanner.getOptions().getSessionPoolOptions().getMinSessions()); + SessionPool pool = server.client.pool; + assertThat(pool.totalSessions()) + .isEqualTo(server.spanner.getOptions().getSessionPoolOptions().getMinSessions()); ListeningScheduledExecutorService service = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(parallelThreads)); @@ -204,7 +231,7 @@ public void burstReadAndWrite(final BenchmarkState server) throws Exception { @Override public Long call() throws Exception { Thread.sleep(RND.nextInt(RND_WAIT_TIME_BETWEEN_REQUESTS)); - TransactionRunner runner = client.readWriteTransaction(); + TransactionRunner runner = server.client.readWriteTransaction(); return runner.run( new TransactionCallable() { @Override @@ -224,7 +251,7 @@ public Long run(TransactionContext transaction) throws Exception { public Void call() throws Exception { Thread.sleep(RND.nextInt(RND_WAIT_TIME_BETWEEN_REQUESTS)); try (ResultSet rs = - client.singleUse().executeQuery(StandardBenchmarkMockServer.SELECT1)) { + server.client.singleUse().executeQuery(StandardBenchmarkMockServer.SELECT1)) { while (rs.next()) { Thread.sleep(RND.nextInt(HOLD_SESSION_TIME)); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java index f7c10848ad..f9f6b0cd2d 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/InlineBeginTransactionTest.java @@ -19,10 +19,18 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.api.core.ApiAsyncFunction; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; import com.google.api.gax.grpc.testing.LocalChannelProvider; import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.cloud.spanner.AsyncRunner.AsyncWork; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.AbstractMessage; import com.google.protobuf.ListValue; import com.google.spanner.v1.BeginTransactionRequest; @@ -36,8 +44,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; @@ -49,10 +61,24 @@ import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class InlineBeginTransactionTest { + @Parameter public Executor executor; + + @Parameters(name = "executor = {0}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {MoreExecutors.directExecutor()}, + {Executors.newSingleThreadExecutor()}, + {Executors.newFixedThreadPool(4)} + }); + } + private static MockSpannerServiceImpl mockSpanner; private static Server server; private static LocalChannelProvider channelProvider; @@ -414,6 +440,222 @@ public void testTransactionManagerInlinedBeginTxWithError() { assertThat(countTransactionsStarted()).isEqualTo(2); } + @Test + public void testInlinedBeginAsyncTx() throws InterruptedException, ExecutionException { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + ApiFuture updateCount = + client + .runAsync() + .runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return txn.executeUpdateAsync(UPDATE_STATEMENT); + } + }, + executor); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(countRequests(BeginTransactionRequest.class)).isEqualTo(0); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + + @Test + public void testInlinedBeginAsyncTxAborted() throws InterruptedException, ExecutionException { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + final AtomicBoolean firstAttempt = new AtomicBoolean(true); + ApiFuture updateCount = + client + .runAsync() + .runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + ApiFuture res = txn.executeUpdateAsync(UPDATE_STATEMENT); + if (firstAttempt.getAndSet(false)) { + mockSpanner.abortTransaction(txn); + } + return res; + } + }, + executor); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(countRequests(BeginTransactionRequest.class)).isEqualTo(0); + // We have started 2 transactions, because the first transaction aborted. + assertThat(countTransactionsStarted()).isEqualTo(2); + } + + @Test + public void testInlinedBeginAsyncTxWithQuery() throws InterruptedException, ExecutionException { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + final ExecutorService queryExecutor = Executors.newSingleThreadExecutor(); + ApiFuture updateCount = + client + .runAsync() + .runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + final SettableApiFuture res = SettableApiFuture.create(); + try (AsyncResultSet rs = txn.executeQueryAsync(SELECT1)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + switch (resultSet.tryNext()) { + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + res.set(resultSet.getLong(0)); + default: + throw new IllegalStateException(); + } + } + }); + } + return res; + } + }, + queryExecutor); + assertThat(updateCount.get()).isEqualTo(1L); + assertThat(countRequests(BeginTransactionRequest.class)).isEqualTo(0); + assertThat(countTransactionsStarted()).isEqualTo(1); + queryExecutor.shutdown(); + } + + @Test + public void testInlinedBeginAsyncTxWithBatchDml() + throws InterruptedException, ExecutionException { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + ApiFuture updateCounts = + client + .runAsync() + .runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext transaction) { + return transaction.batchUpdateAsync( + Arrays.asList(UPDATE_STATEMENT, UPDATE_STATEMENT)); + } + }, + executor); + assertThat(updateCounts.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + assertThat(countRequests(BeginTransactionRequest.class)).isEqualTo(0); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + + @Test + public void testInlinedBeginAsyncTxWithError() throws InterruptedException, ExecutionException { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + ApiFuture updateCount = + client + .runAsync() + .runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext transaction) { + transaction.executeUpdateAsync(INVALID_UPDATE_STATEMENT); + return transaction.executeUpdateAsync(UPDATE_STATEMENT); + } + }, + executor); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(countRequests(BeginTransactionRequest.class)).isEqualTo(0); + // The first update will start a transaction, but then fail the update statement. This will + // start a transaction on the mock server, but that transaction will never be returned to the + // client. + assertThat(countTransactionsStarted()).isEqualTo(2); + } + + @Test + public void testInlinedBeginAsyncTxWithParallelQueries() + throws InterruptedException, ExecutionException { + final int numQueries = 100; + final ScheduledExecutorService executor = Executors.newScheduledThreadPool(16); + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + ApiFuture updateCount = + client + .runAsync() + .runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(final TransactionContext txn) { + List> futures = new ArrayList<>(numQueries); + for (int i = 0; i < numQueries; i++) { + final SettableApiFuture res = SettableApiFuture.create(); + try (AsyncResultSet rs = txn.executeQueryAsync(SELECT1)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + switch (resultSet.tryNext()) { + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + res.set(resultSet.getLong(0)); + default: + throw new IllegalStateException(); + } + } + }); + } + futures.add(res); + } + return ApiFutures.transformAsync( + ApiFutures.allAsList(futures), + new ApiAsyncFunction, Long>() { + @Override + public ApiFuture apply(List input) throws Exception { + long sum = 0L; + for (Long l : input) { + sum += l; + } + return ApiFutures.immediateFuture(sum); + } + }, + MoreExecutors.directExecutor()); + } + }, + executor); + assertThat(updateCount.get()).isEqualTo(1L * numQueries); + assertThat(countRequests(BeginTransactionRequest.class)).isEqualTo(0); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + + @Test + public void testInlinedBeginAsyncTxWithOnlyMutations() + throws InterruptedException, ExecutionException { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + client + .runAsync() + .runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext transaction) { + transaction.buffer(Mutation.delete("FOO", Key.of(1L))); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .get(); + // There should be 1 call to BeginTransaction because there is no statement that we can use to + // inline the BeginTransaction call with. + assertThat(countRequests(BeginTransactionRequest.class)).isEqualTo(1); + assertThat(countTransactionsStarted()).isEqualTo(1); + } + private int countRequests(Class requestType) { int count = 0; for (AbstractMessage msg : mockSpanner.getRequests()) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java index 475d8325a9..50cf96ff3c 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java @@ -47,7 +47,7 @@ private static class NoOpListener implements AbstractResultSet.Listener { public void onTransactionMetadata(Transaction transaction) throws SpannerException {} @Override - public void onError(SpannerException e) {} + public void onError(SpannerException e, boolean withBeginTransaction) {} @Override public void onDone() {} @@ -119,7 +119,7 @@ public void cancel(@Nullable String message) {} public void request(int numMessages) {} }); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false); JSONArray chunks = testCase.getJSONArray("chunks"); JSONObject expectedResult = testCase.getJSONObject("result");