diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml index a10e096af4..13462a986c 100644 --- a/google-cloud-spanner/clirr-ignored-differences.xml +++ b/google-cloud-spanner/clirr-ignored-differences.xml @@ -93,7 +93,6 @@ com/google/cloud/spanner/DatabaseAdminClient com.google.cloud.spanner.Backup updateBackup(java.lang.String, java.lang.String, com.google.cloud.Timestamp) - 7012 com/google/cloud/spanner/spi/v1/SpannerRpc @@ -147,6 +146,88 @@ com.google.api.gax.paging.Page listDatabases() + + + 7012 + com/google/cloud/spanner/spi/v1/SpannerRpc + com.google.api.core.ApiFuture executeQueryAsync(com.google.spanner.v1.ExecuteSqlRequest, java.util.Map) + + + 7012 + com/google/cloud/spanner/DatabaseClient + * runAsync(*) + + + 7012 + com/google/cloud/spanner/DatabaseClient + * transactionManagerAsync(*) + + + 7012 + com/google/cloud/spanner/Spanner + * getAsyncExecutorProvider(*) + + + 7012 + com/google/cloud/spanner/ReadContext + * executeQueryAsync(*) + + + 7012 + com/google/cloud/spanner/ReadContext + * readAsync(*) + + + 7012 + com/google/cloud/spanner/ReadContext + * readRowAsync(*) + + + 7012 + com/google/cloud/spanner/ReadContext + * readUsingIndexAsync(*) + + + 7012 + com/google/cloud/spanner/ReadContext + * readRowUsingIndexAsync(*) + + + 7012 + com/google/cloud/spanner/TransactionContext + * batchUpdateAsync(*) + + + 7012 + com/google/cloud/spanner/TransactionContext + * executeUpdateAsync(*) + + + 7012 + com/google/cloud/spanner/spi/v1/SpannerRpc + * beginTransactionAsync(*) + + + 7012 + com/google/cloud/spanner/spi/v1/SpannerRpc + * commitAsync(*) + + + 7012 + com/google/cloud/spanner/spi/v1/SpannerRpc + * rollbackAsync(*) + + + 7012 + com/google/cloud/spanner/spi/v1/SpannerRpc + * executeBatchDmlAsync(*) + + + 7012 + com/google/cloud/spanner/connection/Connection + * executeQueryAsync(*) + + 7012 diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 685e9a1e31..bc4a868564 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -21,16 +21,24 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutureCallback; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; import com.google.cloud.spanner.AbstractResultSet.GrpcResultSet; import com.google.cloud.spanner.AbstractResultSet.GrpcStreamIterator; import com.google.cloud.spanner.AbstractResultSet.ResumableStreamIterator; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; import com.google.cloud.spanner.Options.QueryOption; import com.google.cloud.spanner.Options.ReadOption; import com.google.cloud.spanner.SessionImpl.SessionTransaction; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.ExecuteBatchDmlRequest; @@ -62,6 +70,7 @@ abstract static class Builder, T extends AbstractReadCon private Span span = Tracing.getTracer().getCurrentSpan(); private int defaultPrefetchChunks = SpannerOptions.Builder.DEFAULT_PREFETCH_CHUNKS; private QueryOptions defaultQueryOptions = SpannerOptions.Builder.DEFAULT_QUERY_OPTIONS; + private ExecutorProvider executorProvider; Builder() {} @@ -95,9 +104,25 @@ B setDefaultQueryOptions(QueryOptions defaultQueryOptions) { return self(); } + B setExecutorProvider(ExecutorProvider executorProvider) { + this.executorProvider = executorProvider; + return self(); + } + abstract T build(); } + /** + * {@link AsyncResultSet} that supports adding listeners that are called when all rows from the + * underlying result stream have been fetched. + */ + interface ListenableAsyncResultSet extends AsyncResultSet { + /** Adds a listener to this {@link AsyncResultSet}. */ + void addListener(Runnable listener); + + void removeListener(Runnable listener); + } + /** * A {@code ReadContext} for standalone reads. This can only be used for a single operation, since * each standalone read may see a different timestamp of Cloud Spanner data. @@ -350,7 +375,8 @@ void initTransaction() { final Object lock = new Object(); final SessionImpl session; final SpannerRpc rpc; - final Span span; + final ExecutorProvider executorProvider; + Span span; private final int defaultPrefetchChunks; private final QueryOptions defaultQueryOptions; @@ -374,6 +400,12 @@ void initTransaction() { this.defaultPrefetchChunks = builder.defaultPrefetchChunks; this.defaultQueryOptions = builder.defaultQueryOptions; this.span = builder.span; + this.executorProvider = builder.executorProvider; + } + + @Override + public void setSpan(Span span) { + this.span = span; } long getSeqNo() { @@ -386,12 +418,38 @@ public final ResultSet read( return readInternal(table, null, keys, columns, options); } + @Override + public ListenableAsyncResultSet readAsync( + String table, KeySet keys, Iterable columns, ReadOption... options) { + Options readOptions = Options.fromReadOptions(options); + final int bufferRows = + readOptions.hasBufferRows() + ? readOptions.bufferRows() + : AsyncResultSetImpl.DEFAULT_BUFFER_SIZE; + return new AsyncResultSetImpl( + executorProvider, readInternal(table, null, keys, columns, options), bufferRows); + } + @Override public final ResultSet readUsingIndex( String table, String index, KeySet keys, Iterable columns, ReadOption... options) { return readInternal(table, checkNotNull(index), keys, columns, options); } + @Override + public ListenableAsyncResultSet readUsingIndexAsync( + String table, String index, KeySet keys, Iterable columns, ReadOption... options) { + Options readOptions = Options.fromReadOptions(options); + final int bufferRows = + readOptions.hasBufferRows() + ? readOptions.bufferRows() + : AsyncResultSetImpl.DEFAULT_BUFFER_SIZE; + return new AsyncResultSetImpl( + executorProvider, + readInternal(table, checkNotNull(index), keys, columns, options), + bufferRows); + } + @Nullable @Override public final Struct readRow(String table, Key key, Iterable columns) { @@ -400,6 +458,13 @@ public final Struct readRow(String table, Key key, Iterable columns) { } } + @Override + public final ApiFuture readRowAsync(String table, Key key, Iterable columns) { + try (AsyncResultSet resultSet = readAsync(table, KeySet.singleKey(key), columns)) { + return consumeSingleRowAsync(resultSet); + } + } + @Nullable @Override public final Struct readRowUsingIndex( @@ -409,12 +474,35 @@ public final Struct readRowUsingIndex( } } + @Override + public final ApiFuture readRowUsingIndexAsync( + String table, String index, Key key, Iterable columns) { + try (AsyncResultSet resultSet = + readUsingIndexAsync(table, index, KeySet.singleKey(key), columns)) { + return consumeSingleRowAsync(resultSet); + } + } + @Override public final ResultSet executeQuery(Statement statement, QueryOption... options) { return executeQueryInternal( statement, com.google.spanner.v1.ExecuteSqlRequest.QueryMode.NORMAL, options); } + @Override + public ListenableAsyncResultSet executeQueryAsync(Statement statement, QueryOption... options) { + Options readOptions = Options.fromQueryOptions(options); + final int bufferRows = + readOptions.hasBufferRows() + ? readOptions.bufferRows() + : AsyncResultSetImpl.DEFAULT_BUFFER_SIZE; + return new AsyncResultSetImpl( + executorProvider, + executeQueryInternal( + statement, com.google.spanner.v1.ExecuteSqlRequest.QueryMode.NORMAL, options), + bufferRows); + } + @Override public final ResultSet analyzeQuery(Statement statement, QueryAnalyzeMode readContextQueryMode) { switch (readContextQueryMode) { @@ -666,4 +754,71 @@ private Struct consumeSingleRow(ResultSet resultSet) { } return row; } + + static ApiFuture consumeSingleRowAsync(AsyncResultSet resultSet) { + final SettableApiFuture result = SettableApiFuture.create(); + // We can safely use a directExecutor here, as we will only be consuming one row, and we will + // not be doing any blocking stuff in the handler. + final SettableApiFuture row = SettableApiFuture.create(); + ApiFutures.addCallback( + resultSet.setCallback(MoreExecutors.directExecutor(), ConsumeSingleRowCallback.create(row)), + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + result.setException(t); + } + + @Override + public void onSuccess(Void input) { + try { + result.set(row.get()); + } catch (Throwable t) { + result.setException(t); + } + } + }, + MoreExecutors.directExecutor()); + return result; + } + + /** + * {@link ReadyCallback} for returning the first row in a result set as a future {@link Struct}. + */ + private static class ConsumeSingleRowCallback implements ReadyCallback { + private final SettableApiFuture result; + private Struct row; + + static ConsumeSingleRowCallback create(SettableApiFuture result) { + return new ConsumeSingleRowCallback(result); + } + + private ConsumeSingleRowCallback(SettableApiFuture result) { + this.result = result; + } + + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + switch (resultSet.tryNext()) { + case DONE: + result.set(row); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + if (row != null) { + throw newSpannerException( + ErrorCode.INTERNAL, "Multiple rows returned for single key"); + } + row = resultSet.getCurrentRowAsStruct(); + return CallbackResponse.CONTINUE; + default: + throw new IllegalStateException(); + } + } catch (Throwable t) { + result.setException(t); + return CallbackResponse.DONE; + } + } + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java index 7b248bfb9d..6b0681b588 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java @@ -495,7 +495,7 @@ private static Struct decodeStructValue(Type structType, ListValue structValue) return new GrpcStruct(structType, fields); } - private static Object decodeArrayValue(Type elementType, ListValue listValue) { + static Object decodeArrayValue(Type elementType, ListValue listValue) { switch (elementType.getCode()) { case BOOL: // Use a view: element conversion is virtually free. @@ -1009,7 +1009,7 @@ protected PartialResultSet computeNext() { } } - private static double valueProtoToFloat64(com.google.protobuf.Value proto) { + static double valueProtoToFloat64(com.google.protobuf.Value proto) { if (proto.getKindCase() == KindCase.STRING_VALUE) { switch (proto.getStringValue()) { case "-Infinity": @@ -1037,7 +1037,7 @@ private static double valueProtoToFloat64(com.google.protobuf.Value proto) { return proto.getNumberValue(); } - private static NullPointerException throwNotNull(int columnIndex) { + static NullPointerException throwNotNull(int columnIndex) { throw new NullPointerException( "Cannot call array getter for column " + columnIndex + " with null elements"); } @@ -1048,7 +1048,7 @@ private static NullPointerException throwNotNull(int columnIndex) { * {@code BigDecimal} respectively. Rather than construct new wrapper objects for each array * element, we use primitive arrays and a {@code BitSet} to track nulls. */ - private abstract static class PrimitiveArray extends AbstractList { + abstract static class PrimitiveArray extends AbstractList { private final A data; private final BitSet nulls; private final int size; @@ -1103,7 +1103,7 @@ A toPrimitiveArray(int columnIndex) { } } - private static class Int64Array extends PrimitiveArray { + static class Int64Array extends PrimitiveArray { Int64Array(ListValue protoList) { super(protoList); } @@ -1128,7 +1128,7 @@ Long get(long[] array, int i) { } } - private static class Float64Array extends PrimitiveArray { + static class Float64Array extends PrimitiveArray { Float64Array(ListValue protoList) { super(protoList); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java new file mode 100644 index 0000000000..c44a42994e --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java @@ -0,0 +1,226 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiFuture; +import com.google.common.base.Function; +import com.google.common.collect.ImmutableList; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; + +/** Interface for result sets returned by async query methods. */ +public interface AsyncResultSet extends ResultSet { + + /** Response code from {@code tryNext()}. */ + enum CursorState { + /** Cursor has been moved to a new row. */ + OK, + /** Read is complete, all rows have been consumed, and there are no more. */ + DONE, + /** No further information known at this time, thus current row not available. */ + NOT_READY + } + + /** + * Non-blocking call that attempts to step the cursor to the next position in the stream. The + * cursor may be inspected only if the cursor returns {@code CursorState.OK}. + * + *

A caller will typically call {@link #tryNext()} in a loop inside the ReadyCallback, + * consuming all results available. For more information see {@link #setCallback(Executor, + * ReadyCallback)}. + * + *

Currently this method may only be called if a ReadyCallback has been registered. This is for + * safety purposes only, and may be relaxed in future. + * + * @return current cursor readiness state + * @throws SpannerException When an unrecoverable problem downstream occurs. Once this occurs you + * will get no further callbacks. You should return CallbackResponse.DONE back from callback. + */ + CursorState tryNext() throws SpannerException; + + enum CallbackResponse { + /** + * Tell the cursor to continue issuing callbacks when data is available. This is the standard + * "I'm ready for more" response. If cursor is not completely drained of all ready results the + * callback will be called again immediately. + */ + CONTINUE, + + /** + * Tell the cursor to suspend all callbacks until application calls {@link RowCursor#resume()}. + */ + PAUSE, + + /** + * Tell the cursor you are done receiving results, even if there are more results sitting in the + * buffer. Once you return DONE, you will receive no further callbacks. + * + *

Approximately equivalent to calling {@link RowCursor#cancel()}, and then returning {@code + * PAUSE}, but more clear, immediate, and idiomatic. + * + *

It is legal to commit a transaction that owns this read before actually returning {@code + * DONE}. + */ + DONE, + } + + /** + * Interface for receiving asynchronous callbacks when new data is ready. See {@link + * AsyncResultSet#setCallback(Executor, ReadyCallback)}. + */ + interface ReadyCallback { + CallbackResponse cursorReady(AsyncResultSet resultSet); + } + + /** + * Register a callback with the ResultSet to be made aware when more data is available, changing + * the usage pattern from sync to async. Details: + * + *

    + *
  • The callback will be called at least once. + *
  • The callback is run each time more results are available, or when we discover that there + * will be no more results. (unless paused, see below). Spurious callbacks are possible, see + * below. + *
  • Spanner guarantees that one callback is ever outstanding at a time. Also, future + * callbacks guarantee the "happens before" property with previous callbacks. + *
  • A callback normally consumes all available data in the ResultSet, and then returns {@link + * CallbackResponse#CONTINUE}. + *
  • If a callback returns {@link CallbackResponse#CONTINUE} with data still in the ResultSet, + * the callback is invoked again immediately! + *
  • Once a callback has returned {@link CallbackResponse#PAUSE} on the cursor no more + * callbacks will be run until a corresponding {@link #resume()}. + *
  • Callback will stop being called once any of the following occurs: + *
      + *
    1. Callback returns {@link CallbackResponse#DONE}. + *
    2. {@link ResultSet#tryNext()} returns {@link CursorState#DONE}. + *
    3. {@link ResultSet#tryNext()} throws an exception. + *
    + *
  • Callback may possibly be invoked after a call to {@link ResultSet#cancel()} call, but the + * subsequent call to {@link #tryNext()} will yield a SpannerException. + *
  • Spurious callbacks are possible where cursors are not actually ready. Typically callback + * should return {@link CallbackResponse#CONTINUE} any time it sees {@link + * CursorState#NOT_READY}. + *
+ * + *

Flow Control

+ * + * If no flow control is needed (say because result sizes are known in advance to be finite in + * size) then async processing is simple. The following is a code example that transfers work from + * the cursor to an upstream sink: + * + *
{@code
+   * @Override
+   * public CallbackResponse cursorReady(ResultSet cursor) {
+   *   try {
+   *     while (true) {
+   *       switch (cursor.tryNext()) {
+   *         case OK:    upstream.emit(cursor.getRow()); break;
+   *         case DONE:  upstream.done(); return CallbackResponse.DONE;
+   *         case NOT_READY:  return CallbackResponse.CONTINUE;
+   *       }
+   *     }
+   *   } catch (SpannerException e) {
+   *     upstream.doneWithError(e);
+   *     return CallbackResponse.DONE;
+   *   }
+   * }
+   * }
+ * + * Flow control may be needed if for example the upstream system may not always be ready to handle + * more data. In this case the app developer has two main options: + * + *
    + *
  • Semi-async: make {@code upstream.emit()} a blocking call. This will block the callback + * thread until progress is possible. When coding in this way the threads in the Executor + * provided to {@link #setCallback(Executor, ReadyCallback)} must be blockable without + * causing harm to progress in your system. + *
  • Full-async: call {@code cursor.pause()} and return from the callback with data still in + * the Cursor. Once in this state cursor waits until resume() is called before calling + * callback again. + *
+ * + * @param exec executor on which to run all callbacks. Typically use a threadpool. If the executor + * is one that runs the work on the submitting thread, you must be very careful not to throw + * RuntimeException up the stack, lest you do damage to calling components. For example, it + * may cause an event dispatcher thread to crash. + * @param cb ready callback + * @return An {@link ApiFuture} that returns null when the consumption of the {@link + * AsyncResultSet} has finished successfully. No more calls to the {@link ReadyCallback} will + * follow and all resources used by the {@link AsyncResultSet} have been cleaned up. The + * {@link ApiFuture} throws an {@link ExecutionException} if the consumption of the {@link + * AsyncResultSet} finished with an error. + */ + ApiFuture setCallback(Executor exec, ReadyCallback cb); + + /** + * Attempt to cancel this operation and free all resources. Non-blocking. This is a no-op for + * child row cursors and does not cancel the parent cursor. + */ + void cancel(); + + /** + * Resume callbacks from the cursor. If there is more data available, a callback will be + * dispatched immediately. This can be called from any thread. + */ + void resume(); + + /** + * Transforms the row cursor into an immutable list using the given transformer function. {@code + * transformer} will be called once per row, thus the returned list will contain one entry per + * row. The returned future will throw a {@link SpannerException} if the row cursor encountered + * any error or if the transformer threw an exception on any row. + * + *

The transformer will be run on the supplied executor. The implementation may batch multiple + * transformer invocations together into a single {@code Runnable} when possible to increase + * efficiency. At any point in time, there will be at most one invocation of the transformer in + * progress. + * + *

WARNING: This will result in materializing the entire list so this should be used + * judiciously after considering the memory requirements of the returned list. + * + *

WARNING: The {@code RowBase} object passed to transformer function is not immutable and is + * not guaranteed to remain valid after the transformer function returns. The same {@code RowBase} + * object might be passed multiple times to the transformer with different underlying data each + * time. So *NEVER* keep a reference to the {@code RowBase} outside of the transformer. + * Specifically do not use {@link com.google.common.base.Functions#identity()} function. + * + * @param transformer function which will be used to transform the row. It should not return null. + * @param executor executor on which the transformer will be run. This should ideally not be an + * inline executor such as {@code MoreExecutors.directExecutor()}; using such an executor may + * degrade the performance of the Spanner library. + */ + ApiFuture> toListAsync( + Function transformer, Executor executor); + + /** + * Transforms the row cursor into an immutable list using the given transformer function. {@code + * transformer} will be called once per row, thus the returned list will contain one entry per + * row. This method will block until all the rows have been yielded by the cursor. + * + *

WARNING: This will result in consuming the entire list so this should be used judiciously + * after considering the memory requirements of the returned list. + * + *

WARNING: The {@code RowBase} object passed to transformer function is not immutable and is + * not guaranteed to remain valid after the transformer function returns. The same {@code RowBase} + * object might be passed multiple times to the transformer with different underlying data each + * time. So *NEVER* keep a reference to the {@code RowBase} outside of the transformer. + * Specifically do not use {@link com.google.common.base.Functions#identity()} function. + * + * @param transformer function which will be used to transform the row. It should not return null. + */ + ImmutableList toList(Function transformer) throws SpannerException; +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java new file mode 100644 index 0000000000..f277388b0b --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java @@ -0,0 +1,586 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiAsyncFunction; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.core.ListenableFutureToApiFuture; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.core.ExecutorProvider; +import com.google.cloud.spanner.AbstractReadContext.ListenableAsyncResultSet; +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.spanner.v1.ResultSetStats; +import java.util.Collection; +import java.util.LinkedList; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** Default implementation for {@link AsyncResultSet}. */ +class AsyncResultSetImpl extends ForwardingStructReader implements ListenableAsyncResultSet { + private static final Logger log = Logger.getLogger(AsyncResultSetImpl.class.getName()); + + /** State of an {@link AsyncResultSetImpl}. */ + private enum State { + INITIALIZED, + /** SYNC indicates that the {@link ResultSet} is used in sync pattern. */ + SYNC, + CONSUMING, + RUNNING, + PAUSED, + CANCELLED(true), + DONE(true); + + /** Does this state mean that the result set should permanently stop producing rows. */ + private final boolean shouldStop; + + private State() { + shouldStop = false; + } + + private State(boolean shouldStop) { + this.shouldStop = shouldStop; + } + } + + static final int DEFAULT_BUFFER_SIZE = 10; + private static final int MAX_WAIT_FOR_BUFFER_CONSUMPTION = 10; + private static final SpannerException CANCELLED_EXCEPTION = + SpannerExceptionFactory.newSpannerException( + ErrorCode.CANCELLED, "This AsyncResultSet has been cancelled"); + + private final Object monitor = new Object(); + private boolean closed; + + /** + * {@link ExecutorProvider} provides executor services that are used to fetch data from the + * backend and put these into the buffer for further consumption by the callback. + */ + private final ExecutorProvider executorProvider; + + private final ListeningScheduledExecutorService service; + + private final BlockingDeque buffer; + private Struct currentRow; + /** The underlying synchronous {@link ResultSet} that is producing the rows. */ + private final ResultSet delegateResultSet; + + /** + * Any exception that occurs while executing the query and iterating over the result set will be + * stored in this variable and propagated to the user through {@link #tryNext()}. + */ + private volatile SpannerException executionException; + + /** + * Executor for callbacks. Regardless of the type of executor that is provided, the {@link + * AsyncResultSetImpl} will ensure that at most 1 callback call will be active at any one time. + */ + private Executor executor; + + private ReadyCallback callback; + + /** + * Listeners that will be called when the {@link AsyncResultSetImpl} has finished fetching all + * rows and any underlying transaction or session can be closed. + */ + private Collection listeners = new LinkedList<>(); + + private State state = State.INITIALIZED; + + /** + * {@link #finished} indicates whether all the results from the underlying result set have been + * read. + */ + private volatile boolean finished; + + private volatile ApiFuture result; + + /** + * {@link #cursorReturnedDoneOrException} indicates whether {@link #tryNext()} has returned {@link + * CursorState#DONE} or a {@link SpannerException}. + */ + private volatile boolean cursorReturnedDoneOrException; + + /** + * {@link #pausedLatch} is used to pause the producer when the {@link AsyncResultSet} is paused. + * The production of rows that are put into the buffer is only paused once the buffer is full. + */ + private volatile CountDownLatch pausedLatch = new CountDownLatch(1); + /** + * {@link #bufferConsumptionLatch} is used to pause the producer when the buffer is full and the + * consumer needs some time to catch up. + */ + private volatile CountDownLatch bufferConsumptionLatch = new CountDownLatch(0); + /** + * {@link #consumingLatch} is used to pause the producer when all rows have been put into the + * buffer, but the consumer (the callback) has not yet received and processed all rows. + */ + private volatile CountDownLatch consumingLatch = new CountDownLatch(0); + + AsyncResultSetImpl(ExecutorProvider executorProvider, ResultSet delegate, int bufferSize) { + super(delegate); + this.executorProvider = Preconditions.checkNotNull(executorProvider); + this.delegateResultSet = Preconditions.checkNotNull(delegate); + this.service = MoreExecutors.listeningDecorator(executorProvider.getExecutor()); + this.buffer = new LinkedBlockingDeque<>(bufferSize); + } + + /** + * Closes the {@link AsyncResultSet}. {@link #close()} is non-blocking and may be called multiple + * times without side effects. An {@link AsyncResultSet} may be closed before all rows have been + * returned to the callback, and calling {@link #tryNext()} on a closed {@link AsyncResultSet} is + * allowed as long as this is done from within a {@link ReadyCallback}. Calling {@link #resume()} + * on a closed {@link AsyncResultSet} is also allowed. + */ + @Override + public void close() { + synchronized (monitor) { + if (this.closed) { + return; + } + if (state == State.INITIALIZED || state == State.SYNC) { + delegateResultSet.close(); + } + this.closed = true; + } + } + + /** + * Adds a listener that will be called when no more rows will be read from the underlying {@link + * ResultSet}, either because all rows have been read, or because {@link + * ReadyCallback#cursorReady(AsyncResultSet)} returned {@link CallbackResponse#DONE}. + */ + @Override + public void addListener(Runnable listener) { + Preconditions.checkState(state == State.INITIALIZED); + listeners.add(listener); + } + + @Override + public void removeListener(Runnable listener) { + Preconditions.checkState(state == State.INITIALIZED); + listeners.remove(listener); + } + + /** + * Tries to advance this {@link AsyncResultSet} to the next row. This method may only be called + * from within a {@link ReadyCallback}. + */ + @Override + public CursorState tryNext() throws SpannerException { + synchronized (monitor) { + if (state == State.CANCELLED) { + cursorReturnedDoneOrException = true; + throw CANCELLED_EXCEPTION; + } + if (buffer.isEmpty() && executionException != null) { + cursorReturnedDoneOrException = true; + throw executionException; + } + Preconditions.checkState( + this.callback != null, "tryNext may only be called after a callback has been set."); + Preconditions.checkState( + this.state == State.CONSUMING, + "tryNext may only be called from a DataReady callback. Current state: " + + this.state.name()); + + if (finished && buffer.isEmpty()) { + cursorReturnedDoneOrException = true; + return CursorState.DONE; + } + } + if (!buffer.isEmpty()) { + // Set the next row from the buffer as the current row of the StructReader. + replaceDelegate(currentRow = buffer.pop()); + synchronized (monitor) { + bufferConsumptionLatch.countDown(); + } + return CursorState.OK; + } + return CursorState.NOT_READY; + } + + private void closeDelegateResultSet() { + try { + delegateResultSet.close(); + } catch (Throwable t) { + log.log(Level.FINE, "Ignoring error from closing delegate result set", t); + } + } + + /** + * {@link CallbackRunnable} calls the {@link ReadyCallback} registered for this {@link + * AsyncResultSet}. + */ + private class CallbackRunnable implements Runnable { + @Override + public void run() { + try { + while (true) { + synchronized (monitor) { + if (cursorReturnedDoneOrException) { + break; + } + } + CallbackResponse response; + try { + response = callback.cursorReady(AsyncResultSetImpl.this); + } catch (Throwable e) { + synchronized (monitor) { + if (cursorReturnedDoneOrException + && state == State.CANCELLED + && e instanceof SpannerException + && ((SpannerException) e).getErrorCode() == ErrorCode.CANCELLED) { + // The callback did not catch the cancelled exception (which it should have), but + // we'll keep the cancelled state. + return; + } + executionException = SpannerExceptionFactory.newSpannerException(e); + cursorReturnedDoneOrException = true; + } + return; + } + synchronized (monitor) { + if (state == State.CANCELLED) { + if (cursorReturnedDoneOrException) { + return; + } + } else { + switch (response) { + case DONE: + state = State.DONE; + closeDelegateResultSet(); + return; + case PAUSE: + state = State.PAUSED; + // Make sure no-one else is waiting on the current pause latch and create a new + // one. + pausedLatch.countDown(); + pausedLatch = new CountDownLatch(1); + return; + case CONTINUE: + if (buffer.isEmpty()) { + // Call the callback once more if the entire result set has been processed but + // the callback has not yet received a CursorState.DONE or a CANCELLED error. + if (finished && !cursorReturnedDoneOrException) { + break; + } + state = State.RUNNING; + return; + } + break; + default: + throw new IllegalStateException("Unknown response: " + response); + } + } + } + } + } finally { + synchronized (monitor) { + // Count down all latches that the producer might be waiting on. + consumingLatch.countDown(); + while (bufferConsumptionLatch.getCount() > 0L) { + bufferConsumptionLatch.countDown(); + } + } + } + } + } + + private final CallbackRunnable callbackRunnable = new CallbackRunnable(); + + /** + * {@link ProduceRowsCallable} reads data from the underlying {@link ResultSet}, places these in + * the buffer and dispatches the {@link CallbackRunnable} when data is ready to be consumed. + */ + private class ProduceRowsCallable implements Callable { + @Override + public Void call() throws Exception { + boolean stop = false; + boolean hasNext = false; + try { + hasNext = delegateResultSet.next(); + } catch (Throwable e) { + synchronized (monitor) { + executionException = SpannerExceptionFactory.newSpannerException(e); + } + } + try { + while (!stop && hasNext) { + try { + synchronized (monitor) { + stop = state.shouldStop; + } + if (!stop) { + while (buffer.remainingCapacity() == 0 && !stop) { + waitIfPaused(); + // The buffer is full and we should let the callback consume a number of rows before + // we proceed with producing any more rows to prevent us from potentially waiting on + // a full buffer repeatedly. + // Wait until at least half of the buffer is available, or if it's a bigger buffer, + // wait until at least 10 rows can be placed in it. + // TODO: Make this more dynamic / configurable? + startCallbackWithBufferLatchIfNecessary( + Math.min( + Math.min(buffer.size() / 2 + 1, buffer.size()), + MAX_WAIT_FOR_BUFFER_CONSUMPTION)); + bufferConsumptionLatch.await(); + synchronized (monitor) { + stop = state.shouldStop; + } + } + } + if (!stop) { + buffer.put(delegateResultSet.getCurrentRowAsStruct()); + startCallbackIfNecessary(); + hasNext = delegateResultSet.next(); + } + } catch (Throwable e) { + synchronized (monitor) { + executionException = SpannerExceptionFactory.newSpannerException(e); + stop = true; + } + } + } + // We don't need any more data from the underlying result set, so we close it as soon as + // possible. Any error that might occur during this will be ignored. + closeDelegateResultSet(); + + // Ensure that the callback has been called at least once, even if the result set was + // cancelled. + synchronized (monitor) { + finished = true; + stop = cursorReturnedDoneOrException; + } + // Call the callback if there are still rows in the buffer that need to be processed. + while (!stop) { + waitIfPaused(); + startCallbackIfNecessary(); + synchronized (monitor) { + stop = state.shouldStop || cursorReturnedDoneOrException; + } + // Make sure we wait until the callback runner has actually finished. + consumingLatch.await(); + } + } finally { + if (executorProvider.shouldAutoClose()) { + service.shutdown(); + } + for (Runnable listener : listeners) { + listener.run(); + } + synchronized (monitor) { + if (executionException != null) { + throw executionException; + } + if (state == State.CANCELLED) { + throw CANCELLED_EXCEPTION; + } + } + } + return null; + } + + private void waitIfPaused() throws InterruptedException { + CountDownLatch pause; + synchronized (monitor) { + pause = pausedLatch; + } + pause.await(); + } + + private void startCallbackIfNecessary() { + startCallbackWithBufferLatchIfNecessary(0); + } + + private void startCallbackWithBufferLatchIfNecessary(int bufferLatch) { + synchronized (monitor) { + if ((state == State.RUNNING || state == State.CANCELLED) + && !cursorReturnedDoneOrException) { + consumingLatch = new CountDownLatch(1); + if (bufferLatch > 0) { + bufferConsumptionLatch = new CountDownLatch(bufferLatch); + } + if (state == State.RUNNING) { + state = State.CONSUMING; + } + executor.execute(callbackRunnable); + } + } + } + } + + /** Sets the callback for this {@link AsyncResultSet}. */ + @Override + public ApiFuture setCallback(Executor exec, ReadyCallback cb) { + synchronized (monitor) { + Preconditions.checkState(!closed, "This AsyncResultSet has been closed"); + Preconditions.checkState( + this.state == State.INITIALIZED, "callback may not be set multiple times"); + + // Start to fetch data and buffer these. + this.result = + new ListenableFutureToApiFuture<>(this.service.submit(new ProduceRowsCallable())); + this.executor = MoreExecutors.newSequentialExecutor(Preconditions.checkNotNull(exec)); + this.callback = Preconditions.checkNotNull(cb); + this.state = State.RUNNING; + pausedLatch.countDown(); + return result; + } + } + + Future getResult() { + return result; + } + + @Override + public void cancel() { + synchronized (monitor) { + Preconditions.checkState( + state != State.INITIALIZED && state != State.SYNC, + "cannot cancel a result set without a callback"); + state = State.CANCELLED; + pausedLatch.countDown(); + } + } + + @Override + public void resume() { + synchronized (monitor) { + Preconditions.checkState( + state != State.INITIALIZED && state != State.SYNC, + "cannot resume a result set without a callback"); + if (state == State.PAUSED) { + state = State.RUNNING; + pausedLatch.countDown(); + } + } + } + + private static class CreateListCallback implements ReadyCallback { + private final SettableApiFuture> future; + private final Function transformer; + private final ImmutableList.Builder builder = ImmutableList.builder(); + + private CreateListCallback( + SettableApiFuture> future, Function transformer) { + this.future = future; + this.transformer = transformer; + } + + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + future.set(builder.build()); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + builder.add(transformer.apply(resultSet)); + break; + } + } + } catch (Throwable t) { + future.setException(t); + return CallbackResponse.DONE; + } + } + } + + @Override + public ApiFuture> toListAsync( + Function transformer, Executor executor) { + synchronized (monitor) { + Preconditions.checkState(!closed, "This AsyncResultSet has been closed"); + Preconditions.checkState( + this.state == State.INITIALIZED, "This AsyncResultSet has already been used."); + final SettableApiFuture> res = SettableApiFuture.>create(); + CreateListCallback callback = new CreateListCallback(res, transformer); + ApiFuture finished = setCallback(executor, callback); + return ApiFutures.transformAsync( + finished, + new ApiAsyncFunction>() { + @Override + public ApiFuture> apply(Void input) throws Exception { + return res; + } + }, + MoreExecutors.directExecutor()); + } + } + + @Override + public ImmutableList toList(Function transformer) + throws SpannerException { + ApiFuture> future = toListAsync(transformer, MoreExecutors.directExecutor()); + try { + return future.get(); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause()); + } catch (Throwable e) { + throw SpannerExceptionFactory.newSpannerException(e); + } + } + + @Override + public boolean next() throws SpannerException { + synchronized (monitor) { + Preconditions.checkState( + this.state == State.INITIALIZED || this.state == State.SYNC, + "Cannot call next() on a result set with a callback."); + this.state = State.SYNC; + } + boolean res = delegateResultSet.next(); + currentRow = delegateResultSet.getCurrentRowAsStruct(); + return res; + } + + @Override + public ResultSetStats getStats() { + return delegateResultSet.getStats(); + } + + @Override + protected void checkValidState() { + synchronized (monitor) { + Preconditions.checkState( + state == State.SYNC || state == State.CONSUMING || state == State.CANCELLED, + "only allowed after a next() call or from within a ReadyCallback#cursorReady callback"); + Preconditions.checkState(state != State.SYNC || !closed, "ResultSet is closed"); + } + } + + @Override + public Struct getCurrentRowAsStruct() { + checkValidState(); + return currentRow; + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunner.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunner.java new file mode 100644 index 0000000000..3cae49e65b --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunner.java @@ -0,0 +1,59 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiFuture; +import com.google.cloud.Timestamp; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; + +public interface AsyncRunner { + + /** + * Functional interface for executing a read/write transaction asynchronously that returns a + * result of type R. + */ + interface AsyncWork { + /** + * Performs a single transaction attempt. All reads/writes should be performed using {@code + * txn}. + * + *

Implementations of this method should not attempt to commit the transaction directly: + * returning normally will result in the runner attempting to commit the transaction once the + * returned future completes, retrying on abort. + * + *

In most cases, the implementation will not need to catch {@code SpannerException}s from + * Spanner operations, instead letting these propagate to the framework. The transaction runner + * will take appropriate action based on the type of exception. In particular, implementations + * should never catch an exception of type {@link SpannerErrors#isAborted}: these indicate that + * some reads may have returned inconsistent data and the transaction attempt must be aborted. + * + * @param txn the transaction + * @return future over the result of the work + */ + ApiFuture doWorkAsync(TransactionContext txn); + } + + /** Executes a read/write transaction asynchronously using the given executor. */ + ApiFuture runAsync(AsyncWork work, Executor executor); + + /** + * Returns the timestamp at which the transaction committed. {@link ApiFuture#get()} will throw an + * {@link ExecutionException} if the transaction did not commit. + */ + ApiFuture getCommitTimestamp(); +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunnerImpl.java new file mode 100644 index 0000000000..5b83402919 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncRunnerImpl.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiFuture; +import com.google.api.core.SettableApiFuture; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; + +class AsyncRunnerImpl implements AsyncRunner { + private final TransactionRunnerImpl delegate; + private final SettableApiFuture commitTimestamp = SettableApiFuture.create(); + + AsyncRunnerImpl(TransactionRunnerImpl delegate) { + this.delegate = delegate; + } + + @Override + public ApiFuture runAsync(final AsyncWork work, Executor executor) { + final SettableApiFuture res = SettableApiFuture.create(); + executor.execute( + new Runnable() { + @Override + public void run() { + try { + res.set(runTransaction(work)); + } catch (Throwable t) { + res.setException(t); + } finally { + setCommitTimestamp(); + } + } + }); + return res; + } + + private R runTransaction(final AsyncWork work) { + return delegate.run( + new TransactionCallable() { + @Override + public R run(TransactionContext transaction) throws Exception { + try { + return work.doWorkAsync(transaction).get(); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause()); + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } + } + }); + } + + private void setCommitTimestamp() { + try { + commitTimestamp.set(delegate.getCommitTimestamp()); + } catch (Throwable t) { + commitTimestamp.setException(t); + } + } + + @Override + public ApiFuture getCommitTimestamp() { + return commitTimestamp; + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java new file mode 100644 index 0000000000..d519c68013 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManager.java @@ -0,0 +1,203 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiFuture; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionFunction; +import com.google.cloud.spanner.AsyncTransactionManager.CommitTimestampFuture; +import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; +import com.google.cloud.spanner.TransactionManager.TransactionState; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * An interface for managing the life cycle of a read write transaction including all its retries. + * See {@link TransactionContext} for a description of transaction semantics. + * + *

At any point in time there can be at most one active transaction in this manager. When that + * transaction is committed, if it fails with an {@code ABORTED} error, calling {@link + * #resetForRetryAsync()} would create a new {@link TransactionContextFuture}. The newly created + * transaction would use the same session thus increasing its lock priority. If the transaction is + * committed successfully, or is rolled back or commit fails with any error other than {@code + * ABORTED}, the manager is considered complete and no further transactions are allowed to be + * created in it. + * + *

Every {@code AsyncTransactionManager} should either be committed or rolled back. Failure to do + * so can cause resources to be leaked and deadlocks. Easiest way to guarantee this is by calling + * {@link #close()} in a finally block. + * + * @see DatabaseClient#transactionManagerAsync() + */ +public interface AsyncTransactionManager extends AutoCloseable { + /** + * {@link ApiFuture} that returns a {@link TransactionContext} and that supports chaining of + * multiple {@link TransactionContextFuture}s to form a transaction. + */ + public interface TransactionContextFuture extends ApiFuture { + /** + * Sets the first step to execute as part of this transaction after the transaction has started + * using the specified executor. {@link MoreExecutors#directExecutor()} can be be used for + * lightweight functions, but should be avoided for heavy or blocking operations. See also + * {@link ListenableFuture#addListener(Runnable, Executor)} for further information. + */ + AsyncTransactionStep then( + AsyncTransactionFunction function, Executor executor); + } + + /** + * {@link ApiFuture} that returns the commit {@link Timestamp} of a Cloud Spanner transaction that + * is executed using an {@link AsyncTransactionManager}. This future is returned by the call to + * {@link AsyncTransactionStep#commitAsync()} of the last step in the transaction. + */ + public interface CommitTimestampFuture extends ApiFuture { + /** + * Returns the commit timestamp of the transaction. Getting this value should always be done in + * order to ensure that the transaction succeeded. If any of the steps in the transaction fails + * with an uncaught exception, this method will automatically stop the transaction at that point + * and the exception will be returned as the cause of the {@link ExecutionException} that is + * thrown by this method. + * + * @throws AbortedException if the transaction was aborted by Cloud Spanner and needs to be + * retried. + */ + @Override + Timestamp get() throws AbortedException, InterruptedException, ExecutionException; + + /** + * Same as {@link #get()}, but will throw a {@link TimeoutException} if the transaction does not + * finish within the timeout. + */ + @Override + Timestamp get(long timeout, TimeUnit unit) + throws AbortedException, InterruptedException, ExecutionException, TimeoutException; + } + + /** + * {@link AsyncTransactionStep} is returned by {@link + * TransactionContextFuture#then(AsyncTransactionFunction)} and {@link + * AsyncTransactionStep#then(AsyncTransactionFunction)} and allows transaction steps that should + * be executed serially to be chained together. Each step can contain one or more statements that + * may execute in parallel. + * + *

Example usage: + * + *

{@code
+   * TransactionContextFuture txnFuture = manager.beginAsync();
+   * final String column = "FirstName";
+   * txnFuture.then(
+   *         new AsyncTransactionFunction() {
+   *           @Override
+   *           public ApiFuture apply(TransactionContext txn, Void input)
+   *               throws Exception {
+   *             return txn.readRowAsync(
+   *                 "Singers", Key.of(singerId), Collections.singleton(column));
+   *           }
+   *         })
+   *     .then(
+   *         new AsyncTransactionFunction() {
+   *           @Override
+   *           public ApiFuture apply(TransactionContext txn, Struct input)
+   *               throws Exception {
+   *             String name = input.getString(column);
+   *             txn.buffer(
+   *                 Mutation.newUpdateBuilder("Singers")
+   *                     .set(column)
+   *                     .to(name.toUpperCase())
+   *                     .build());
+   *             return ApiFutures.immediateFuture(null);
+   *           }
+   *         })
+   * }
+ */ + public interface AsyncTransactionStep extends ApiFuture { + /** + * Adds a step to the transaction chain that should be executed using the specified executor. + * This step is guaranteed to be executed only after the previous step executed successfully. + * {@link MoreExecutors#directExecutor()} can be be used for lightweight functions, but should + * be avoided for heavy or blocking operations. See also {@link + * ListenableFuture#addListener(Runnable, Executor)} for further information. + */ + AsyncTransactionStep then( + AsyncTransactionFunction next, Executor executor); + + /** + * Commits the transaction and returns a {@link CommitTimestampFuture} that will return the + * commit timestamp of the transaction, or throw the first uncaught exception in the transaction + * chain as an {@link ExecutionException}. + */ + CommitTimestampFuture commitAsync(); + } + + /** + * Each step in a transaction chain is defined by an {@link AsyncTransactionFunction}. It receives + * a {@link TransactionContext} and the output value of the previous transaction step as its input + * parameters. The method should return an {@link ApiFuture} that will return the result of this + * step. + */ + public interface AsyncTransactionFunction { + /** + * {@link #apply(TransactionContext, Object)} is called when this transaction step is executed. + * The input value is the result of the previous step, and this method will only be called if + * the previous step executed successfully. + * + * @param txn the {@link TransactionContext} that can be used to execute statements. + * @param input the result of the previous transaction step. + * @return an {@link ApiFuture} that will return the result of this step, and that will be the + * input of the next transaction step. This method should never return null. + * Instead, if the method does not have a return value, the method should return {@link + * ApiFutures#immediateFuture(null)}. + */ + ApiFuture apply(TransactionContext txn, I input) throws Exception; + } + + /** + * Creates a new read write transaction. This must be called before doing any other operation and + * can only be called once. To create a new transaction for subsequent retries, see {@link + * #resetForRetry()}. + */ + TransactionContextFuture beginAsync(); + + /** + * Rolls back the currently active transaction. In most cases there should be no need to call this + * explicitly since {@link #close()} would automatically roll back any active transaction. + */ + ApiFuture rollbackAsync(); + + /** + * Creates a new transaction for retry. This should only be called if the previous transaction + * failed with {@code ABORTED}. In all other cases, this will throw an {@link + * IllegalStateException}. Users should backoff before calling this method. Backoff delay is + * specified by {@link SpannerException#getRetryDelayInMillis()} on the {@code SpannerException} + * throw by the previous commit call. + */ + TransactionContextFuture resetForRetryAsync(); + + /** Returns the state of the transaction. */ + TransactionState getState(); + + /** + * Closes the manager. If there is an active transaction, it will be rolled back. Underlying + * session will be released back to the session pool. + */ + @Override + void close(); +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java new file mode 100644 index 0000000000..082fa827e7 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java @@ -0,0 +1,167 @@ +/* + * Copyright 2017 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutureCallback; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.SessionImpl.SessionTransaction; +import com.google.cloud.spanner.TransactionContextFutureImpl.CommittableAsyncTransactionManager; +import com.google.cloud.spanner.TransactionManager.TransactionState; +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.MoreExecutors; +import io.opencensus.trace.Span; +import io.opencensus.trace.Tracer; +import io.opencensus.trace.Tracing; + +/** Implementation of {@link AsyncTransactionManager}. */ +final class AsyncTransactionManagerImpl + implements CommittableAsyncTransactionManager, SessionTransaction { + private static final Tracer tracer = Tracing.getTracer(); + + private final SessionImpl session; + private Span span; + + private TransactionRunnerImpl.TransactionContextImpl txn; + private TransactionState txnState; + private final SettableApiFuture commitTimestamp = SettableApiFuture.create(); + + AsyncTransactionManagerImpl(SessionImpl session, Span span) { + this.session = session; + this.span = span; + } + + @Override + public void setSpan(Span span) { + this.span = span; + } + + @Override + public void close() { + txn.close(); + } + + @Override + public TransactionContextFutureImpl beginAsync() { + Preconditions.checkState(txn == null, "begin can only be called once"); + TransactionContextFutureImpl begin = + new TransactionContextFutureImpl(this, internalBeginAsync(true)); + return begin; + } + + private ApiFuture internalBeginAsync(boolean setActive) { + txnState = TransactionState.STARTED; + txn = session.newTransaction(); + if (setActive) { + session.setActive(this); + } + final SettableApiFuture res = SettableApiFuture.create(); + final ApiFuture fut = txn.ensureTxnAsync(); + ApiFutures.addCallback( + fut, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + res.setException(SpannerExceptionFactory.newSpannerException(t)); + } + + @Override + public void onSuccess(Void result) { + res.set(txn); + } + }, + MoreExecutors.directExecutor()); + return res; + } + + @Override + public void onError(Throwable t) { + if (t instanceof AbortedException) { + txnState = TransactionState.ABORTED; + } + } + + @Override + public ApiFuture commitAsync() { + Preconditions.checkState( + txnState == TransactionState.STARTED, + "commit can only be invoked if the transaction is in progress. Current state: " + txnState); + if (txn.isAborted()) { + txnState = TransactionState.ABORTED; + return ApiFutures.immediateFailedFuture( + SpannerExceptionFactory.newSpannerException( + ErrorCode.ABORTED, "Transaction already aborted")); + } + ApiFuture res = txn.commitAsync(); + txnState = TransactionState.COMMITTED; + ApiFutures.addCallback( + res, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + if (t instanceof AbortedException) { + txnState = TransactionState.ABORTED; + } else { + txnState = TransactionState.COMMIT_FAILED; + commitTimestamp.setException(t); + } + } + + @Override + public void onSuccess(Timestamp result) { + commitTimestamp.set(result); + } + }, + MoreExecutors.directExecutor()); + return res; + } + + @Override + public ApiFuture rollbackAsync() { + Preconditions.checkState( + txnState == TransactionState.STARTED, + "rollback can only be called if the transaction is in progress"); + try { + return txn.rollbackAsync(); + } finally { + txnState = TransactionState.ROLLED_BACK; + } + } + + @Override + public TransactionContextFuture resetForRetryAsync() { + if (txn == null || !txn.isAborted() && txnState != TransactionState.ABORTED) { + throw new IllegalStateException( + "resetForRetry can only be called if the previous attempt aborted"); + } + return new TransactionContextFutureImpl(this, internalBeginAsync(false)); + } + + @Override + public TransactionState getState() { + return txnState; + } + + @Override + public void invalidate() { + if (txnState == TransactionState.STARTED || txnState == null) { + txnState = TransactionState.ROLLED_BACK; + } + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index 43de2be092..c84bef77cf 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -30,6 +30,7 @@ import com.google.spanner.v1.PartitionReadRequest; import com.google.spanner.v1.PartitionResponse; import com.google.spanner.v1.TransactionSelector; +import io.opencensus.trace.Tracing; import java.util.List; import java.util.Map; @@ -51,6 +52,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) { .setTimestampBound(bound) .setDefaultQueryOptions( sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) + .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()), checkNotNull(bound)); } @@ -67,6 +69,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(BatchTransactionId batc .setTimestamp(batchTransactionId.getTimestamp()) .setDefaultQueryOptions( sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) + .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()), batchTransactionId); } @@ -81,6 +84,7 @@ private static class BatchReadOnlyTransactionImpl extends MultiUseReadOnlyTransa super(builder.setTimestampBound(bound)); this.sessionName = session.getName(); this.options = session.getOptions(); + setSpan(Tracing.getTracer().getCurrentSpan()); initTransaction(); } @@ -89,6 +93,7 @@ private static class BatchReadOnlyTransactionImpl extends MultiUseReadOnlyTransa super(builder.setTransactionId(batchTransactionId.getTransactionId())); this.sessionName = session.getName(); this.options = session.getOptions(); + setSpan(Tracing.getTracer().getCurrentSpan()); } @Override diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java index ac29ba2b37..d52d1d892e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClient.java @@ -278,6 +278,127 @@ public interface DatabaseClient { */ TransactionManager transactionManager(); + /** + * Returns an asynchronous transaction runner for executing a single logical transaction with + * retries. The returned runner can only be used once. + * + *

Example of a read write transaction. + * + *

 
+   * Executor executor = Executors.newSingleThreadExecutor();
+   * final long singerId = my_singer_id;
+   * AsyncRunner runner = client.runAsync();
+   * ApiFuture rowCount =
+   *     runner.runAsync(
+   *         new AsyncWork() {
+   *           @Override
+   *           public ApiFuture doWorkAsync(TransactionContext txn) {
+   *             String column = "FirstName";
+   *             Struct row =
+   *                 txn.readRow("Singers", Key.of(singerId), Collections.singleton("Name"));
+   *             String name = row.getString("Name");
+   *             return txn.executeUpdateAsync(
+   *                 Statement.newBuilder("UPDATE Singers SET Name=@name WHERE SingerId=@id")
+   *                     .bind("id")
+   *                     .to(singerId)
+   *                     .bind("name")
+   *                     .to(name.toUpperCase())
+   *                     .build());
+   *           }
+   *         },
+   *         executor);
+   * 
+ */ + AsyncRunner runAsync(); + + /** + * Returns an asynchronous transaction manager which allows manual management of transaction + * lifecycle. This API is meant for advanced users. Most users should instead use the {@link + * #runAsync()} API instead. + * + *

Example of using {@link AsyncTransactionManager} with lambda expressions (Java 8 and + * higher). + * + *

{@code
+   * long singerId = 1L;
+   * try (AsyncTransactionManager manager = client.transactionManagerAsync()) {
+   *   TransactionContextFuture txnFut = manager.beginAsync();
+   *   while (true) {
+   *     String column = "FirstName";
+   *     CommitTimestampFuture commitTimestamp =
+   *         txnFut
+   *             .then(
+   *                 (txn, __) ->
+   *                     txn.readRowAsync(
+   *                         "Singers", Key.of(singerId), Collections.singleton(column)))
+   *             .then(
+   *                 (txn, row) -> {
+   *                   String name = row.getString(column);
+   *                   txn.buffer(
+   *                       Mutation.newUpdateBuilder("Singers")
+   *                           .set(column)
+   *                           .to(name.toUpperCase())
+   *                           .build());
+   *                   return ApiFutures.immediateFuture(null);
+   *                 })
+   *             .commitAsync();
+   *     try {
+   *       commitTimestamp.get();
+   *       break;
+   *     } catch (AbortedException e) {
+   *       Thread.sleep(e.getRetryDelayInMillis() / 1000);
+   *       txnFut = manager.resetForRetryAsync();
+   *     }
+   *   }
+   * }
+   * }
+ * + *

Example of using {@link AsyncTransactionManager} (Java 7). + * + *

{@code
+   * final long singerId = 1L;
+   * try (AsyncTransactionManager manager = client().transactionManagerAsync()) {
+   *   TransactionContextFuture txn = manager.beginAsync();
+   *   while (true) {
+   *     final String column = "FirstName";
+   *     CommitTimestampFuture commitTimestamp =
+   *         txn.then(
+   *                 new AsyncTransactionFunction() {
+   *                   @Override
+   *                   public ApiFuture apply(TransactionContext txn, Void input)
+   *                       throws Exception {
+   *                     return txn.readRowAsync(
+   *                         "Singers", Key.of(singerId), Collections.singleton(column));
+   *                   }
+   *                 })
+   *             .then(
+   *                 new AsyncTransactionFunction() {
+   *                   @Override
+   *                   public ApiFuture apply(TransactionContext txn, Struct input)
+   *                       throws Exception {
+   *                     String name = input.getString(column);
+   *                     txn.buffer(
+   *                         Mutation.newUpdateBuilder("Singers")
+   *                             .set(column)
+   *                             .to(name.toUpperCase())
+   *                             .build());
+   *                     return ApiFutures.immediateFuture(null);
+   *                   }
+   *                 })
+   *             .commitAsync();
+   *     try {
+   *       commitTimestamp.get();
+   *       break;
+   *     } catch (AbortedException e) {
+   *       Thread.sleep(e.getRetryDelayInMillis() / 1000);
+   *       txn = manager.resetForRetryAsync();
+   *     }
+   *   }
+   * }
+   * }
+ */ + AsyncTransactionManager transactionManagerAsync(); + /** * Returns the lower bound of rows modified by this DML statement. * diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java index ec83d06335..4dd10001c7 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java @@ -17,7 +17,7 @@ package com.google.cloud.spanner; import com.google.cloud.Timestamp; -import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.SessionPool.PooledSessionFuture; import com.google.cloud.spanner.SpannerImpl.ClosedException; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; @@ -52,12 +52,12 @@ private enum SessionMode { } @VisibleForTesting - PooledSession getReadSession() { + PooledSessionFuture getReadSession() { return pool.getReadSession(); } @VisibleForTesting - PooledSession getReadWriteSession() { + PooledSessionFuture getReadWriteSession() { return pool.getReadWriteSession(); } @@ -191,6 +191,28 @@ public TransactionManager transactionManager() { } } + @Override + public AsyncRunner runAsync() { + Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); + try (Scope s = tracer.withSpan(span)) { + return getReadWriteSession().runAsync(); + } catch (RuntimeException e) { + TraceUtil.endSpanWithFailure(span, e); + throw e; + } + } + + @Override + public AsyncTransactionManager transactionManagerAsync() { + Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); + try (Scope s = tracer.withSpan(span)) { + return getReadWriteSession().transactionManagerAsync(); + } catch (RuntimeException e) { + TraceUtil.endSpanWithFailure(span, e); + throw e; + } + } + @Override public long executePartitionedUpdate(final Statement stmt) { Span span = tracer.spanBuilder(PARTITION_DML_TRANSACTION).startSpan(); @@ -212,7 +234,7 @@ public Long apply(Session session) { } private T runWithSessionRetry(SessionMode mode, Function callable) { - PooledSession session = + PooledSessionFuture session = mode == SessionMode.READ_WRITE ? getReadWriteSession() : getReadSession(); while (true) { try { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseId.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseId.java index d2c732750e..dd13df65e8 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseId.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseId.java @@ -81,7 +81,7 @@ public String toString() { * projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID} * @throws IllegalArgumentException if {@code name} does not conform to the expected pattern */ - static DatabaseId of(String name) { + public static DatabaseId of(String name) { Preconditions.checkNotNull(name); Map parts = NAME_TEMPLATE.match(name); Preconditions.checkArgument( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingAsyncResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingAsyncResultSet.java new file mode 100644 index 0000000000..78e3505998 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingAsyncResultSet.java @@ -0,0 +1,65 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiFuture; +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import java.util.concurrent.Executor; + +/** Forwarding implementation of {@link AsyncResultSet} that forwards all calls to a delegate. */ +public class ForwardingAsyncResultSet extends ForwardingResultSet implements AsyncResultSet { + final AsyncResultSet delegate; + + public ForwardingAsyncResultSet(AsyncResultSet delegate) { + super(Preconditions.checkNotNull(delegate)); + this.delegate = delegate; + } + + @Override + public CursorState tryNext() throws SpannerException { + return delegate.tryNext(); + } + + @Override + public ApiFuture setCallback(Executor exec, ReadyCallback cb) { + return delegate.setCallback(exec, cb); + } + + @Override + public void cancel() { + delegate.cancel(); + } + + @Override + public void resume() { + delegate.resume(); + } + + @Override + public ApiFuture> toListAsync( + Function transformer, Executor executor) { + return delegate.toListAsync(transformer, executor); + } + + @Override + public ImmutableList toList(Function transformer) + throws SpannerException { + return delegate.toList(transformer); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index 753c3f6f39..4cc0ab9b9e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -17,16 +17,23 @@ package com.google.cloud.spanner; import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; import com.google.spanner.v1.ResultSetStats; /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ public class ForwardingResultSet extends ForwardingStructReader implements ResultSet { - private ResultSet delegate; + private Supplier delegate; public ForwardingResultSet(ResultSet delegate) { super(delegate); - this.delegate = Preconditions.checkNotNull(delegate); + this.delegate = Suppliers.ofInstance(Preconditions.checkNotNull(delegate)); + } + + public ForwardingResultSet(Supplier supplier) { + super(supplier); + this.delegate = supplier; } /** @@ -39,26 +46,26 @@ public ForwardingResultSet(ResultSet delegate) { void replaceDelegate(ResultSet newDelegate) { Preconditions.checkNotNull(newDelegate); super.replaceDelegate(newDelegate); - this.delegate = newDelegate; + this.delegate = Suppliers.ofInstance(Preconditions.checkNotNull(newDelegate)); } @Override public boolean next() throws SpannerException { - return delegate.next(); + return delegate.get().next(); } @Override public Struct getCurrentRowAsStruct() { - return delegate.getCurrentRowAsStruct(); + return delegate.get().getCurrentRowAsStruct(); } @Override public void close() { - delegate.close(); + delegate.get().close(); } @Override public ResultSetStats getStats() { - return delegate.getStats(); + return delegate.get().getStats(); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java index 9b30b89985..67e546ad5a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java @@ -20,14 +20,20 @@ import com.google.cloud.Date; import com.google.cloud.Timestamp; import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; import java.util.List; /** Forwarding implements of StructReader */ public class ForwardingStructReader implements StructReader { - private StructReader delegate; + private Supplier delegate; public ForwardingStructReader(StructReader delegate) { + this.delegate = Suppliers.ofInstance(Preconditions.checkNotNull(delegate)); + } + + public ForwardingStructReader(Supplier delegate) { this.delegate = Preconditions.checkNotNull(delegate); } @@ -39,221 +45,271 @@ public ForwardingStructReader(StructReader delegate) { * returned to the user. */ void replaceDelegate(StructReader newDelegate) { - this.delegate = Preconditions.checkNotNull(newDelegate); + this.delegate = Suppliers.ofInstance(Preconditions.checkNotNull(newDelegate)); } + /** + * Called before each forwarding call to allow sub classes to do additional state checking. Sub + * classes should throw an {@link Exception} if the current state is not valid for reading data + * from this {@link ForwardingStructReader}. The default implementation does nothing. + */ + protected void checkValidState() {} + @Override public Type getType() { - return delegate.getType(); + checkValidState(); + return delegate.get().getType(); } @Override public int getColumnCount() { - return delegate.getColumnCount(); + checkValidState(); + return delegate.get().getColumnCount(); } @Override public int getColumnIndex(String columnName) { - return delegate.getColumnIndex(columnName); + checkValidState(); + return delegate.get().getColumnIndex(columnName); } @Override public Type getColumnType(int columnIndex) { - return delegate.getColumnType(columnIndex); + checkValidState(); + return delegate.get().getColumnType(columnIndex); } @Override public Type getColumnType(String columnName) { - return delegate.getColumnType(columnName); + checkValidState(); + return delegate.get().getColumnType(columnName); } @Override public boolean isNull(int columnIndex) { - return delegate.isNull(columnIndex); + checkValidState(); + return delegate.get().isNull(columnIndex); } @Override public boolean isNull(String columnName) { - return delegate.isNull(columnName); + checkValidState(); + return delegate.get().isNull(columnName); } @Override public boolean getBoolean(int columnIndex) { - return delegate.getBoolean(columnIndex); + checkValidState(); + return delegate.get().getBoolean(columnIndex); } @Override public boolean getBoolean(String columnName) { - return delegate.getBoolean(columnName); + checkValidState(); + return delegate.get().getBoolean(columnName); } @Override public long getLong(int columnIndex) { - return delegate.getLong(columnIndex); + checkValidState(); + return delegate.get().getLong(columnIndex); } @Override public long getLong(String columnName) { - return delegate.getLong(columnName); + checkValidState(); + return delegate.get().getLong(columnName); } @Override public double getDouble(int columnIndex) { - return delegate.getDouble(columnIndex); + checkValidState(); + return delegate.get().getDouble(columnIndex); } @Override public double getDouble(String columnName) { - return delegate.getDouble(columnName); + checkValidState(); + return delegate.get().getDouble(columnName); } @Override public String getString(int columnIndex) { - return delegate.getString(columnIndex); + checkValidState(); + return delegate.get().getString(columnIndex); } @Override public String getString(String columnName) { - return delegate.getString(columnName); + checkValidState(); + return delegate.get().getString(columnName); } @Override public ByteArray getBytes(int columnIndex) { - return delegate.getBytes(columnIndex); + checkValidState(); + return delegate.get().getBytes(columnIndex); } @Override public ByteArray getBytes(String columnName) { - return delegate.getBytes(columnName); + checkValidState(); + return delegate.get().getBytes(columnName); } @Override public Timestamp getTimestamp(int columnIndex) { - return delegate.getTimestamp(columnIndex); + checkValidState(); + return delegate.get().getTimestamp(columnIndex); } @Override public Timestamp getTimestamp(String columnName) { - return delegate.getTimestamp(columnName); + checkValidState(); + return delegate.get().getTimestamp(columnName); } @Override public Date getDate(int columnIndex) { - return delegate.getDate(columnIndex); + checkValidState(); + return delegate.get().getDate(columnIndex); } @Override public Date getDate(String columnName) { - return delegate.getDate(columnName); + checkValidState(); + return delegate.get().getDate(columnName); } @Override public boolean[] getBooleanArray(int columnIndex) { - return delegate.getBooleanArray(columnIndex); + checkValidState(); + return delegate.get().getBooleanArray(columnIndex); } @Override public boolean[] getBooleanArray(String columnName) { - return delegate.getBooleanArray(columnName); + checkValidState(); + return delegate.get().getBooleanArray(columnName); } @Override public List getBooleanList(int columnIndex) { - return delegate.getBooleanList(columnIndex); + checkValidState(); + return delegate.get().getBooleanList(columnIndex); } @Override public List getBooleanList(String columnName) { - return delegate.getBooleanList(columnName); + checkValidState(); + return delegate.get().getBooleanList(columnName); } @Override public long[] getLongArray(int columnIndex) { - return delegate.getLongArray(columnIndex); + checkValidState(); + return delegate.get().getLongArray(columnIndex); } @Override public long[] getLongArray(String columnName) { - return delegate.getLongArray(columnName); + checkValidState(); + return delegate.get().getLongArray(columnName); } @Override public List getLongList(int columnIndex) { - return delegate.getLongList(columnIndex); + checkValidState(); + return delegate.get().getLongList(columnIndex); } @Override public List getLongList(String columnName) { - return delegate.getLongList(columnName); + checkValidState(); + return delegate.get().getLongList(columnName); } @Override public double[] getDoubleArray(int columnIndex) { - return delegate.getDoubleArray(columnIndex); + checkValidState(); + return delegate.get().getDoubleArray(columnIndex); } @Override public double[] getDoubleArray(String columnName) { - return delegate.getDoubleArray(columnName); + checkValidState(); + return delegate.get().getDoubleArray(columnName); } @Override public List getDoubleList(int columnIndex) { - return delegate.getDoubleList(columnIndex); + checkValidState(); + return delegate.get().getDoubleList(columnIndex); } @Override public List getDoubleList(String columnName) { - return delegate.getDoubleList(columnName); + checkValidState(); + return delegate.get().getDoubleList(columnName); } @Override public List getStringList(int columnIndex) { - return delegate.getStringList(columnIndex); + checkValidState(); + return delegate.get().getStringList(columnIndex); } @Override public List getStringList(String columnName) { - return delegate.getStringList(columnName); + checkValidState(); + return delegate.get().getStringList(columnName); } @Override public List getBytesList(int columnIndex) { - return delegate.getBytesList(columnIndex); + checkValidState(); + return delegate.get().getBytesList(columnIndex); } @Override public List getBytesList(String columnName) { - return delegate.getBytesList(columnName); + checkValidState(); + return delegate.get().getBytesList(columnName); } @Override public List getTimestampList(int columnIndex) { - return delegate.getTimestampList(columnIndex); + checkValidState(); + return delegate.get().getTimestampList(columnIndex); } @Override public List getTimestampList(String columnName) { - return delegate.getTimestampList(columnName); + checkValidState(); + return delegate.get().getTimestampList(columnName); } @Override public List getDateList(int columnIndex) { - return delegate.getDateList(columnIndex); + checkValidState(); + return delegate.get().getDateList(columnIndex); } @Override public List getDateList(String columnName) { - return delegate.getDateList(columnName); + checkValidState(); + return delegate.get().getDateList(columnName); } @Override public List getStructList(int columnIndex) { - return delegate.getStructList(columnIndex); + checkValidState(); + return delegate.get().getStructList(columnIndex); } @Override public List getStructList(String columnName) { - return delegate.getStructList(columnName); + checkValidState(); + return delegate.get().getStructList(columnName); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index d193ad1c75..879b632d17 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -59,6 +59,11 @@ public static ReadAndQueryOption prefetchChunks(int prefetchChunks) { return new FlowControlOption(prefetchChunks); } + public static ReadAndQueryOption bufferRows(int bufferRows) { + Preconditions.checkArgument(bufferRows > 0, "bufferRows should be greater than 0"); + return new BufferRowsOption(bufferRows); + } + /** * Specifying this will cause the list operations to fetch at most this many records in a page. */ @@ -115,8 +120,22 @@ void appendToOptions(Options options) { } } + static final class BufferRowsOption extends InternalOption implements ReadAndQueryOption { + final int bufferRows; + + BufferRowsOption(int bufferRows) { + this.bufferRows = bufferRows; + } + + @Override + void appendToOptions(Options options) { + options.bufferRows = bufferRows; + } + } + private Long limit; private Integer prefetchChunks; + private Integer bufferRows; private Integer pageSize; private String pageToken; private String filter; @@ -140,6 +159,14 @@ int prefetchChunks() { return prefetchChunks; } + boolean hasBufferRows() { + return bufferRows != null; + } + + int bufferRows() { + return bufferRows; + } + boolean hasPageSize() { return pageSize != null; } @@ -203,6 +230,10 @@ public boolean equals(Object o) { || hasPrefetchChunks() && that.hasPrefetchChunks() && Objects.equals(prefetchChunks(), that.prefetchChunks())) + && (!hasBufferRows() && !that.hasBufferRows() + || hasBufferRows() + && that.hasBufferRows() + && Objects.equals(bufferRows(), that.bufferRows())) && (!hasPageSize() && !that.hasPageSize() || hasPageSize() && that.hasPageSize() && Objects.equals(pageSize(), that.pageSize())) && Objects.equals(pageToken(), that.pageToken()) @@ -218,6 +249,9 @@ public int hashCode() { if (prefetchChunks != null) { result = 31 * result + prefetchChunks.hashCode(); } + if (bufferRows != null) { + result = 31 * result + bufferRows.hashCode(); + } if (pageSize != null) { result = 31 * result + pageSize.hashCode(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java index 638c567a03..96ae390dd6 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java @@ -33,6 +33,7 @@ import com.google.spanner.v1.TransactionOptions; import com.google.spanner.v1.TransactionSelector; import io.grpc.Status.Code; +import io.opencensus.trace.Span; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.logging.Level; @@ -160,4 +161,8 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo public void invalidate() { isValid = false; } + + // No-op method needed to implement SessionTransaction interface. + @Override + public void setSpan(Span span) {} } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ReadContext.java index 16f40769fa..e87d40fb20 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ReadContext.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import com.google.api.core.ApiFuture; import com.google.cloud.spanner.Options.QueryOption; import com.google.cloud.spanner.Options.ReadOption; import javax.annotation.Nullable; @@ -65,6 +66,13 @@ enum QueryAnalyzeMode { */ ResultSet read(String table, KeySet keys, Iterable columns, ReadOption... options); + /** + * Same as {@link #read(String, KeySet, Iterable, ReadOption...)}, but is guaranteed to be + * non-blocking and will return the results as an {@link AsyncResultSet}. + */ + AsyncResultSet readAsync( + String table, KeySet keys, Iterable columns, ReadOption... options); + /** * Reads zero or more rows from a database using an index. * @@ -93,6 +101,13 @@ enum QueryAnalyzeMode { ResultSet readUsingIndex( String table, String index, KeySet keys, Iterable columns, ReadOption... options); + /** + * Same as {@link #readUsingIndex(String, String, KeySet, Iterable, ReadOption...)}, but is + * guaranteed to be non-blocking and will return its results as an {@link AsyncResultSet}. + */ + AsyncResultSet readUsingIndexAsync( + String table, String index, KeySet keys, Iterable columns, ReadOption... options); + /** * Reads a single row from a database, returning {@code null} if the row does not exist. * @@ -112,6 +127,9 @@ ResultSet readUsingIndex( @Nullable Struct readRow(String table, Key key, Iterable columns); + /** Same as {@link #readRow(String, Key, Iterable)}, but is guaranteed to be non-blocking. */ + ApiFuture readRowAsync(String table, Key key, Iterable columns); + /** * Reads a single row from a database using an index, returning {@code null} if the row does not * exist. @@ -134,6 +152,13 @@ ResultSet readUsingIndex( @Nullable Struct readRowUsingIndex(String table, String index, Key key, Iterable columns); + /** + * Same as {@link #readRowUsingIndex(String, String, Key, Iterable)}, but is guaranteed to be + * non-blocking. + */ + ApiFuture readRowUsingIndexAsync( + String table, String index, Key key, Iterable columns); + /** * Executes a query against the database. * @@ -160,6 +185,12 @@ ResultSet readUsingIndex( */ ResultSet executeQuery(Statement statement, QueryOption... options); + /** + * Same as {@link #executeQuery(Statement, QueryOption...)}, but is guaranteed to be non-blocking + * and returns its results as an {@link AsyncResultSet}. + */ + AsyncResultSet executeQueryAsync(Statement statement, QueryOption... options); + /** * Analyzes a query and returns query plan and/or query execution statistics information. * diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java index 29c3e52c6a..278b15d967 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java @@ -16,6 +16,8 @@ package com.google.cloud.spanner; +import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.core.InstantiatingExecutorProvider; import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; @@ -23,6 +25,7 @@ import com.google.cloud.spanner.Type.StructField; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.spanner.v1.ResultSetStats; import java.util.List; @@ -41,6 +44,30 @@ public static ResultSet forRows(Type type, Iterable rows) { return new PrePopulatedResultSet(type, rows); } + /** Converts the given {@link ResultSet} to an {@link AsyncResultSet}. */ + public static AsyncResultSet toAsyncResultSet(ResultSet delegate) { + return new AsyncResultSetImpl( + InstantiatingExecutorProvider.newBuilder() + .setExecutorThreadCount(1) + .setThreadFactory( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("test-async-resultset-%d") + .build()) + .build(), + delegate, + 100); + } + + /** + * Converts the given {@link ResultSet} to an {@link AsyncResultSet} using the given {@link + * ExecutorProvider}. + */ + public static AsyncResultSet toAsyncResultSet( + ResultSet delegate, ExecutorProvider executorProvider) { + return new AsyncResultSetImpl(executorProvider, delegate, 100); + } + private static class PrePopulatedResultSet implements ResultSet { private final List rows; private final Type type; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index b865efa2d9..ce4d27e94e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.api.core.ApiFuture; +import com.google.api.core.SettableApiFuture; import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbstractReadContext.MultiUseReadOnlyTransaction; import com.google.cloud.spanner.AbstractReadContext.SingleReadContext; @@ -28,6 +29,7 @@ import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.collect.Lists; +import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import com.google.spanner.v1.BeginTransactionRequest; @@ -43,6 +45,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; import javax.annotation.Nullable; /** @@ -76,14 +79,17 @@ static void throwIfTransactionsPending() { static interface SessionTransaction { /** Invalidates the transaction, generally because a new one has been started on the session. */ void invalidate(); + /** Registers the current span on the transaction. */ + void setSpan(Span span); } private final SpannerImpl spanner; private final String name; private final DatabaseId databaseId; private SessionTransaction activeTransaction; - private ByteString readyTransactionId; + ByteString readyTransactionId; private final Map options; + private Span currentSpan; SessionImpl(SpannerImpl spanner, String name, Map options) { this.spanner = spanner; @@ -101,6 +107,10 @@ public String getName() { return options; } + void setCurrentSpan(Span span) { + currentSpan = span; + } + @Override public long executePartitionedUpdate(Statement stmt) { setActive(null); @@ -170,6 +180,8 @@ public ReadContext singleUse(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setSpan(currentSpan) + .setExecutorProvider(spanner.getAsyncExecutorProvider()) .build()); } @@ -187,6 +199,8 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setSpan(currentSpan) + .setExecutorProvider(spanner.getAsyncExecutorProvider()) .buildSingleUseReadOnlyTransaction()); } @@ -204,6 +218,8 @@ public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setSpan(currentSpan) + .setExecutorProvider(spanner.getAsyncExecutorProvider()) .build()); } @@ -213,6 +229,23 @@ public TransactionRunner readWriteTransaction() { new TransactionRunnerImpl(this, spanner.getRpc(), spanner.getDefaultPrefetchChunks())); } + @Override + public AsyncRunner runAsync() { + return new AsyncRunnerImpl( + setActive( + new TransactionRunnerImpl(this, spanner.getRpc(), spanner.getDefaultPrefetchChunks()))); + } + + @Override + public TransactionManager transactionManager() { + return new TransactionManagerImpl(this, currentSpan); + } + + @Override + public AsyncTransactionManagerImpl transactionManagerAsync() { + return new AsyncTransactionManagerImpl(this, currentSpan); + } + @Override public void prepareReadWriteTransaction() { setActive(null); @@ -238,27 +271,59 @@ public void close() { } ByteString beginTransaction() { - Span span = tracer.spanBuilder(SpannerImpl.BEGIN_TRANSACTION).startSpan(); - try (Scope s = tracer.withSpan(span)) { - final BeginTransactionRequest request = - BeginTransactionRequest.newBuilder() - .setSession(name) - .setOptions( - TransactionOptions.newBuilder() - .setReadWrite(TransactionOptions.ReadWrite.getDefaultInstance())) - .build(); - Transaction txn = spanner.getRpc().beginTransaction(request, options); - if (txn.getId().isEmpty()) { - throw newSpannerException(ErrorCode.INTERNAL, "Missing id in transaction\n" + getName()); - } - span.end(TraceUtil.END_SPAN_OPTIONS); - return txn.getId(); - } catch (RuntimeException e) { - TraceUtil.endSpanWithFailure(span, e); - throw e; + try { + return beginTransactionAsync().get(); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause() == null ? e : e.getCause()); + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); } } + ApiFuture beginTransactionAsync() { + final SettableApiFuture res = SettableApiFuture.create(); + final Span span = tracer.spanBuilder(SpannerImpl.BEGIN_TRANSACTION).startSpan(); + final BeginTransactionRequest request = + BeginTransactionRequest.newBuilder() + .setSession(name) + .setOptions( + TransactionOptions.newBuilder() + .setReadWrite(TransactionOptions.ReadWrite.getDefaultInstance())) + .build(); + final ApiFuture requestFuture = + spanner.getRpc().beginTransactionAsync(request, options); + requestFuture.addListener( + tracer.withSpan( + span, + new Runnable() { + @Override + public void run() { + try { + Transaction txn = requestFuture.get(); + if (txn.getId().isEmpty()) { + throw newSpannerException( + ErrorCode.INTERNAL, "Missing id in transaction\n" + getName()); + } + span.end(TraceUtil.END_SPAN_OPTIONS); + res.set(txn.getId()); + } catch (ExecutionException e) { + TraceUtil.endSpanWithFailure(span, e); + res.setException( + SpannerExceptionFactory.newSpannerException( + e.getCause() == null ? e : e.getCause())); + } catch (InterruptedException e) { + TraceUtil.endSpanWithFailure(span, e); + res.setException(SpannerExceptionFactory.propagateInterrupt(e)); + } catch (Exception e) { + TraceUtil.endSpanWithFailure(span, e); + res.setException(e); + } + } + }), + MoreExecutors.directExecutor()); + return res; + } + TransactionContextImpl newTransaction() { return TransactionContextImpl.newBuilder() .setSession(this) @@ -266,6 +331,8 @@ TransactionContextImpl newTransaction() { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setSpan(currentSpan) + .setExecutorProvider(spanner.getAsyncExecutorProvider()) .build(); } @@ -277,11 +344,13 @@ T setActive(@Nullable T ctx) { } activeTransaction = ctx; readyTransactionId = null; + if (activeTransaction != null) { + activeTransaction.setSpan(currentSpan); + } return ctx; } - @Override - public TransactionManager transactionManager() { - return new TransactionManagerImpl(this); + boolean hasReadyTransaction() { + return readyTransactionId != null; } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index 2286c3b34b..90e399fad6 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -40,6 +40,8 @@ import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.Timestamp; import com.google.cloud.grpc.GrpcTransportOptions; import com.google.cloud.grpc.GrpcTransportOptions.ExecutorFactory; @@ -48,6 +50,7 @@ import com.google.cloud.spanner.SessionClient.SessionConsumer; import com.google.cloud.spanner.SpannerException.ResourceNotFoundException; import com.google.cloud.spanner.SpannerImpl.ClosedException; +import com.google.cloud.spanner.TransactionManager.TransactionState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.base.MoreObjects; @@ -56,11 +59,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ForwardingListenableFuture; +import com.google.common.util.concurrent.ForwardingListenableFuture.SimpleForwardingListenableFuture; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.Empty; import io.opencensus.common.Scope; import io.opencensus.common.ToLongFunction; @@ -85,12 +89,17 @@ import java.util.Queue; import java.util.Random; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -132,53 +141,115 @@ Instant instant() { } } + private abstract static class CachedResultSetSupplier implements Supplier { + private ResultSet cached; + + abstract ResultSet load(); + + ResultSet reload() { + return cached = load(); + } + + @Override + public ResultSet get() { + if (cached == null) { + cached = load(); + } + return cached; + } + } + /** * Wrapper around {@code ReadContext} that releases the session to the pool once the call is * finished, if it is a single use context. */ private static class AutoClosingReadContext implements ReadContext { - private final Function readContextDelegateSupplier; + /** + * {@link AsyncResultSet} implementation that keeps track of the async operations that are still + * running for this {@link ReadContext} and that should finish before the {@link ReadContext} + * releases its session back into the pool. + */ + private class AutoClosingReadContextAsyncResultSetImpl extends AsyncResultSetImpl { + private AutoClosingReadContextAsyncResultSetImpl( + ExecutorProvider executorProvider, ResultSet delegate, int bufferRows) { + super(executorProvider, delegate, bufferRows); + } + + @Override + public ApiFuture setCallback(Executor exec, ReadyCallback cb) { + Runnable listener = + new Runnable() { + @Override + public void run() { + synchronized (lock) { + if (asyncOperationsCount.decrementAndGet() == 0 && closed) { + // All async operations for this read context have finished. + AutoClosingReadContext.this.close(); + } + } + } + }; + try { + asyncOperationsCount.incrementAndGet(); + addListener(listener); + return super.setCallback(exec, cb); + } catch (Throwable t) { + removeListener(listener); + asyncOperationsCount.decrementAndGet(); + throw t; + } + } + } + + private final Function readContextDelegateSupplier; private T readContextDelegate; private final SessionPool sessionPool; - private PooledSession session; private final boolean isSingleUse; - private boolean closed; + private final AtomicInteger asyncOperationsCount = new AtomicInteger(); + + private Object lock = new Object(); + + @GuardedBy("lock") private boolean sessionUsedForQuery = false; + @GuardedBy("lock") + private PooledSessionFuture session; + + @GuardedBy("lock") + private boolean closed; + + @GuardedBy("lock") + private boolean delegateClosed; + private AutoClosingReadContext( - Function delegateSupplier, + Function delegateSupplier, SessionPool sessionPool, - PooledSession session, + PooledSessionFuture session, boolean isSingleUse) { this.readContextDelegateSupplier = delegateSupplier; this.sessionPool = sessionPool; this.session = session; this.isSingleUse = isSingleUse; - while (true) { - try { - this.readContextDelegate = readContextDelegateSupplier.apply(this.session); - break; - } catch (SessionNotFoundException e) { - replaceSessionIfPossible(e); - } - } } T getReadContextDelegate() { + synchronized (lock) { + if (readContextDelegate == null) { + while (true) { + try { + this.readContextDelegate = readContextDelegateSupplier.apply(this.session); + break; + } catch (SessionNotFoundException e) { + replaceSessionIfPossible(e); + } + } + } + } return readContextDelegate; } - private ResultSet wrap(final Supplier resultSetSupplier) { - ResultSet res; - while (true) { - try { - res = resultSetSupplier.get(); - break; - } catch (SessionNotFoundException e) { - replaceSessionIfPossible(e); - } - } - return new ForwardingResultSet(res) { + private ResultSet wrap(final CachedResultSetSupplier resultSetSupplier) { + return new ForwardingResultSet(resultSetSupplier) { private boolean beforeFirst = true; @Override @@ -187,8 +258,18 @@ public boolean next() throws SpannerException { try { return internalNext(); } catch (SessionNotFoundException e) { - replaceSessionIfPossible(e); - replaceDelegate(resultSetSupplier.get()); + while (true) { + // Keep the replace-if-possible outside the try-block to let the exception bubble up + // if it's too late to replace the session. + replaceSessionIfPossible(e); + try { + replaceDelegate(resultSetSupplier.reload()); + break; + } catch (SessionNotFoundException snfe) { + e = snfe; + // retry on yet another session. + } + } } } } @@ -197,9 +278,11 @@ private boolean internalNext() { try { boolean ret = super.next(); if (beforeFirst) { - session.markUsed(); - beforeFirst = false; - sessionUsedForQuery = true; + synchronized (lock) { + session.get().markUsed(); + beforeFirst = false; + sessionUsedForQuery = true; + } } if (!ret && isSingleUse) { close(); @@ -208,9 +291,11 @@ private boolean internalNext() { } catch (SessionNotFoundException e) { throw e; } catch (SpannerException e) { - if (!closed && isSingleUse) { - session.lastException = e; - AutoClosingReadContext.this.close(); + synchronized (lock) { + if (!closed && isSingleUse) { + session.get().lastException = e; + AutoClosingReadContext.this.close(); + } } throw e; } @@ -218,22 +303,27 @@ private boolean internalNext() { @Override public void close() { - super.close(); - if (isSingleUse) { - AutoClosingReadContext.this.close(); + try { + super.close(); + } finally { + if (isSingleUse) { + AutoClosingReadContext.this.close(); + } } } }; } - private void replaceSessionIfPossible(SessionNotFoundException e) { - if (isSingleUse || !sessionUsedForQuery) { - // This class is only used by read-only transactions, so we know that we only need a - // read-only session. - session = sessionPool.replaceReadSession(e, session); - readContextDelegate = readContextDelegateSupplier.apply(session); - } else { - throw e; + private void replaceSessionIfPossible(SessionNotFoundException notFound) { + synchronized (lock) { + if (isSingleUse || !sessionUsedForQuery) { + // This class is only used by read-only transactions, so we know that we only need a + // read-only session. + session = sessionPool.replaceReadSession(notFound, session); + readContextDelegate = readContextDelegateSupplier.apply(session); + } else { + throw notFound; + } } } @@ -244,14 +334,37 @@ public ResultSet read( final Iterable columns, final ReadOption... options) { return wrap( - new Supplier() { + new CachedResultSetSupplier() { @Override - public ResultSet get() { - return readContextDelegate.read(table, keys, columns, options); + ResultSet load() { + return getReadContextDelegate().read(table, keys, columns, options); } }); } + @Override + public AsyncResultSet readAsync( + final String table, + final KeySet keys, + final Iterable columns, + final ReadOption... options) { + Options readOptions = Options.fromReadOptions(options); + final int bufferRows = + readOptions.hasBufferRows() + ? readOptions.bufferRows() + : AsyncResultSetImpl.DEFAULT_BUFFER_SIZE; + return new AutoClosingReadContextAsyncResultSetImpl( + sessionPool.sessionClient.getSpanner().getAsyncExecutorProvider(), + wrap( + new CachedResultSetSupplier() { + @Override + ResultSet load() { + return getReadContextDelegate().read(table, keys, columns, options); + } + }), + bufferRows); + } + @Override public ResultSet readUsingIndex( final String table, @@ -260,84 +373,159 @@ public ResultSet readUsingIndex( final Iterable columns, final ReadOption... options) { return wrap( - new Supplier() { + new CachedResultSetSupplier() { @Override - public ResultSet get() { - return readContextDelegate.readUsingIndex(table, index, keys, columns, options); + ResultSet load() { + return getReadContextDelegate().readUsingIndex(table, index, keys, columns, options); } }); } + @Override + public AsyncResultSet readUsingIndexAsync( + final String table, + final String index, + final KeySet keys, + final Iterable columns, + final ReadOption... options) { + Options readOptions = Options.fromReadOptions(options); + final int bufferRows = + readOptions.hasBufferRows() + ? readOptions.bufferRows() + : AsyncResultSetImpl.DEFAULT_BUFFER_SIZE; + return new AutoClosingReadContextAsyncResultSetImpl( + sessionPool.sessionClient.getSpanner().getAsyncExecutorProvider(), + wrap( + new CachedResultSetSupplier() { + @Override + ResultSet load() { + return getReadContextDelegate() + .readUsingIndex(table, index, keys, columns, options); + } + }), + bufferRows); + } + @Override @Nullable public Struct readRow(String table, Key key, Iterable columns) { try { while (true) { try { - session.markUsed(); - return readContextDelegate.readRow(table, key, columns); + synchronized (lock) { + session.get().markUsed(); + } + return getReadContextDelegate().readRow(table, key, columns); } catch (SessionNotFoundException e) { replaceSessionIfPossible(e); } } } finally { - sessionUsedForQuery = true; + synchronized (lock) { + sessionUsedForQuery = true; + } if (isSingleUse) { close(); } } } + @Override + public ApiFuture readRowAsync(String table, Key key, Iterable columns) { + try (AsyncResultSet rs = readAsync(table, KeySet.singleKey(key), columns)) { + return AbstractReadContext.consumeSingleRowAsync(rs); + } + } + @Override @Nullable public Struct readRowUsingIndex(String table, String index, Key key, Iterable columns) { try { while (true) { try { - session.markUsed(); - return readContextDelegate.readRowUsingIndex(table, index, key, columns); + synchronized (lock) { + session.get().markUsed(); + } + return getReadContextDelegate().readRowUsingIndex(table, index, key, columns); } catch (SessionNotFoundException e) { replaceSessionIfPossible(e); } } } finally { - sessionUsedForQuery = true; + synchronized (lock) { + sessionUsedForQuery = true; + } if (isSingleUse) { close(); } } } + @Override + public ApiFuture readRowUsingIndexAsync( + String table, String index, Key key, Iterable columns) { + try (AsyncResultSet rs = readUsingIndexAsync(table, index, KeySet.singleKey(key), columns)) { + return AbstractReadContext.consumeSingleRowAsync(rs); + } + } + @Override public ResultSet executeQuery(final Statement statement, final QueryOption... options) { return wrap( - new Supplier() { + new CachedResultSetSupplier() { @Override - public ResultSet get() { - return readContextDelegate.executeQuery(statement, options); + ResultSet load() { + return getReadContextDelegate().executeQuery(statement, options); } }); } + @Override + public AsyncResultSet executeQueryAsync( + final Statement statement, final QueryOption... options) { + Options queryOptions = Options.fromQueryOptions(options); + final int bufferRows = + queryOptions.hasBufferRows() + ? queryOptions.bufferRows() + : AsyncResultSetImpl.DEFAULT_BUFFER_SIZE; + return new AutoClosingReadContextAsyncResultSetImpl( + sessionPool.sessionClient.getSpanner().getAsyncExecutorProvider(), + wrap( + new CachedResultSetSupplier() { + @Override + ResultSet load() { + return getReadContextDelegate().executeQuery(statement, options); + } + }), + bufferRows); + } + @Override public ResultSet analyzeQuery(final Statement statement, final QueryAnalyzeMode queryMode) { return wrap( - new Supplier() { + new CachedResultSetSupplier() { @Override - public ResultSet get() { - return readContextDelegate.analyzeQuery(statement, queryMode); + ResultSet load() { + return getReadContextDelegate().analyzeQuery(statement, queryMode); } }); } @Override public void close() { - if (closed) { - return; + synchronized (lock) { + if (closed && delegateClosed) { + return; + } + closed = true; + if (asyncOperationsCount.get() == 0) { + if (readContextDelegate != null) { + readContextDelegate.close(); + } + session.close(); + delegateClosed = true; + } } - closed = true; - readContextDelegate.close(); - session.close(); } } @@ -345,9 +533,9 @@ private static class AutoClosingReadTransaction extends AutoClosingReadContext implements ReadOnlyTransaction { AutoClosingReadTransaction( - Function txnSupplier, + Function txnSupplier, SessionPool sessionPool, - PooledSession session, + PooledSessionFuture session, boolean isSingleUse) { super(txnSupplier, sessionPool, session, isSingleUse); } @@ -394,6 +582,13 @@ public ResultSet read( return new SessionPoolResultSet(delegate.read(table, keys, columns, options)); } + @Override + public AsyncResultSet readAsync( + String table, KeySet keys, Iterable columns, ReadOption... options) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.UNIMPLEMENTED, "not yet implemented"); + } + @Override public ResultSet readUsingIndex( String table, @@ -405,6 +600,17 @@ public ResultSet readUsingIndex( delegate.readUsingIndex(table, index, keys, columns, options)); } + @Override + public AsyncResultSet readUsingIndexAsync( + String table, + String index, + KeySet keys, + Iterable columns, + ReadOption... options) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.UNIMPLEMENTED, "not yet implemented"); + } + @Override public Struct readRow(String table, Key key, Iterable columns) { try { @@ -414,6 +620,13 @@ public Struct readRow(String table, Key key, Iterable columns) { } } + @Override + public ApiFuture readRowAsync(String table, Key key, Iterable columns) { + try (AsyncResultSet rs = readAsync(table, KeySet.singleKey(key), columns)) { + return AbstractReadContext.consumeSingleRowAsync(rs); + } + } + @Override public void buffer(Mutation mutation) { delegate.buffer(mutation); @@ -429,6 +642,15 @@ public Struct readRowUsingIndex( } } + @Override + public ApiFuture readRowUsingIndexAsync( + String table, String index, Key key, Iterable columns) { + try (AsyncResultSet rs = + readUsingIndexAsync(table, index, KeySet.singleKey(key), columns)) { + return AbstractReadContext.consumeSingleRowAsync(rs); + } + } + @Override public void buffer(Iterable mutations) { delegate.buffer(mutations); @@ -443,6 +665,15 @@ public long executeUpdate(Statement statement) { } } + @Override + public ApiFuture executeUpdateAsync(Statement statement) { + try { + return delegate.executeUpdateAsync(statement); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + @Override public long[] batchUpdate(Iterable statements) { try { @@ -452,11 +683,29 @@ public long[] batchUpdate(Iterable statements) { } } + @Override + public ApiFuture batchUpdateAsync(Iterable statements) { + try { + return delegate.batchUpdateAsync(statements); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + @Override public ResultSet executeQuery(Statement statement, QueryOption... options) { return new SessionPoolResultSet(delegate.executeQuery(statement, options)); } + @Override + public AsyncResultSet executeQueryAsync(Statement statement, QueryOption... options) { + try { + return delegate.executeQueryAsync(statement, options); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + @Override public ResultSet analyzeQuery(Statement statement, QueryAnalyzeMode queryMode) { return new SessionPoolResultSet(delegate.analyzeQuery(statement, queryMode)); @@ -470,39 +719,40 @@ public void close() { private TransactionManager delegate; private final SessionPool sessionPool; - private PooledSession session; + private PooledSessionFuture session; private boolean closed; private boolean restartedAfterSessionNotFound; - AutoClosingTransactionManager(SessionPool sessionPool, PooledSession session) { + AutoClosingTransactionManager(SessionPool sessionPool, PooledSessionFuture session) { this.sessionPool = sessionPool; this.session = session; - this.delegate = session.delegate.transactionManager(); } @Override public TransactionContext begin() { + this.delegate = session.get().transactionManager(); while (true) { try { return internalBegin(); } catch (SessionNotFoundException e) { session = sessionPool.replaceReadWriteSession(e, session); - delegate = session.delegate.transactionManager(); + delegate = session.get().delegate.transactionManager(); } } } private TransactionContext internalBegin() { TransactionContext res = new SessionPoolTransactionContext(delegate.begin()); - session.markUsed(); + session.get().markUsed(); return res; } - private SpannerException handleSessionNotFound(SessionNotFoundException e) { - session = sessionPool.replaceReadWriteSession(e, session); - delegate = session.delegate.transactionManager(); + private SpannerException handleSessionNotFound(SessionNotFoundException notFound) { + session = sessionPool.replaceReadWriteSession(notFound, session); + delegate = session.get().delegate.transactionManager(); restartedAfterSessionNotFound = true; - return SpannerExceptionFactory.newSpannerException(ErrorCode.ABORTED, e.getMessage(), e); + return SpannerExceptionFactory.newSpannerException( + ErrorCode.ABORTED, notFound.getMessage(), notFound); } @Override @@ -540,7 +790,7 @@ public TransactionContext resetForRetry() { } } catch (SessionNotFoundException e) { session = sessionPool.replaceReadWriteSession(e, session); - delegate = session.delegate.transactionManager(); + delegate = session.get().delegate.transactionManager(); restartedAfterSessionNotFound = true; } } @@ -558,7 +808,9 @@ public void close() { } closed = true; try { - delegate.close(); + if (delegate != null) { + delegate.close(); + } } finally { session.close(); } @@ -569,7 +821,7 @@ public TransactionState getState() { if (restartedAfterSessionNotFound) { return TransactionState.ABORTED; } else { - return delegate.getState(); + return delegate == null ? null : delegate.getState(); } } } @@ -580,13 +832,19 @@ public TransactionState getState() { */ private static final class SessionPoolTransactionRunner implements TransactionRunner { private final SessionPool sessionPool; - private PooledSession session; + private PooledSessionFuture session; private TransactionRunner runner; - private SessionPoolTransactionRunner(SessionPool sessionPool, PooledSession session) { + private SessionPoolTransactionRunner(SessionPool sessionPool, PooledSessionFuture session) { this.sessionPool = sessionPool; this.session = session; - this.runner = session.delegate.readWriteTransaction(); + } + + private TransactionRunner getRunner() { + if (this.runner == null) { + this.runner = session.get().readWriteTransaction(); + } + return runner; } @Override @@ -596,17 +854,17 @@ public T run(TransactionCallable callable) { T result; while (true) { try { - result = runner.run(callable); + result = getRunner().run(callable); break; } catch (SessionNotFoundException e) { session = sessionPool.replaceReadWriteSession(e, session); - runner = session.delegate.readWriteTransaction(); + runner = session.get().delegate.readWriteTransaction(); } } - session.markUsed(); + session.get().markUsed(); return result; } catch (SpannerException e) { - throw session.lastException = e; + throw session.get().lastException = e; } finally { session.close(); } @@ -614,19 +872,86 @@ public T run(TransactionCallable callable) { @Override public Timestamp getCommitTimestamp() { - return runner.getCommitTimestamp(); + return getRunner().getCommitTimestamp(); } @Override public TransactionRunner allowNestedTransaction() { - runner.allowNestedTransaction(); + getRunner().allowNestedTransaction(); return this; } } + private static class SessionPoolAsyncRunner implements AsyncRunner { + private final SessionPool sessionPool; + private volatile PooledSessionFuture session; + private final SettableApiFuture commitTimestamp = SettableApiFuture.create(); + + private SessionPoolAsyncRunner(SessionPool sessionPool, PooledSessionFuture session) { + this.sessionPool = sessionPool; + this.session = session; + } + + @Override + public ApiFuture runAsync(final AsyncWork work, Executor executor) { + final SettableApiFuture res = SettableApiFuture.create(); + executor.execute( + new Runnable() { + @Override + public void run() { + SpannerException se = null; + R r = null; + AsyncRunner runner = null; + while (true) { + try { + runner = session.get().runAsync(); + r = runner.runAsync(work, MoreExecutors.directExecutor()).get(); + break; + } catch (ExecutionException e) { + se = SpannerExceptionFactory.newSpannerException(e.getCause()); + } catch (InterruptedException e) { + se = SpannerExceptionFactory.propagateInterrupt(e); + } catch (Throwable t) { + se = SpannerExceptionFactory.newSpannerException(t); + } finally { + if (se != null && se instanceof SessionNotFoundException) { + session = + sessionPool.replaceReadWriteSession((SessionNotFoundException) se, session); + } else { + break; + } + } + } + session.get().markUsed(); + session.close(); + setCommitTimestamp(runner); + if (se != null) { + res.setException(se); + } else { + res.set(r); + } + } + }); + return res; + } + + private void setCommitTimestamp(AsyncRunner delegate) { + try { + commitTimestamp.set(delegate.getCommitTimestamp().get()); + } catch (Throwable t) { + commitTimestamp.setException(t); + } + } + + @Override + public ApiFuture getCommitTimestamp() { + return commitTimestamp; + } + } + // Exception class used just to track the stack trace at the point when a session was handed out // from the pool. - private final class LeakedSessionException extends RuntimeException { + final class LeakedSessionException extends RuntimeException { private static final long serialVersionUID = 1451131180314064914L; private LeakedSessionException() { @@ -640,25 +965,124 @@ private enum SessionState { CLOSING, } - final class PooledSession implements Session { - @VisibleForTesting SessionImpl delegate; - private volatile Instant lastUseTime; - private volatile SpannerException lastException; - private volatile LeakedSessionException leakedException; - private volatile boolean allowReplacing = true; + /** + * Forwarding future that will return a {@link PooledSession}. If {@link #inProcessPrepare} has + * been set to true, the returned session will be prepared with a read/write session using the + * thread of the caller to {@link #get()}. This ensures that the executor that is responsible for + * background preparing of read/write transactions is not overwhelmed by requests in case of a + * large burst of write requests. Instead of filling up the queue of the background executor, the + * caller threads will be used for the BeginTransaction call. + */ + private final class ForwardingListenablePooledSessionFuture + extends SimpleForwardingListenableFuture { + private final boolean inProcessPrepare; + private final Span span; + private volatile boolean initialized = false; + private final Object prepareLock = new Object(); + private volatile PooledSession result; + private volatile SpannerException error; + + private ForwardingListenablePooledSessionFuture( + ListenableFuture delegate, boolean inProcessPrepare, Span span) { + super(delegate); + this.inProcessPrepare = inProcessPrepare; + this.span = span; + } - @GuardedBy("lock") - private SessionState state; + @Override + public PooledSession get() throws InterruptedException, ExecutionException { + try { + return initialize(super.get()); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause()); + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } + } - private PooledSession(SessionImpl delegate) { - this.delegate = delegate; - this.state = SessionState.AVAILABLE; - this.lastUseTime = clock.instant(); + @Override + public PooledSession get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + try { + return initialize(super.get(timeout, unit)); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause()); + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } catch (TimeoutException e) { + throw SpannerExceptionFactory.propagateTimeout(e); + } } - @VisibleForTesting - void setAllowReplacing(boolean allowReplacing) { - this.allowReplacing = allowReplacing; + private PooledSession initialize(PooledSession sess) { + if (!initialized) { + synchronized (prepareLock) { + if (!initialized) { + try { + result = prepare(sess); + } catch (Throwable t) { + error = SpannerExceptionFactory.newSpannerException(t); + } finally { + initialized = true; + } + } + } + } + if (error != null) { + throw error; + } + return result; + } + + private PooledSession prepare(PooledSession sess) { + if (inProcessPrepare && !sess.delegate.hasReadyTransaction()) { + while (true) { + try { + sess.prepareReadWriteTransaction(); + synchronized (lock) { + stopAutomaticPrepare = false; + } + break; + } catch (Throwable t) { + if (isClosed()) { + span.addAnnotation("Pool has been closed"); + throw new IllegalStateException("Pool has been closed"); + } + SpannerException e = newSpannerException(t); + WaiterFuture waiter = new WaiterFuture(); + synchronized (lock) { + handlePrepareSessionFailure(e, sess, false); + if (!isSessionNotFound(e)) { + throw e; + } + readWaiters.add(waiter); + } + sess = waiter.get(); + if (sess.delegate.hasReadyTransaction()) { + break; + } + } + } + } + return sess; + } + } + + private PooledSessionFuture createPooledSessionFuture( + ListenableFuture future, Span span) { + return new PooledSessionFuture(future, span); + } + + final class PooledSessionFuture extends SimpleForwardingListenableFuture + implements Session { + private volatile LeakedSessionException leakedException; + private volatile AtomicBoolean inUse = new AtomicBoolean(); + private volatile CountDownLatch initialized = new CountDownLatch(1); + private final Span span; + + private PooledSessionFuture(ListenableFuture delegate, Span span) { + super(delegate); + this.span = span; } @VisibleForTesting @@ -666,34 +1090,14 @@ void clearLeakedException() { this.leakedException = null; } - private void markBusy() { - this.state = SessionState.BUSY; + private void markCheckedOut() { this.leakedException = new LeakedSessionException(); } - private void markClosing() { - this.state = SessionState.CLOSING; - } - @Override public Timestamp write(Iterable mutations) throws SpannerException { try { - markUsed(); - return delegate.write(mutations); - } catch (SpannerException e) { - throw lastException = e; - } finally { - close(); - } - } - - @Override - public long executePartitionedUpdate(Statement stmt) throws SpannerException { - try { - markUsed(); - return delegate.executePartitionedUpdate(stmt); - } catch (SpannerException e) { - throw lastException = e; + return get().write(mutations); } finally { close(); } @@ -702,10 +1106,7 @@ public long executePartitionedUpdate(Statement stmt) throws SpannerException { @Override public Timestamp writeAtLeastOnce(Iterable mutations) throws SpannerException { try { - markUsed(); - return delegate.writeAtLeastOnce(mutations); - } catch (SpannerException e) { - throw lastException = e; + return get().writeAtLeastOnce(mutations); } finally { close(); } @@ -715,10 +1116,10 @@ public Timestamp writeAtLeastOnce(Iterable mutations) throws SpannerEx public ReadContext singleUse() { try { return new AutoClosingReadContext<>( - new Function() { + new Function() { @Override - public ReadContext apply(PooledSession session) { - return session.delegate.singleUse(); + public ReadContext apply(PooledSessionFuture session) { + return session.get().delegate.singleUse(); } }, SessionPool.this, @@ -734,10 +1135,10 @@ public ReadContext apply(PooledSession session) { public ReadContext singleUse(final TimestampBound bound) { try { return new AutoClosingReadContext<>( - new Function() { + new Function() { @Override - public ReadContext apply(PooledSession session) { - return session.delegate.singleUse(bound); + public ReadContext apply(PooledSessionFuture session) { + return session.get().delegate.singleUse(bound); } }, SessionPool.this, @@ -752,10 +1153,10 @@ public ReadContext apply(PooledSession session) { @Override public ReadOnlyTransaction singleUseReadOnlyTransaction() { return internalReadOnlyTransaction( - new Function() { + new Function() { @Override - public ReadOnlyTransaction apply(PooledSession session) { - return session.delegate.singleUseReadOnlyTransaction(); + public ReadOnlyTransaction apply(PooledSessionFuture session) { + return session.get().delegate.singleUseReadOnlyTransaction(); } }, true); @@ -764,10 +1165,10 @@ public ReadOnlyTransaction apply(PooledSession session) { @Override public ReadOnlyTransaction singleUseReadOnlyTransaction(final TimestampBound bound) { return internalReadOnlyTransaction( - new Function() { + new Function() { @Override - public ReadOnlyTransaction apply(PooledSession session) { - return session.delegate.singleUseReadOnlyTransaction(bound); + public ReadOnlyTransaction apply(PooledSessionFuture session) { + return session.get().delegate.singleUseReadOnlyTransaction(bound); } }, true); @@ -776,10 +1177,10 @@ public ReadOnlyTransaction apply(PooledSession session) { @Override public ReadOnlyTransaction readOnlyTransaction() { return internalReadOnlyTransaction( - new Function() { + new Function() { @Override - public ReadOnlyTransaction apply(PooledSession session) { - return session.delegate.readOnlyTransaction(); + public ReadOnlyTransaction apply(PooledSessionFuture session) { + return session.get().delegate.readOnlyTransaction(); } }, false); @@ -788,17 +1189,18 @@ public ReadOnlyTransaction apply(PooledSession session) { @Override public ReadOnlyTransaction readOnlyTransaction(final TimestampBound bound) { return internalReadOnlyTransaction( - new Function() { + new Function() { @Override - public ReadOnlyTransaction apply(PooledSession session) { - return session.delegate.readOnlyTransaction(bound); + public ReadOnlyTransaction apply(PooledSessionFuture session) { + return session.get().delegate.readOnlyTransaction(bound); } }, false); } private ReadOnlyTransaction internalReadOnlyTransaction( - Function transactionSupplier, boolean isSingleUse) { + Function transactionSupplier, + boolean isSingleUse) { try { return new AutoClosingReadTransaction( transactionSupplier, SessionPool.this, this, isSingleUse); @@ -813,6 +1215,188 @@ public TransactionRunner readWriteTransaction() { return new SessionPoolTransactionRunner(SessionPool.this, this); } + @Override + public TransactionManager transactionManager() { + return new AutoClosingTransactionManager(SessionPool.this, this); + } + + @Override + public AsyncRunner runAsync() { + return new SessionPoolAsyncRunner(SessionPool.this, this); + } + + @Override + public AsyncTransactionManager transactionManagerAsync() { + return new SessionPoolAsyncTransactionManager(this); + } + + @Override + public long executePartitionedUpdate(Statement stmt) { + try { + return get().executePartitionedUpdate(stmt); + } finally { + close(); + } + } + + @Override + public String getName() { + return get().getName(); + } + + @Override + public void prepareReadWriteTransaction() { + get().prepareReadWriteTransaction(); + } + + @Override + public void close() { + synchronized (lock) { + leakedException = null; + checkedOutSessions.remove(this); + } + get().close(); + } + + @Override + public ApiFuture asyncClose() { + synchronized (lock) { + leakedException = null; + checkedOutSessions.remove(this); + } + return get().asyncClose(); + } + + @Override + public PooledSession get() { + if (inUse.compareAndSet(false, true)) { + PooledSession res = null; + try { + res = super.get(); + } catch (Throwable e) { + // ignore the exception as it will be handled by the call to super.get() below. + } + if (res != null) { + res.markBusy(span); + span.addAnnotation(sessionAnnotation(res)); + synchronized (lock) { + incrementNumSessionsInUse(); + checkedOutSessions.add(this); + } + } + initialized.countDown(); + } + try { + initialized.await(); + return super.get(); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause()); + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } + } + } + + final class PooledSession implements Session { + @VisibleForTesting SessionImpl delegate; + private volatile Instant lastUseTime; + private volatile SpannerException lastException; + private volatile boolean allowReplacing = true; + + @GuardedBy("lock") + private SessionState state; + + private PooledSession(SessionImpl delegate) { + this.delegate = delegate; + this.state = SessionState.AVAILABLE; + this.lastUseTime = clock.instant(); + } + + @Override + public String toString() { + return getName(); + } + + @VisibleForTesting + void setAllowReplacing(boolean allowReplacing) { + this.allowReplacing = allowReplacing; + } + + @Override + public Timestamp write(Iterable mutations) throws SpannerException { + try { + markUsed(); + return delegate.write(mutations); + } catch (SpannerException e) { + throw lastException = e; + } + } + + @Override + public Timestamp writeAtLeastOnce(Iterable mutations) throws SpannerException { + try { + markUsed(); + return delegate.writeAtLeastOnce(mutations); + } catch (SpannerException e) { + throw lastException = e; + } + } + + @Override + public long executePartitionedUpdate(Statement stmt) throws SpannerException { + try { + markUsed(); + return delegate.executePartitionedUpdate(stmt); + } catch (SpannerException e) { + throw lastException = e; + } + } + + @Override + public ReadContext singleUse() { + return delegate.singleUse(); + } + + @Override + public ReadContext singleUse(TimestampBound bound) { + return delegate.singleUse(bound); + } + + @Override + public ReadOnlyTransaction singleUseReadOnlyTransaction() { + return delegate.singleUseReadOnlyTransaction(); + } + + @Override + public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { + return delegate.singleUseReadOnlyTransaction(bound); + } + + @Override + public ReadOnlyTransaction readOnlyTransaction() { + return delegate.readOnlyTransaction(); + } + + @Override + public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { + return delegate.readOnlyTransaction(bound); + } + + @Override + public TransactionRunner readWriteTransaction() { + return delegate.readWriteTransaction(); + } + + @Override + public AsyncRunner runAsync() { + return delegate.runAsync(); + } + + @Override + public AsyncTransactionManagerImpl transactionManagerAsync() { + return delegate.transactionManagerAsync(); + } + @Override public ApiFuture asyncClose() { close(); @@ -825,7 +1409,6 @@ public void close() { numSessionsInUse--; numSessionsReleased++; } - leakedException = null; if (lastException != null && isSessionNotFound(lastException)) { invalidateSession(this); } else { @@ -868,59 +1451,56 @@ private void keepAlive() { } } + private void markBusy(Span span) { + this.delegate.setCurrentSpan(span); + this.state = SessionState.BUSY; + } + + private void markClosing() { + this.state = SessionState.CLOSING; + } + void markUsed() { lastUseTime = clock.instant(); } @Override public TransactionManager transactionManager() { - return new AutoClosingTransactionManager(SessionPool.this, this); + return delegate.transactionManager(); } } - private static final class SessionOrError { - private final PooledSession session; - private final SpannerException e; - - SessionOrError(PooledSession session) { - this.session = session; - this.e = null; - } + private final class WaiterFuture extends ForwardingListenableFuture { + private static final long MAX_SESSION_WAIT_TIMEOUT = 240_000L; + private final SettableFuture waiter = SettableFuture.create(); - SessionOrError(SpannerException e) { - this.session = null; - this.e = e; + @Override + protected ListenableFuture delegate() { + return waiter; } - } - - private final class Waiter { - private static final long MAX_SESSION_WAIT_TIMEOUT = 240_000L; - private final SynchronousQueue waiter = new SynchronousQueue<>(); private void put(PooledSession session) { - Uninterruptibles.putUninterruptibly(waiter, new SessionOrError(session)); + waiter.set(session); } private void put(SpannerException e) { - Uninterruptibles.putUninterruptibly(waiter, new SessionOrError(e)); + waiter.setException(e); } - private PooledSession take() throws SpannerException { + @Override + public PooledSession get() { long currentTimeout = options.getInitialWaitForSessionTimeoutMillis(); while (true) { Span span = tracer.spanBuilder(WAIT_FOR_SESSION).startSpan(); try (Scope waitScope = tracer.withSpan(span)) { - SessionOrError s = pollUninterruptiblyWithTimeout(currentTimeout); + PooledSession s = pollUninterruptiblyWithTimeout(currentTimeout); if (s == null) { // Set the status to DEADLINE_EXCEEDED and retry. numWaiterTimeouts.incrementAndGet(); tracer.getCurrentSpan().setStatus(Status.DEADLINE_EXCEEDED); currentTimeout = Math.min(currentTimeout * 2, MAX_SESSION_WAIT_TIMEOUT); } else { - if (s.e != null) { - throw newSpannerException(s.e); - } - return s.session; + return s; } } catch (Exception e) { TraceUtil.setWithFailure(span, e); @@ -931,14 +1511,18 @@ private PooledSession take() throws SpannerException { } } - private SessionOrError pollUninterruptiblyWithTimeout(long timeoutMillis) { + private PooledSession pollUninterruptiblyWithTimeout(long timeoutMillis) { boolean interrupted = false; try { while (true) { try { - return waiter.poll(timeoutMillis, TimeUnit.MILLISECONDS); + return waiter.get(timeoutMillis, TimeUnit.MILLISECONDS); } catch (InterruptedException e) { interrupted = true; + } catch (TimeoutException e) { + return null; + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause()); } } } finally { @@ -1118,6 +1702,7 @@ private static enum Position { private final ScheduledExecutorService executor; private final ExecutorFactory executorFactory; private final ScheduledExecutorService prepareExecutor; + private final int prepareThreadPoolSize; final PoolMaintainer poolMaintainer; private final Clock clock; @@ -1146,10 +1731,10 @@ private static enum Position { private final LinkedList writePreparedSessions = new LinkedList<>(); @GuardedBy("lock") - private final Queue readWaiters = new LinkedList<>(); + private final Queue readWaiters = new LinkedList<>(); @GuardedBy("lock") - private final Queue readWriteWaiters = new LinkedList<>(); + private final Queue readWriteWaiters = new LinkedList<>(); @GuardedBy("lock") private int numSessionsBeingPrepared = 0; @@ -1183,6 +1768,9 @@ private static enum Position { @GuardedBy("lock") private final Set allSessions = new HashSet<>(); + @GuardedBy("lock") + private final Set checkedOutSessions = new HashSet<>(); + private final SessionConsumer sessionConsumer = new SessionConsumerImpl(); @VisibleForTesting Function idleSessionRemovedListener; @@ -1275,6 +1863,12 @@ private SessionPool( } @VisibleForTesting + int getNumberOfSessionsInUse() { + synchronized (lock) { + return numSessionsInUse; + } + } + long getNumberOfSessionsInProcessPrepared() { synchronized (lock) { return numSessionsInProcessPrepared; @@ -1297,9 +1891,9 @@ void removeFromPool(PooledSession session) { session.markClosing(); allSessions.remove(session); numIdleSessionsRemoved++; - if (idleSessionRemovedListener != null) { - idleSessionRemovedListener.apply(session); - } + } + if (idleSessionRemovedListener != null) { + idleSessionRemovedListener.apply(session); } } @@ -1437,10 +2031,10 @@ boolean isValid() { * session being returned to the pool or a new session being created. * */ - PooledSession getReadSession() throws SpannerException { + PooledSessionFuture getReadSession() throws SpannerException { Span span = Tracing.getTracer().getCurrentSpan(); span.addAnnotation("Acquiring session"); - Waiter waiter = null; + WaiterFuture waiter = null; PooledSession sess = null; synchronized (lock) { if (closureFuture != null) { @@ -1462,7 +2056,7 @@ PooledSession getReadSession() throws SpannerException { if (sess == null) { span.addAnnotation("No session available"); maybeCreateSession(); - waiter = new Waiter(); + waiter = new WaiterFuture(); readWaiters.add(waiter); } else { span.addAnnotation("Acquired read write session"); @@ -1470,18 +2064,8 @@ PooledSession getReadSession() throws SpannerException { } else { span.addAnnotation("Acquired read only session"); } + return checkoutSession(span, sess, waiter, false, false); } - if (waiter != null) { - logger.log( - Level.FINE, - "No session available in the pool. Blocking for one to become available/created"); - span.addAnnotation("Waiting for read only session to be available"); - sess = waiter.take(); - } - sess.markBusy(); - incrementNumSessionsInUse(); - span.addAnnotation(sessionAnnotation(sess)); - return sess; } /** @@ -1502,129 +2086,123 @@ PooledSession getReadSession() throws SpannerException { * to the pool which is then write prepared. * */ - PooledSession getReadWriteSession() { + PooledSessionFuture getReadWriteSession() { Span span = Tracing.getTracer().getCurrentSpan(); span.addAnnotation("Acquiring read write session"); PooledSession sess = null; - // Loop to retry SessionNotFoundExceptions that might occur during in-process prepare of a - // session. - while (true) { - Waiter waiter = null; - boolean inProcessPrepare = stopAutomaticPrepare; - synchronized (lock) { - if (closureFuture != null) { - span.addAnnotation("Pool has been closed"); - throw new IllegalStateException("Pool has been closed", closedException); - } - if (resourceNotFoundException != null) { - span.addAnnotation("Database has been deleted"); - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.NOT_FOUND, - String.format( - "The session pool has been invalidated because a previous RPC returned 'Database not found': %s", - resourceNotFoundException.getMessage()), - resourceNotFoundException); - } - sess = writePreparedSessions.poll(); - if (sess == null) { - if (!inProcessPrepare && numSessionsBeingPrepared <= prepareThreadPoolSize) { - if (numSessionsBeingPrepared <= readWriteWaiters.size()) { - PooledSession readSession = readSessions.poll(); - if (readSession != null) { - span.addAnnotation( - "Acquired read only session. Preparing for read write transaction"); - prepareSession(readSession); - } else { - span.addAnnotation("No session available"); - maybeCreateSession(); - } - } - } else { - inProcessPrepare = true; - numSessionsInProcessPrepared++; + WaiterFuture waiter = null; + boolean inProcessPrepare = stopAutomaticPrepare; + synchronized (lock) { + if (closureFuture != null) { + span.addAnnotation("Pool has been closed"); + throw new IllegalStateException("Pool has been closed", closedException); + } + if (resourceNotFoundException != null) { + span.addAnnotation("Database has been deleted"); + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.NOT_FOUND, + String.format( + "The session pool has been invalidated because a previous RPC returned 'Database not found': %s", + resourceNotFoundException.getMessage()), + resourceNotFoundException); + } + sess = writePreparedSessions.poll(); + if (sess == null) { + if (!inProcessPrepare && numSessionsBeingPrepared <= prepareThreadPoolSize) { + if (numSessionsBeingPrepared <= readWriteWaiters.size()) { PooledSession readSession = readSessions.poll(); if (readSession != null) { - // Create a read/write transaction in-process if there is already a queue for prepared - // sessions. This is more efficient than doing it asynchronously, as it scales with - // the number of user threads. The thread pool for asynchronously preparing sessions - // is fixed. span.addAnnotation( - "Acquired read only session. Preparing in-process for read write transaction"); - sess = readSession; + "Acquired read only session. Preparing for read write transaction"); + prepareSession(readSession); } else { span.addAnnotation("No session available"); maybeCreateSession(); } } - if (sess == null) { - waiter = new Waiter(); - if (inProcessPrepare) { - // inProcessPrepare=true means that we have already determined that the queue for - // preparing read/write sessions is larger than the number of threads in the prepare - // thread pool, and that it's more efficient to do the prepare in-process. We will - // therefore create a waiter for a read-only session, even though a read/write session - // has been requested. - readWaiters.add(waiter); - } else { - readWriteWaiters.add(waiter); - } - } } else { - span.addAnnotation("Acquired read write session"); - } - } - if (waiter != null) { - logger.log( - Level.FINE, - "No session available in the pool. Blocking for one to become available/created"); - span.addAnnotation("Waiting for read write session to be available"); - sess = waiter.take(); - } - if (inProcessPrepare) { - try { - sess.prepareReadWriteTransaction(); - // Session prepare succeeded, restart automatic prepare if it had been stopped. - synchronized (lock) { - stopAutomaticPrepare = false; - } - } catch (Throwable t) { - SpannerException e = newSpannerException(t); - if (!isClosed()) { - handlePrepareSessionFailure(e, sess, false); + inProcessPrepare = true; + numSessionsInProcessPrepared++; + PooledSession readSession = readSessions.poll(); + if (readSession != null) { + // Create a read/write transaction in-process if there is already a queue for prepared + // sessions. This is more efficient than doing it asynchronously, as it scales with + // the number of user threads. The thread pool for asynchronously preparing sessions + // is fixed. + span.addAnnotation( + "Acquired read only session. Preparing in-process for read write transaction"); + sess = readSession; + } else { + span.addAnnotation("No session available"); + maybeCreateSession(); } - sess = null; - if (!isSessionNotFound(e)) { - throw e; + } + if (sess == null) { + waiter = new WaiterFuture(); + if (inProcessPrepare) { + // inProcessPrepare=true means that we have already determined that the queue for + // preparing read/write sessions is larger than the number of threads in the prepare + // thread pool, and that it's more efficient to do the prepare in-process. We will + // therefore create a waiter for a read-only session, even though a read/write session + // has been requested. + readWaiters.add(waiter); + } else { + readWriteWaiters.add(waiter); } } + } else { + span.addAnnotation("Acquired read write session"); } - if (sess != null) { - break; - } + return checkoutSession(span, sess, waiter, true, inProcessPrepare); } - sess.markBusy(); - incrementNumSessionsInUse(); - span.addAnnotation(sessionAnnotation(sess)); - return sess; } - PooledSession replaceReadSession(SessionNotFoundException e, PooledSession session) { + private PooledSessionFuture checkoutSession( + final Span span, + final PooledSession readySession, + WaiterFuture waiter, + boolean write, + final boolean inProcessPrepare) { + ListenableFuture sessionFuture; + if (waiter != null) { + logger.log( + Level.FINE, + "No session available in the pool. Blocking for one to become available/created"); + span.addAnnotation( + String.format( + "Waiting for %s session to be available", write ? "read write" : "read only")); + sessionFuture = waiter; + } else { + SettableFuture fut = SettableFuture.create(); + fut.set(readySession); + sessionFuture = fut; + } + ForwardingListenablePooledSessionFuture forwardingFuture = + new ForwardingListenablePooledSessionFuture(sessionFuture, inProcessPrepare, span); + PooledSessionFuture res = createPooledSessionFuture(forwardingFuture, span); + res.markCheckedOut(); + return res; + } + + PooledSessionFuture replaceReadSession(SessionNotFoundException e, PooledSessionFuture session) { return replaceSession(e, session, false); } - PooledSession replaceReadWriteSession(SessionNotFoundException e, PooledSession session) { + PooledSessionFuture replaceReadWriteSession( + SessionNotFoundException e, PooledSessionFuture session) { return replaceSession(e, session, true); } - private PooledSession replaceSession( - SessionNotFoundException e, PooledSession session, boolean write) { - if (!options.isFailIfSessionNotFound() && session.allowReplacing) { + private PooledSessionFuture replaceSession( + SessionNotFoundException e, PooledSessionFuture session, boolean write) { + if (!options.isFailIfSessionNotFound() && session.get().allowReplacing) { synchronized (lock) { numSessionsInUse--; numSessionsReleased++; + checkedOutSessions.remove(session); } session.leakedException = null; - invalidateSession(session); + invalidateSession(session.get()); return write ? getReadWriteSession() : getReadSession(); } else { throw e; @@ -1787,7 +2365,7 @@ ListenableFuture closeAsync(ClosedException closedException) { } this.closedException = closedException; // Fail all pending waiters. - Waiter waiter = readWaiters.poll(); + WaiterFuture waiter = readWaiters.poll(); while (waiter != null) { waiter.put(newSpannerException(ErrorCode.INTERNAL, "Client has been closed")); waiter = readWaiters.poll(); @@ -1821,10 +2399,16 @@ public void run() { } } }); - for (final PooledSession session : ImmutableList.copyOf(allSessions)) { + for (PooledSessionFuture session : checkedOutSessions) { if (session.leakedException != null) { - logger.log(Level.WARNING, "Leaked session", session.leakedException); + if (options.isFailOnSessionLeak()) { + throw session.leakedException; + } else { + logger.log(Level.WARNING, "Leaked session", session.leakedException); + } } + } + for (final PooledSession session : ImmutableList.copyOf(allSessions)) { if (session.state != SessionState.CLOSING) { closeSessionAsync(session); } @@ -1894,7 +2478,7 @@ public void run() { } } }, - executor); + MoreExecutors.directExecutor()); return res; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java new file mode 100644 index 0000000000..55b6102a27 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolAsyncTransactionManager.java @@ -0,0 +1,216 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiAsyncFunction; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutureCallback; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; +import com.google.cloud.spanner.SessionPool.PooledSessionFuture; +import com.google.cloud.spanner.TransactionContextFutureImpl.CommittableAsyncTransactionManager; +import com.google.cloud.spanner.TransactionManager.TransactionState; +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.MoreExecutors; +import javax.annotation.concurrent.GuardedBy; + +class SessionPoolAsyncTransactionManager implements CommittableAsyncTransactionManager { + private final Object lock = new Object(); + + @GuardedBy("lock") + private TransactionState txnState; + + private volatile PooledSessionFuture session; + private final SettableApiFuture delegate = + SettableApiFuture.create(); + + SessionPoolAsyncTransactionManager(PooledSessionFuture session) { + this.session = session; + this.session.addListener( + new Runnable() { + @Override + public void run() { + try { + delegate.set( + SessionPoolAsyncTransactionManager.this.session.get().transactionManagerAsync()); + } catch (Throwable t) { + delegate.setException(t); + } + } + }, + MoreExecutors.directExecutor()); + } + + @Override + public void close() { + delegate.addListener( + new Runnable() { + @Override + public void run() { + session.close(); + } + }, + MoreExecutors.directExecutor()); + } + + @Override + public TransactionContextFuture beginAsync() { + synchronized (lock) { + Preconditions.checkState(txnState == null, "begin can only be called once"); + txnState = TransactionState.STARTED; + } + final SettableApiFuture delegateTxnFuture = SettableApiFuture.create(); + ApiFutures.addCallback( + delegate, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + delegateTxnFuture.setException(t); + } + + @Override + public void onSuccess(AsyncTransactionManagerImpl result) { + ApiFutures.addCallback( + result.beginAsync(), + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + delegateTxnFuture.setException(t); + } + + @Override + public void onSuccess(TransactionContext result) { + delegateTxnFuture.set(result); + } + }, + MoreExecutors.directExecutor()); + } + }, + MoreExecutors.directExecutor()); + return new TransactionContextFutureImpl(this, delegateTxnFuture); + } + + @Override + public void onError(Throwable t) { + if (t instanceof AbortedException) { + synchronized (lock) { + txnState = TransactionState.ABORTED; + } + } + } + + @Override + public ApiFuture commitAsync() { + synchronized (lock) { + Preconditions.checkState( + txnState == TransactionState.STARTED, + "commit can only be invoked if the transaction is in progress. Current state: " + + txnState); + txnState = TransactionState.COMMITTED; + } + return ApiFutures.transformAsync( + delegate, + new ApiAsyncFunction() { + @Override + public ApiFuture apply(AsyncTransactionManagerImpl input) throws Exception { + final SettableApiFuture res = SettableApiFuture.create(); + ApiFutures.addCallback( + input.commitAsync(), + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + synchronized (lock) { + if (t instanceof AbortedException) { + txnState = TransactionState.ABORTED; + } else { + txnState = TransactionState.COMMIT_FAILED; + } + } + res.setException(t); + } + + @Override + public void onSuccess(Timestamp result) { + res.set(result); + } + }, + MoreExecutors.directExecutor()); + return res; + } + }, + MoreExecutors.directExecutor()); + } + + @Override + public ApiFuture rollbackAsync() { + synchronized (lock) { + Preconditions.checkState( + txnState == TransactionState.STARTED, + "rollback can only be called if the transaction is in progress"); + txnState = TransactionState.ROLLED_BACK; + } + return ApiFutures.transformAsync( + delegate, + new ApiAsyncFunction() { + @Override + public ApiFuture apply(AsyncTransactionManagerImpl input) throws Exception { + ApiFuture res = input.rollbackAsync(); + res.addListener( + new Runnable() { + @Override + public void run() { + session.close(); + } + }, + MoreExecutors.directExecutor()); + return res; + } + }, + MoreExecutors.directExecutor()); + } + + @Override + public TransactionContextFuture resetForRetryAsync() { + synchronized (lock) { + Preconditions.checkState( + txnState == TransactionState.ABORTED, + "resetForRetry can only be called after the transaction aborted."); + txnState = TransactionState.STARTED; + } + return new TransactionContextFutureImpl( + this, + ApiFutures.transformAsync( + delegate, + new ApiAsyncFunction() { + @Override + public ApiFuture apply(AsyncTransactionManagerImpl input) + throws Exception { + return input.resetForRetryAsync(); + } + }, + MoreExecutors.directExecutor())); + } + + @Override + public TransactionState getState() { + synchronized (lock) { + return txnState; + } + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java index 17295a38ab..57dbd4debd 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java @@ -37,6 +37,7 @@ public class SessionPoolOptions { private final int keepAliveIntervalMinutes; private final Duration removeInactiveSessionAfter; private final ActionOnSessionNotFound actionOnSessionNotFound; + private final ActionOnSessionLeak actionOnSessionLeak; private final long initialWaitForSessionTimeoutMillis; private SessionPoolOptions(Builder builder) { @@ -50,6 +51,7 @@ private SessionPoolOptions(Builder builder) { this.writeSessionsFraction = builder.writeSessionsFraction; this.actionOnExhaustion = builder.actionOnExhaustion; this.actionOnSessionNotFound = builder.actionOnSessionNotFound; + this.actionOnSessionLeak = builder.actionOnSessionLeak; this.initialWaitForSessionTimeoutMillis = builder.initialWaitForSessionTimeoutMillis; this.loopFrequency = builder.loopFrequency; this.keepAliveIntervalMinutes = builder.keepAliveIntervalMinutes; @@ -106,6 +108,11 @@ boolean isFailIfSessionNotFound() { return actionOnSessionNotFound == ActionOnSessionNotFound.FAIL; } + @VisibleForTesting + boolean isFailOnSessionLeak() { + return actionOnSessionLeak == ActionOnSessionLeak.FAIL; + } + public static Builder newBuilder() { return new Builder(); } @@ -120,6 +127,11 @@ private static enum ActionOnSessionNotFound { FAIL; } + private static enum ActionOnSessionLeak { + WARN, + FAIL; + } + /** Builder for creating SessionPoolOptions. */ public static class Builder { private boolean minSessionsSet = false; @@ -131,6 +143,7 @@ public static class Builder { private ActionOnExhaustion actionOnExhaustion = DEFAULT_ACTION; private long initialWaitForSessionTimeoutMillis = 30_000L; private ActionOnSessionNotFound actionOnSessionNotFound = ActionOnSessionNotFound.RETRY; + private ActionOnSessionLeak actionOnSessionLeak = ActionOnSessionLeak.WARN; private long loopFrequency = 10 * 1000L; private int keepAliveIntervalMinutes = 30; private Duration removeInactiveSessionAfter = Duration.ofMinutes(55L); @@ -240,6 +253,12 @@ Builder setFailIfSessionNotFound() { return this; } + @VisibleForTesting + Builder setFailOnSessionLeak() { + this.actionOnSessionLeak = ActionOnSessionLeak.FAIL; + return this; + } + /** * Fraction of sessions to be kept prepared for write transactions. This is an optimisation to * avoid the cost of sending a BeginTransaction() rpc. If all such sessions are in use and a diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Spanner.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Spanner.java index 0c6bec4ea8..52c35cb713 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Spanner.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Spanner.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.Service; /** @@ -108,4 +109,7 @@ public interface Spanner extends Service, AutoCloseable { /** @return true if this {@link Spanner} object is closed. */ boolean isClosed(); + + /** @return the {@link ExecutorProvider} that is used for asynchronous queries and operations. */ + ExecutorProvider getAsyncExecutorProvider(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 4e937459cf..2d034eda88 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.core.GaxProperties; import com.google.api.gax.paging.Page; import com.google.cloud.BaseService; @@ -23,9 +24,11 @@ import com.google.cloud.PageImpl.NextPageFetcher; import com.google.cloud.grpc.GrpcTransportOptions; import com.google.cloud.spanner.SessionClient.SessionId; +import com.google.cloud.spanner.SpannerOptions.CloseableExecutorProvider; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.cloud.spanner.spi.v1.SpannerRpc.Paginated; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; @@ -86,6 +89,8 @@ private static String nextDatabaseClientId(DatabaseId databaseId) { @GuardedBy("this") private final Map dbClients = new HashMap<>(); + private final CloseableExecutorProvider asyncExecutorProvider; + @GuardedBy("this") private final List invalidatedDbClients = new ArrayList<>(); @@ -116,6 +121,10 @@ static final class ClosedException extends RuntimeException { SpannerImpl(SpannerRpc gapicRpc, SpannerOptions options) { super(options); this.gapicRpc = gapicRpc; + this.asyncExecutorProvider = + MoreObjects.firstNonNull( + options.getAsyncExecutorProvider(), + SpannerOptions.createDefaultAsyncExecutorProvider()); this.dbAdminClient = new DatabaseAdminClientImpl(options.getProjectId(), gapicRpc); this.instanceClient = new InstanceAdminClientImpl(options.getProjectId(), gapicRpc, dbAdminClient); @@ -140,6 +149,13 @@ QueryOptions getDefaultQueryOptions(DatabaseId databaseId) { return getOptions().getDefaultQueryOptions(databaseId); } + /** + * Returns the {@link ExecutorProvider} to use for async methods that need a background executor. + */ + public ExecutorProvider getAsyncExecutorProvider() { + return asyncExecutorProvider; + } + SessionImpl sessionWithId(String name) { Preconditions.checkArgument(!Strings.isNullOrEmpty(name), "name is null or empty"); SessionId id = SessionId.of(name); @@ -251,6 +267,7 @@ void close(long timeout, TimeUnit unit) { sessionClient.close(); } sessionClients.clear(); + asyncExecutorProvider.close(); try { gapicRpc.shutdown(); } catch (RuntimeException e) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index edeadb7b90..35a288530f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -17,6 +17,7 @@ package com.google.cloud.spanner; import com.google.api.core.ApiFunction; +import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.grpc.GrpcInterceptorProvider; import com.google.api.gax.longrunning.OperationSnapshot; import com.google.api.gax.longrunning.OperationTimedPollAlgorithm; @@ -40,10 +41,12 @@ import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.cloud.spanner.v1.SpannerSettings; import com.google.cloud.spanner.v1.stub.SpannerStubSettings; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.spanner.admin.database.v1.CreateBackupRequest; import com.google.spanner.admin.database.v1.CreateDatabaseRequest; import com.google.spanner.admin.database.v1.RestoreDatabaseRequest; @@ -59,6 +62,11 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.threeten.bp.Duration; @@ -107,6 +115,7 @@ public class SpannerOptions extends ServiceOptions { private final Map mergedQueryOptions; private final CallCredentialsProvider callCredentialsProvider; + private final CloseableExecutorProvider asyncExecutorProvider; private final String compressorName; /** @@ -138,6 +147,67 @@ public ServiceRpc create(SpannerOptions options) { } } + private static final AtomicInteger DEFAULT_POOL_COUNT = new AtomicInteger(); + + /** {@link ExecutorProvider} that is used for {@link AsyncResultSet}. */ + interface CloseableExecutorProvider extends ExecutorProvider, AutoCloseable { + /** Overridden to suppress the throws declaration of the super interface. */ + @Override + public void close(); + } + + static class FixedCloseableExecutorProvider implements CloseableExecutorProvider { + private final ScheduledExecutorService executor; + + private FixedCloseableExecutorProvider(ScheduledExecutorService executor) { + this.executor = Preconditions.checkNotNull(executor); + } + + @Override + public void close() { + executor.shutdown(); + } + + @Override + public ScheduledExecutorService getExecutor() { + return executor; + } + + @Override + public boolean shouldAutoClose() { + return false; + } + + /** Creates a FixedCloseableExecutorProvider. */ + static FixedCloseableExecutorProvider create(ScheduledExecutorService executor) { + return new FixedCloseableExecutorProvider(executor); + } + } + + /** + * Default {@link ExecutorProvider} for high-level async calls that need an executor. The default + * uses a cached thread pool containing a max of 8 threads. The pool is lazily initialized and + * will not create any threads if the user application does not use any async methods. It will + * also scale down the thread usage if the async load allows for that. + */ + @VisibleForTesting + static CloseableExecutorProvider createDefaultAsyncExecutorProvider() { + return createAsyncExecutorProvider(8, 60L, TimeUnit.SECONDS); + } + + @VisibleForTesting + static CloseableExecutorProvider createAsyncExecutorProvider( + int poolSize, long keepAliveTime, TimeUnit unit) { + String format = + String.format("spanner-async-pool-%d-thread-%%d", DEFAULT_POOL_COUNT.incrementAndGet()); + ThreadFactory threadFactory = + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(format).build(); + ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(poolSize, threadFactory); + executor.setKeepAliveTime(keepAliveTime, unit); + executor.allowCoreThreadTimeOut(true); + return FixedCloseableExecutorProvider.create(executor); + } + private SpannerOptions(Builder builder) { super(SpannerFactory.class, SpannerRpcFactory.class, builder, new SpannerDefaults()); numChannels = builder.numChannels; @@ -178,6 +248,7 @@ private SpannerOptions(Builder builder) { this.mergedQueryOptions = ImmutableMap.copyOf(merged); } callCredentialsProvider = builder.callCredentialsProvider; + asyncExecutorProvider = builder.asyncExecutorProvider; compressorName = builder.compressorName; } @@ -243,6 +314,7 @@ public static class Builder private boolean autoThrottleAdministrativeRequests = false; private Map defaultQueryOptions = new HashMap<>(); private CallCredentialsProvider callCredentialsProvider; + private CloseableExecutorProvider asyncExecutorProvider; private String compressorName; private String emulatorHost = System.getenv("SPANNER_EMULATOR_HOST"); @@ -304,6 +376,11 @@ private Builder() { Builder(SpannerOptions options) { super(options); + if (options.getHost() != null + && this.emulatorHost != null + && !options.getHost().equals(this.emulatorHost)) { + this.emulatorHost = null; + } this.numChannels = options.numChannels; this.sessionPoolOptions = options.sessionPoolOptions; this.prefetchChunks = options.prefetchChunks; @@ -315,6 +392,7 @@ private Builder() { this.autoThrottleAdministrativeRequests = options.autoThrottleAdministrativeRequests; this.defaultQueryOptions = options.defaultQueryOptions; this.callCredentialsProvider = options.callCredentialsProvider; + this.asyncExecutorProvider = options.asyncExecutorProvider; this.compressorName = options.compressorName; this.channelProvider = options.channelProvider; this.channelConfigurator = options.channelConfigurator; @@ -736,6 +814,10 @@ public QueryOptions getDefaultQueryOptions(DatabaseId databaseId) { return options; } + CloseableExecutorProvider getAsyncExecutorProvider() { + return asyncExecutorProvider; + } + public int getPrefetchChunks() { return prefetchChunks; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TraceUtil.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TraceUtil.java index c5488ac55d..0d429661ad 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TraceUtil.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TraceUtil.java @@ -40,7 +40,7 @@ static Map getTransactionAnnotations(Transaction t) { AttributeValue.stringAttributeValue(Timestamp.fromProto(t.getReadTimestamp()).toString())); } - static ImmutableMap getExceptionAnnotations(RuntimeException e) { + static ImmutableMap getExceptionAnnotations(Throwable e) { if (e instanceof SpannerException) { return ImmutableMap.of( "Status", diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java index a529c4c492..0b4a92f989 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java @@ -16,6 +16,8 @@ package com.google.cloud.spanner; +import com.google.api.core.ApiFuture; + /** * Context for a single attempt of a locking read-write transaction. This type of transaction is the * only way to write data into Cloud Spanner; {@link Session#write(Iterable)} and {@link @@ -102,6 +104,17 @@ public interface TransactionContext extends ReadContext { */ long executeUpdate(Statement statement); + /** + * Same as {@link #executeUpdate(Statement)}, but is guaranteed to be non-blocking. If multiple + * asynchronous update statements are submitted to the same read/write transaction, the statements + * are guaranteed to be submitted to Cloud Spanner in the order that they were submitted in the + * client. This does however not guarantee that an asynchronous update statement will see the + * results of all previously submitted statements, as the execution of the statements can be + * parallel. If you rely on the results of a previous statement, you should block until the result + * of that statement is known and has been returned to the client. + */ + ApiFuture executeUpdateAsync(Statement statement); + /** * Executes a list of DML statements in a single request. The statements will be executed in order * and the semantics is the same as if each statement is executed by {@code executeUpdate} in a @@ -118,4 +131,15 @@ public interface TransactionContext extends ReadContext { * statement. The 3rd statement will not run. */ long[] batchUpdate(Iterable statements); + + /** + * Same as {@link #batchUpdate(Iterable)}, but is guaranteed to be non-blocking. If multiple + * asynchronous update statements are submitted to the same read/write transaction, the statements + * are guaranteed to be submitted to Cloud Spanner in the order that they were submitted in the + * client. This does however not guarantee that an asynchronous update statement will see the + * results of all previously submitted statements, as the execution of the statements can be + * parallel. If you rely on the results of a previous statement, you should block until the result + * of that statement is known and has been returned to the client. + */ + ApiFuture batchUpdateAsync(Iterable statements); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContextFutureImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContextFutureImpl.java new file mode 100644 index 0000000000..bc8262a535 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContextFutureImpl.java @@ -0,0 +1,258 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutureCallback; +import com.google.api.core.ApiFutures; +import com.google.api.core.ForwardingApiFuture; +import com.google.api.core.InternalApi; +import com.google.api.core.SettableApiFuture; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionFunction; +import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionStep; +import com.google.cloud.spanner.AsyncTransactionManager.CommitTimestampFuture; +import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.MoreExecutors; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +class TransactionContextFutureImpl extends ForwardingApiFuture + implements TransactionContextFuture { + + @InternalApi + interface CommittableAsyncTransactionManager extends AsyncTransactionManager { + void onError(Throwable t); + + ApiFuture commitAsync(); + } + /** + * {@link ApiFuture} that returns a commit timestamp. Any {@link AbortedException} that is thrown + * by either the commit call or any other rpc during the transaction will be thrown by the {@link + * #get()} method of this future as an {@link AbortedException} and not as an {@link + * ExecutionException} with an {@link AbortedException} as its cause. + */ + static class CommitTimestampFutureImpl extends ForwardingApiFuture + implements CommitTimestampFuture { + CommitTimestampFutureImpl(ApiFuture delegate) { + super(Preconditions.checkNotNull(delegate)); + } + + @Override + public Timestamp get() throws AbortedException, ExecutionException, InterruptedException { + try { + return super.get(); + } catch (ExecutionException e) { + if (e.getCause() != null && e.getCause() instanceof AbortedException) { + throw (AbortedException) e.getCause(); + } + throw e; + } + } + + @Override + public Timestamp get(long timeout, TimeUnit unit) + throws AbortedException, ExecutionException, InterruptedException, TimeoutException { + try { + return super.get(timeout, unit); + } catch (ExecutionException e) { + if (e.getCause() != null && e.getCause() instanceof AbortedException) { + throw (AbortedException) e.getCause(); + } + throw e; + } + } + } + + class AsyncTransactionStatementImpl extends ForwardingApiFuture + implements AsyncTransactionStep { + final ApiFuture txnFuture; + final SettableApiFuture statementResult; + + AsyncTransactionStatementImpl( + final ApiFuture txnFuture, + ApiFuture input, + final AsyncTransactionFunction function, + Executor executor) { + this(SettableApiFuture.create(), txnFuture, input, function, executor); + } + + AsyncTransactionStatementImpl( + SettableApiFuture delegate, + final ApiFuture txnFuture, + ApiFuture input, + final AsyncTransactionFunction function, + final Executor executor) { + super(delegate); + this.statementResult = delegate; + this.txnFuture = txnFuture; + ApiFutures.addCallback( + input, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + mgr.onError(t); + txnResult.setException(t); + } + + @Override + public void onSuccess(I result) { + try { + ApiFutures.addCallback( + runAsyncTransactionFunction(function, txnFuture.get(), result, executor), + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + mgr.onError(t); + txnResult.setException(t); + } + + @Override + public void onSuccess(O result) { + statementResult.set(result); + } + }, + MoreExecutors.directExecutor()); + } catch (Throwable t) { + mgr.onError(t); + txnResult.setException(t); + } + } + }, + MoreExecutors.directExecutor()); + } + + @Override + public AsyncTransactionStatementImpl then( + AsyncTransactionFunction next, Executor executor) { + return new AsyncTransactionStatementImpl<>(txnFuture, statementResult, next, executor); + } + + @Override + public CommitTimestampFuture commitAsync() { + ApiFutures.addCallback( + statementResult, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + mgr.onError(t); + txnResult.setException(t); + } + + @Override + public void onSuccess(O result) { + ApiFutures.addCallback( + mgr.commitAsync(), + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + mgr.onError(t); + txnResult.setException(t); + } + + @Override + public void onSuccess(Timestamp result) { + txnResult.set(result); + } + }, + MoreExecutors.directExecutor()); + } + }, + MoreExecutors.directExecutor()); + return new CommitTimestampFutureImpl(txnResult); + } + } + + static ApiFuture runAsyncTransactionFunction( + final AsyncTransactionFunction function, + final TransactionContext txn, + final I input, + Executor executor) + throws Exception { + // Shortcut for common path. + if (executor == MoreExecutors.directExecutor()) { + return Preconditions.checkNotNull( + function.apply(txn, input), + "AsyncTransactionFunction returned . Did you mean to return ApiFutures.immediateFuture(null)?"); + } else { + final SettableApiFuture res = SettableApiFuture.create(); + executor.execute( + new Runnable() { + @Override + public void run() { + try { + ApiFuture functionResult = + Preconditions.checkNotNull( + function.apply(txn, input), + "AsyncTransactionFunction returned . Did you mean to return ApiFutures.immediateFuture(null)?"); + ApiFutures.addCallback( + functionResult, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + res.setException(t); + } + + @Override + public void onSuccess(O result) { + res.set(result); + } + }, + MoreExecutors.directExecutor()); + } catch (Throwable t) { + res.setException(t); + } + } + }); + return res; + } + } + + final CommittableAsyncTransactionManager mgr; + final SettableApiFuture txnResult = SettableApiFuture.create(); + + TransactionContextFutureImpl( + CommittableAsyncTransactionManager mgr, ApiFuture txnFuture) { + super(txnFuture); + this.mgr = mgr; + } + + @Override + public AsyncTransactionStatementImpl then( + AsyncTransactionFunction function, Executor executor) { + final SettableApiFuture input = SettableApiFuture.create(); + ApiFutures.addCallback( + this, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + mgr.onError(t); + input.setException(t); + } + + @Override + public void onSuccess(TransactionContext result) { + input.set(null); + } + }, + MoreExecutors.directExecutor()); + return new AsyncTransactionStatementImpl<>(this, input, function, executor); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java index bdf7ec954f..8dbab88314 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java @@ -29,14 +29,23 @@ final class TransactionManagerImpl implements TransactionManager, SessionTransac private static final Tracer tracer = Tracing.getTracer(); private final SessionImpl session; - private final Span span; + private Span span; private TransactionRunnerImpl.TransactionContextImpl txn; private TransactionState txnState; - TransactionManagerImpl(SessionImpl session) { + TransactionManagerImpl(SessionImpl session, Span span) { this.session = session; - this.span = Tracing.getTracer().getCurrentSpan(); + this.span = span; + } + + Span getSpan() { + return span; + } + + @Override + public void setSpan(Span span) { + this.span = span; } @Override diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index cfa8b73c4a..c10b713285 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -21,18 +21,29 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.api.core.ApiAsyncFunction; +import com.google.api.core.ApiFunction; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.Options.QueryOption; +import com.google.cloud.spanner.Options.ReadOption; import com.google.cloud.spanner.SessionImpl.SessionTransaction; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; import com.google.rpc.Code; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; +import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.TransactionSelector; import io.opencensus.common.Scope; @@ -43,6 +54,8 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; @@ -76,6 +89,57 @@ static Builder newBuilder() { return new Builder(); } + /** + * {@link AsyncResultSet} implementation that keeps track of the async operations that are still + * running for this {@link TransactionContext} and that should finish before the {@link + * TransactionContext} can commit and release its session back into the pool. + */ + private class TransactionContextAsyncResultSetImpl extends ForwardingAsyncResultSet + implements ListenableAsyncResultSet { + private TransactionContextAsyncResultSetImpl(ListenableAsyncResultSet delegate) { + super(delegate); + } + + @Override + public ApiFuture setCallback(Executor exec, ReadyCallback cb) { + Runnable listener = + new Runnable() { + @Override + public void run() { + decreaseAsyncOperations(); + } + }; + try { + increaseAsynOperations(); + addListener(listener); + return super.setCallback(exec, cb); + } catch (Throwable t) { + removeListener(listener); + decreaseAsyncOperations(); + throw t; + } + } + + @Override + public void addListener(Runnable listener) { + ((ListenableAsyncResultSet) this.delegate).addListener(listener); + } + + @Override + public void removeListener(Runnable listener) { + ((ListenableAsyncResultSet) this.delegate).removeListener(listener); + } + } + + @GuardedBy("lock") + private volatile boolean committing; + + @GuardedBy("lock") + private volatile SettableApiFuture finishedAsyncOperations = SettableApiFuture.create(); + + @GuardedBy("lock") + private volatile int runningAsyncOperations; + @GuardedBy("lock") private List mutations = new ArrayList<>(); @@ -92,25 +156,70 @@ static Builder newBuilder() { private TransactionContextImpl(Builder builder) { super(builder); this.transactionId = builder.transactionId; + this.finishedAsyncOperations.set(null); + } + + private void increaseAsynOperations() { + synchronized (lock) { + if (runningAsyncOperations == 0) { + finishedAsyncOperations = SettableApiFuture.create(); + } + runningAsyncOperations++; + } + } + + private void decreaseAsyncOperations() { + synchronized (lock) { + runningAsyncOperations--; + if (runningAsyncOperations == 0) { + finishedAsyncOperations.set(null); + } + } } void ensureTxn() { + try { + ensureTxnAsync().get(); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause() == null ? e : e.getCause()); + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } + } + + ApiFuture ensureTxnAsync() { + final SettableApiFuture res = SettableApiFuture.create(); if (transactionId == null || isAborted()) { span.addAnnotation("Creating Transaction"); - try { - transactionId = session.beginTransaction(); - span.addAnnotation( - "Transaction Creation Done", - ImmutableMap.of( - "Id", AttributeValue.stringAttributeValue(transactionId.toStringUtf8()))); - txnLogger.log( - Level.FINER, - "Started transaction {0}", - txnLogger.isLoggable(Level.FINER) ? transactionId.asReadOnlyByteBuffer() : null); - } catch (SpannerException e) { - span.addAnnotation("Transaction Creation Failed", TraceUtil.getExceptionAnnotations(e)); - throw e; - } + final ApiFuture fut = session.beginTransactionAsync(); + fut.addListener( + new Runnable() { + @Override + public void run() { + try { + transactionId = fut.get(); + span.addAnnotation( + "Transaction Creation Done", + ImmutableMap.of( + "Id", AttributeValue.stringAttributeValue(transactionId.toStringUtf8()))); + txnLogger.log( + Level.FINER, + "Started transaction {0}", + txnLogger.isLoggable(Level.FINER) + ? transactionId.asReadOnlyByteBuffer() + : null); + res.set(null); + } catch (ExecutionException e) { + span.addAnnotation( + "Transaction Creation Failed", + TraceUtil.getExceptionAnnotations(e.getCause() == null ? e : e.getCause())); + res.setException(e.getCause() == null ? e : e.getCause()); + } catch (InterruptedException e) { + res.setException(SpannerExceptionFactory.propagateInterrupt(e)); + } + } + }, + MoreExecutors.directExecutor()); } else { span.addAnnotation( "Transaction Initialized", @@ -120,41 +229,102 @@ void ensureTxn() { Level.FINER, "Using prepared transaction {0}", txnLogger.isLoggable(Level.FINER) ? transactionId.asReadOnlyByteBuffer() : null); + res.set(null); } + return res; } void commit() { - span.addAnnotation("Starting Commit"); - CommitRequest.Builder builder = - CommitRequest.newBuilder().setSession(session.getName()).setTransactionId(transactionId); - synchronized (lock) { - if (!mutations.isEmpty()) { - List mutationsProto = new ArrayList<>(); - Mutation.toProto(mutations, mutationsProto); - builder.addAllMutations(mutationsProto); - } - // Ensure that no call to buffer mutations that would be lost can succeed. - mutations = null; + try { + commitTimestamp = commitAsync().get(); + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause() == null ? e : e.getCause()); } - final CommitRequest commitRequest = builder.build(); - Span opSpan = tracer.spanBuilderWithExplicitParent(SpannerImpl.COMMIT, span).startSpan(); - try (Scope s = tracer.withSpan(opSpan)) { - CommitResponse commitResponse = rpc.commit(commitRequest, session.getOptions()); - if (!commitResponse.hasCommitTimestamp()) { - throw newSpannerException( - ErrorCode.INTERNAL, "Missing commitTimestamp:\n" + session.getName()); - } - commitTimestamp = Timestamp.fromProto(commitResponse.getCommitTimestamp()); - opSpan.end(TraceUtil.END_SPAN_OPTIONS); - } catch (RuntimeException e) { - span.addAnnotation("Commit Failed", TraceUtil.getExceptionAnnotations(e)); - TraceUtil.endSpanWithFailure(opSpan, e); - if (e instanceof SpannerException) { - onError((SpannerException) e); - } - throw e; + } + + ApiFuture commitAsync() { + final SettableApiFuture res = SettableApiFuture.create(); + final SettableApiFuture latch; + synchronized (lock) { + latch = finishedAsyncOperations; } - span.addAnnotation("Commit Done"); + latch.addListener( + new Runnable() { + @Override + public void run() { + try { + latch.get(); + CommitRequest.Builder builder = + CommitRequest.newBuilder() + .setSession(session.getName()) + .setTransactionId(transactionId); + synchronized (lock) { + if (!mutations.isEmpty()) { + List mutationsProto = new ArrayList<>(); + Mutation.toProto(mutations, mutationsProto); + builder.addAllMutations(mutationsProto); + } + // Ensure that no call to buffer mutations that would be lost can succeed. + mutations = null; + } + final CommitRequest commitRequest = builder.build(); + span.addAnnotation("Starting Commit"); + final Span opSpan = + tracer.spanBuilderWithExplicitParent(SpannerImpl.COMMIT, span).startSpan(); + final ApiFuture commitFuture = + rpc.commitAsync(commitRequest, session.getOptions()); + commitFuture.addListener( + tracer.withSpan( + opSpan, + new Runnable() { + @Override + public void run() { + try { + CommitResponse commitResponse = commitFuture.get(); + if (!commitResponse.hasCommitTimestamp()) { + throw newSpannerException( + ErrorCode.INTERNAL, + "Missing commitTimestamp:\n" + session.getName()); + } + Timestamp ts = + Timestamp.fromProto(commitResponse.getCommitTimestamp()); + span.addAnnotation("Commit Done"); + opSpan.end(TraceUtil.END_SPAN_OPTIONS); + res.set(ts); + } catch (Throwable e) { + if (e instanceof ExecutionException) { + e = + SpannerExceptionFactory.newSpannerException( + e.getCause() == null ? e : e.getCause()); + } else if (e instanceof InterruptedException) { + e = + SpannerExceptionFactory.propagateInterrupt( + (InterruptedException) e); + } else { + e = SpannerExceptionFactory.newSpannerException(e); + } + span.addAnnotation( + "Commit Failed", TraceUtil.getExceptionAnnotations(e)); + TraceUtil.endSpanWithFailure(opSpan, e); + onError((SpannerException) e); + res.setException(e); + } + } + }), + MoreExecutors.directExecutor()); + } catch (InterruptedException e) { + res.setException(SpannerExceptionFactory.propagateInterrupt(e)); + } catch (ExecutionException e) { + res.setException( + SpannerExceptionFactory.newSpannerException( + e.getCause() == null ? e : e.getCause())); + } + } + }, + MoreExecutors.directExecutor()); + return res; } Timestamp commitTimestamp() { @@ -190,6 +360,25 @@ void rollback() { } } + ApiFuture rollbackAsync() { + span.addAnnotation("Starting Rollback"); + return ApiFutures.transformAsync( + rpc.rollbackAsync( + RollbackRequest.newBuilder() + .setSession(session.getName()) + .setTransactionId(transactionId) + .build(), + session.getOptions()), + new ApiAsyncFunction() { + @Override + public ApiFuture apply(Empty input) throws Exception { + span.addAnnotation("Rollback Done"); + return ApiFutures.immediateFuture(null); + } + }, + MoreExecutors.directExecutor()); + } + @Nullable @Override TransactionSelector getTransactionSelector() { @@ -252,6 +441,61 @@ public long executeUpdate(Statement statement) { } } + @Override + public ApiFuture executeUpdateAsync(Statement statement) { + beforeReadOrQuery(); + final ExecuteSqlRequest.Builder builder = + getExecuteSqlRequestBuilder(statement, QueryMode.NORMAL); + ApiFuture resultSet; + try { + // Register the update as an async operation that must finish before the transaction may + // commit. + increaseAsynOperations(); + resultSet = rpc.executeQueryAsync(builder.build(), session.getOptions()); + } catch (Throwable t) { + decreaseAsyncOperations(); + throw t; + } + ApiFuture updateCount = + ApiFutures.transform( + resultSet, + new ApiFunction() { + @Override + public Long apply(ResultSet input) { + if (!input.hasStats()) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.INVALID_ARGUMENT, + "DML response missing stats possibly due to non-DML statement as input"); + } + // For standard DML, using the exact row count. + return input.getStats().getRowCountExact(); + } + }, + MoreExecutors.directExecutor()); + updateCount = + ApiFutures.catching( + updateCount, + Throwable.class, + new ApiFunction() { + @Override + public Long apply(Throwable input) { + SpannerException e = SpannerExceptionFactory.newSpannerException(input); + onError(e); + throw e; + } + }, + MoreExecutors.directExecutor()); + updateCount.addListener( + new Runnable() { + @Override + public void run() { + decreaseAsyncOperations(); + } + }, + MoreExecutors.directExecutor()); + return updateCount; + } + @Override public long[] batchUpdate(Iterable statements) { beforeReadOrQuery(); @@ -281,11 +525,91 @@ public long[] batchUpdate(Iterable statements) { throw e; } } + + @Override + public ApiFuture batchUpdateAsync(Iterable statements) { + beforeReadOrQuery(); + final ExecuteBatchDmlRequest.Builder builder = getExecuteBatchDmlRequestBuilder(statements); + ApiFuture response; + try { + // Register the update as an async operation that must finish before the transaction may + // commit. + increaseAsynOperations(); + response = rpc.executeBatchDmlAsync(builder.build(), session.getOptions()); + } catch (Throwable t) { + decreaseAsyncOperations(); + throw t; + } + final ApiFuture updateCounts = + ApiFutures.transform( + response, + new ApiFunction() { + @Override + public long[] apply(ExecuteBatchDmlResponse input) { + long[] results = new long[input.getResultSetsCount()]; + for (int i = 0; i < input.getResultSetsCount(); ++i) { + results[i] = input.getResultSets(i).getStats().getRowCountExact(); + } + // If one of the DML statements was aborted, we should throw an aborted exception. + // In all other cases, we should throw a BatchUpdateException. + if (input.getStatus().getCode() == Code.ABORTED_VALUE) { + throw newSpannerException( + ErrorCode.fromRpcStatus(input.getStatus()), input.getStatus().getMessage()); + } else if (input.getStatus().getCode() != 0) { + throw newSpannerBatchUpdateException( + ErrorCode.fromRpcStatus(input.getStatus()), + input.getStatus().getMessage(), + results); + } + return results; + } + }, + MoreExecutors.directExecutor()); + updateCounts.addListener( + new Runnable() { + @Override + public void run() { + try { + updateCounts.get(); + } catch (ExecutionException e) { + onError(SpannerExceptionFactory.newSpannerException(e.getCause())); + } catch (InterruptedException e) { + onError(SpannerExceptionFactory.propagateInterrupt(e)); + } finally { + decreaseAsyncOperations(); + } + } + }, + MoreExecutors.directExecutor()); + return updateCounts; + } + + private ListenableAsyncResultSet wrap(ListenableAsyncResultSet delegate) { + return new TransactionContextAsyncResultSetImpl(delegate); + } + + @Override + public ListenableAsyncResultSet readAsync( + String table, KeySet keys, Iterable columns, ReadOption... options) { + return wrap(super.readAsync(table, keys, columns, options)); + } + + @Override + public ListenableAsyncResultSet readUsingIndexAsync( + String table, String index, KeySet keys, Iterable columns, ReadOption... options) { + return wrap(super.readUsingIndexAsync(table, index, keys, columns, options)); + } + + @Override + public ListenableAsyncResultSet executeQueryAsync( + final Statement statement, final QueryOption... options) { + return wrap(super.executeQueryAsync(statement, options)); + } } private boolean blockNestedTxn = true; private final SessionImpl session; - private final Span span; + private Span span; private TransactionContextImpl txn; private volatile boolean isValid = true; @@ -297,10 +621,14 @@ public TransactionRunner allowNestedTransaction() { TransactionRunnerImpl(SessionImpl session, SpannerRpc rpc, int defaultPrefetchChunks) { this.session = session; - this.span = Tracing.getTracer().getCurrentSpan(); this.txn = session.newTransaction(); } + @Override + public void setSpan(Span span) { + this.span = span; + } + @Nullable @Override public T run(TransactionCallable callable) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 6f383b1f79..9f7eaa8e8f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -1078,8 +1078,14 @@ public void cancel(String message) { @Override public ResultSet executeQuery(ExecuteSqlRequest request, @Nullable Map options) { + return get(executeQueryAsync(request, options)); + } + + @Override + public ApiFuture executeQueryAsync( + ExecuteSqlRequest request, @Nullable Map options) { GrpcCallContext context = newCallContext(options, request.getSession()); - return get(spannerStub.executeSqlCallable().futureCall(request, context)); + return spannerStub.executeSqlCallable().futureCall(request, context); } @Override @@ -1127,30 +1133,52 @@ public void cancel(String message) { @Override public ExecuteBatchDmlResponse executeBatchDml( ExecuteBatchDmlRequest request, @Nullable Map options) { + return get(executeBatchDmlAsync(request, options)); + } + @Override + public ApiFuture executeBatchDmlAsync( + ExecuteBatchDmlRequest request, @Nullable Map options) { + GrpcCallContext context = newCallContext(options, request.getSession()); + return spannerStub.executeBatchDmlCallable().futureCall(request, context); + } + + @Override + public ApiFuture beginTransactionAsync( + BeginTransactionRequest request, @Nullable Map options) { GrpcCallContext context = newCallContext(options, request.getSession()); - return get(spannerStub.executeBatchDmlCallable().futureCall(request, context)); + return spannerStub.beginTransactionCallable().futureCall(request, context); } @Override public Transaction beginTransaction( BeginTransactionRequest request, @Nullable Map options) throws SpannerException { - GrpcCallContext context = newCallContext(options, request.getSession()); - return get(spannerStub.beginTransactionCallable().futureCall(request, context)); + return get(beginTransactionAsync(request, options)); + } + + @Override + public ApiFuture commitAsync( + CommitRequest commitRequest, @Nullable Map options) { + GrpcCallContext context = newCallContext(options, commitRequest.getSession()); + return spannerStub.commitCallable().futureCall(commitRequest, context); } @Override public CommitResponse commit(CommitRequest commitRequest, @Nullable Map options) throws SpannerException { - GrpcCallContext context = newCallContext(options, commitRequest.getSession()); - return get(spannerStub.commitCallable().futureCall(commitRequest, context)); + return get(commitAsync(commitRequest, options)); + } + + @Override + public ApiFuture rollbackAsync(RollbackRequest request, @Nullable Map options) { + GrpcCallContext context = newCallContext(options, request.getSession()); + return spannerStub.rollbackCallable().futureCall(request, context); } @Override public void rollback(RollbackRequest request, @Nullable Map options) throws SpannerException { - GrpcCallContext context = newCallContext(options, request.getSession()); - get(spannerStub.rollbackCallable().futureCall(request, context)); + get(rollbackAsync(request, options)); } @Override diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 5b6c6756d9..6b42c0a754 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -284,6 +284,9 @@ StreamingCall read( ResultSet executeQuery(ExecuteSqlRequest request, @Nullable Map options); + ApiFuture executeQueryAsync( + ExecuteSqlRequest request, @Nullable Map options); + ResultSet executePartitionedDml(ExecuteSqlRequest request, @Nullable Map options); RetrySettings getPartitionedDmlRetrySettings(); @@ -296,14 +299,25 @@ StreamingCall executeQuery( ExecuteBatchDmlResponse executeBatchDml(ExecuteBatchDmlRequest build, Map options); + ApiFuture executeBatchDmlAsync( + ExecuteBatchDmlRequest build, Map options); + Transaction beginTransaction(BeginTransactionRequest request, @Nullable Map options) throws SpannerException; + ApiFuture beginTransactionAsync( + BeginTransactionRequest request, @Nullable Map options); + CommitResponse commit(CommitRequest commitRequest, @Nullable Map options) throws SpannerException; + ApiFuture commitAsync( + CommitRequest commitRequest, @Nullable Map options); + void rollback(RollbackRequest request, @Nullable Map options) throws SpannerException; + ApiFuture rollbackAsync(RollbackRequest request, @Nullable Map options); + PartitionResponse partitionQuery(PartitionQueryRequest request, @Nullable Map options) throws SpannerException; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractAsyncTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractAsyncTransactionTest.java new file mode 100644 index 0000000000..bf76ea4f39 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractAsyncTransactionTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.MockSpannerTestUtil.EMPTY_KEY_VALUE_RESULTSET; +import static com.google.cloud.spanner.MockSpannerTestUtil.INVALID_UPDATE_STATEMENT; +import static com.google.cloud.spanner.MockSpannerTestUtil.READ_MULTIPLE_KEY_VALUE_RESULTSET; +import static com.google.cloud.spanner.MockSpannerTestUtil.READ_MULTIPLE_KEY_VALUE_STATEMENT; +import static com.google.cloud.spanner.MockSpannerTestUtil.READ_ONE_EMPTY_KEY_VALUE_STATEMENT; +import static com.google.cloud.spanner.MockSpannerTestUtil.READ_ONE_KEY_VALUE_RESULTSET; +import static com.google.cloud.spanner.MockSpannerTestUtil.READ_ONE_KEY_VALUE_STATEMENT; +import static com.google.cloud.spanner.MockSpannerTestUtil.TEST_DATABASE; +import static com.google.cloud.spanner.MockSpannerTestUtil.TEST_INSTANCE; +import static com.google.cloud.spanner.MockSpannerTestUtil.TEST_PROJECT; +import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_ABORTED_STATEMENT; +import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_COUNT; +import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_STATEMENT; + +import com.google.api.core.ApiFunction; +import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import java.net.InetSocketAddress; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +/** Base class for {@link AsyncRunnerTest} and {@link AsyncTransactionManagerTest}. */ +public abstract class AbstractAsyncTransactionTest { + static MockSpannerServiceImpl mockSpanner; + private static Server server; + private static InetSocketAddress address; + static ExecutorService executor; + + Spanner spanner; + Spanner spannerWithEmptySessionPool; + + @BeforeClass + public static void setup() throws Exception { + mockSpanner = new MockSpannerServiceImpl(); + mockSpanner.setAbortProbability(0.0D); + mockSpanner.putStatementResult( + StatementResult.query(READ_ONE_EMPTY_KEY_VALUE_STATEMENT, EMPTY_KEY_VALUE_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.query(READ_ONE_KEY_VALUE_STATEMENT, READ_ONE_KEY_VALUE_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.query( + READ_MULTIPLE_KEY_VALUE_STATEMENT, READ_MULTIPLE_KEY_VALUE_RESULTSET)); + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + mockSpanner.putStatementResult( + StatementResult.exception( + INVALID_UPDATE_STATEMENT, + Status.INVALID_ARGUMENT.withDescription("invalid statement").asRuntimeException())); + mockSpanner.putStatementResult( + StatementResult.exception( + UPDATE_ABORTED_STATEMENT, + Status.ABORTED.withDescription("Transaction was aborted").asRuntimeException())); + + address = new InetSocketAddress("localhost", 0); + server = NettyServerBuilder.forAddress(address).addService(mockSpanner).build().start(); + executor = Executors.newSingleThreadExecutor(); + } + + @AfterClass + public static void teardown() throws Exception { + server.shutdown(); + server.awaitTermination(); + executor.shutdown(); + } + + @Before + public void before() throws Exception { + String endpoint = address.getHostString() + ":" + server.getPort(); + spanner = + SpannerOptions.newBuilder() + .setProjectId(TEST_PROJECT) + .setChannelConfigurator( + new ApiFunction() { + @Override + public ManagedChannelBuilder apply(ManagedChannelBuilder input) { + input.usePlaintext(); + return input; + } + }) + .setHost("http://" + endpoint) + .setCredentials(NoCredentials.getInstance()) + .setSessionPoolOption(SessionPoolOptions.newBuilder().setFailOnSessionLeak().build()) + .build() + .getService(); + spannerWithEmptySessionPool = + spanner + .getOptions() + .toBuilder() + .setSessionPoolOption( + SessionPoolOptions.newBuilder() + .setFailOnSessionLeak() + .setMinSessions(0) + .setIncStep(1) + .build()) + .build() + .getService(); + } + + @After + public void after() throws Exception { + spanner.close(); + spannerWithEmptySessionPool.close(); + mockSpanner.removeAllExecutionTimes(); + mockSpanner.reset(); + } + + DatabaseClient client() { + return spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + } + + DatabaseClient clientWithEmptySessionPool() { + return spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java index bfd739d553..f9cee1c488 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java @@ -20,6 +20,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; @@ -80,6 +81,7 @@ public void setup() { .setSession(session) .setRpc(mock(SpannerRpc.class)) .setDefaultQueryOptions(defaultQueryOptions) + .setExecutorProvider(mock(ExecutorProvider.class)) .build(); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplStressTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplStressTest.java new file mode 100644 index 0000000000..c3383cadda --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplStressTest.java @@ -0,0 +1,464 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.core.ExecutorProvider; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.CursorState; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.cloud.spanner.Type.StructField; +import com.google.common.base.Function; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class AsyncResultSetImplStressTest { + private static final int TEST_RUNS = 1000; + + /** Timeout is applied to each test case individually. */ + @Rule public Timeout timeout = new Timeout(120, TimeUnit.SECONDS); + + @Parameter(0) + public int resultSetSize; + + @Parameters(name = "rows = {0}") + public static Collection data() { + List params = new ArrayList<>(); + for (int rows : new int[] {0, 1, 5, 10}) { + params.add(new Object[] {rows}); + } + return params; + } + + /** POJO representing a row in the test {@link ResultSet}. */ + private static final class Row { + private final Long id; + private final String name; + + static Row create(StructReader reader) { + return new Row(reader.getLong("ID"), reader.getString("NAME")); + } + + private Row(Long id, String name) { + this.id = id; + this.name = name; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Row)) { + return false; + } + Row other = (Row) o; + return Objects.equals(this.id, other.id) && Objects.equals(this.name, other.name); + } + + @Override + public int hashCode() { + return Objects.hash(this.id, this.name); + } + + @Override + public String toString() { + return String.format("ID: %d, NAME: %s", id, name); + } + } + + private static final class ResultSetWithRandomErrors extends ForwardingResultSet { + private final Random random = new Random(); + private final double errorFraction; + + private ResultSetWithRandomErrors(ResultSet delegate, double errorFraction) { + super(delegate); + this.errorFraction = errorFraction; + } + + @Override + public boolean next() { + if (random.nextDouble() < errorFraction) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.INVALID_ARGUMENT, "random error"); + } + return super.next(); + } + } + + /** Creates a simple in-mem {@link ResultSet}. */ + private ResultSet createResultSet() { + List rows = new ArrayList<>(resultSetSize); + for (int i = 0; i < resultSetSize; i++) { + rows.add( + Struct.newBuilder() + .set("ID") + .to(i + 1) + .set("NAME") + .to(String.format("Row %d", (i + 1))) + .build()); + } + return ResultSets.forRows( + Type.struct(StructField.of("ID", Type.int64()), StructField.of("NAME", Type.string())), + rows); + } + + private ResultSet createResultSetWithErrors(double errorFraction) { + return new ResultSetWithRandomErrors(createResultSet(), errorFraction); + } + + /** + * Generates a list of {@link Row} instances that correspond with the rows in {@link + * #createResultSet()}. + */ + private List createExpectedRows() { + List rows = new ArrayList<>(resultSetSize); + for (int i = 0; i < resultSetSize; i++) { + rows.add(new Row(Long.valueOf(i + 1), String.format("Row %d", (i + 1)))); + } + return rows; + } + + /** Creates a single-threaded {@link ExecutorService}. */ + private static ScheduledExecutorService createExecService() { + return createExecService(1); + } + + /** Creates an {@link ExecutorService} using a bounded pool of threadCount threads. */ + private static ScheduledExecutorService createExecService(int threadCount) { + return Executors.newScheduledThreadPool( + threadCount, new ThreadFactoryBuilder().setDaemon(true).build()); + } + + @Test + public void toList() throws Exception { + ExecutorProvider executorProvider = SpannerOptions.createDefaultAsyncExecutorProvider(); + for (int bufferSize = 1; bufferSize < resultSetSize * 2; bufferSize *= 2) { + for (int i = 0; i < TEST_RUNS; i++) { + try (AsyncResultSetImpl impl = + new AsyncResultSetImpl(executorProvider, createResultSet(), bufferSize)) { + ImmutableList list = + impl.toList( + new Function() { + @Override + public Row apply(StructReader input) { + return Row.create(input); + } + }); + assertThat(list).containsExactlyElementsIn(createExpectedRows()); + } + } + } + } + + @Test + public void toListWithErrors() throws Exception { + ExecutorProvider executorProvider = SpannerOptions.createDefaultAsyncExecutorProvider(); + for (int bufferSize = 1; bufferSize < resultSetSize * 2; bufferSize *= 2) { + for (int i = 0; i < TEST_RUNS; i++) { + try (AsyncResultSetImpl impl = + new AsyncResultSetImpl( + executorProvider, createResultSetWithErrors(1.0 / resultSetSize), bufferSize)) { + ImmutableList list = + impl.toList( + new Function() { + @Override + public Row apply(StructReader input) { + return Row.create(input); + } + }); + assertThat(list).containsExactlyElementsIn(createExpectedRows()); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(e.getMessage()).contains("random error"); + } + } + } + } + + @Test + public void asyncToList() throws Exception { + ExecutorProvider executorProvider = SpannerOptions.createDefaultAsyncExecutorProvider(); + for (int bufferSize = 1; bufferSize < resultSetSize * 2; bufferSize *= 2) { + List>> futures = new ArrayList<>(TEST_RUNS); + ExecutorService executor = createExecService(32); + for (int i = 0; i < TEST_RUNS; i++) { + try (AsyncResultSet impl = + new AsyncResultSetImpl(executorProvider, createResultSet(), bufferSize)) { + futures.add( + impl.toListAsync( + new Function() { + @Override + public Row apply(StructReader input) { + return Row.create(input); + } + }, + executor)); + } + } + List> lists = ApiFutures.allAsList(futures).get(); + for (ImmutableList list : lists) { + assertThat(list).containsExactlyElementsIn(createExpectedRows()); + } + executor.shutdown(); + } + } + + @Test + public void consume() throws Exception { + ExecutorProvider executorProvider = SpannerOptions.createDefaultAsyncExecutorProvider(); + final Random random = new Random(); + for (Executor executor : + new Executor[] { + MoreExecutors.directExecutor(), createExecService(), createExecService(32) + }) { + for (int bufferSize = 1; bufferSize < resultSetSize * 2; bufferSize *= 2) { + for (int i = 0; i < TEST_RUNS; i++) { + final SettableApiFuture> future = SettableApiFuture.create(); + try (AsyncResultSetImpl impl = + new AsyncResultSetImpl(executorProvider, createResultSet(), bufferSize)) { + final ImmutableList.Builder builder = ImmutableList.builder(); + impl.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + // Randomly do something with the received data or not. Not calling tryNext() in + // the onDataReady is not 'normal', but users may do it, and the result set + // should be able to handle that. + if (random.nextBoolean()) { + CursorState state; + while ((state = resultSet.tryNext()) == CursorState.OK) { + builder.add(Row.create(resultSet)); + } + if (state == CursorState.DONE) { + future.set(builder.build()); + } + } + return CallbackResponse.CONTINUE; + } + }); + assertThat(future.get()).containsExactlyElementsIn(createExpectedRows()); + } + } + } + } + } + + @Test + public void pauseResume() throws Exception { + ExecutorProvider executorProvider = SpannerOptions.createDefaultAsyncExecutorProvider(); + final Random random = new Random(); + List>> futures = new ArrayList<>(); + for (Executor executor : + new Executor[] { + MoreExecutors.directExecutor(), createExecService(), createExecService(32) + }) { + final List resultSets = + Collections.synchronizedList(new ArrayList()); + for (int bufferSize = 1; bufferSize < resultSetSize * 2; bufferSize *= 2) { + for (int i = 0; i < TEST_RUNS; i++) { + final SettableApiFuture> future = SettableApiFuture.create(); + futures.add(future); + try (AsyncResultSetImpl impl = + new AsyncResultSetImpl(executorProvider, createResultSet(), bufferSize)) { + resultSets.add(impl); + final ImmutableList.Builder builder = ImmutableList.builder(); + impl.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + CursorState state; + while ((state = resultSet.tryNext()) == CursorState.OK) { + builder.add(Row.create(resultSet)); + // Randomly request the iterator to pause. + if (random.nextBoolean()) { + return CallbackResponse.PAUSE; + } + } + if (state == CursorState.DONE) { + future.set(builder.build()); + } + return CallbackResponse.CONTINUE; + } + }); + } + } + } + final AtomicBoolean finished = new AtomicBoolean(false); + ExecutorService resumeService = createExecService(); + resumeService.execute( + new Runnable() { + @Override + public void run() { + while (!finished.get()) { + // Randomly resume result sets. + resultSets.get(random.nextInt(resultSets.size())).resume(); + } + } + }); + List> lists = ApiFutures.allAsList(futures).get(); + for (ImmutableList list : lists) { + assertThat(list).containsExactlyElementsIn(createExpectedRows()); + } + if (executor instanceof ExecutorService) { + ((ExecutorService) executor).shutdown(); + } + finished.set(true); + resumeService.shutdown(); + } + } + + @Test + public void cancel() throws Exception { + ExecutorProvider executorProvider = SpannerOptions.createDefaultAsyncExecutorProvider(); + final Random random = new Random(); + for (Executor executor : + new Executor[] { + MoreExecutors.directExecutor(), createExecService(), createExecService(32) + }) { + List>> futures = new ArrayList<>(); + final List resultSets = + Collections.synchronizedList(new ArrayList()); + final Set cancelledIndexes = new HashSet<>(); + for (int bufferSize = 1; bufferSize < resultSetSize * 2; bufferSize *= 2) { + for (int i = 0; i < TEST_RUNS; i++) { + final SettableApiFuture> future = SettableApiFuture.create(); + futures.add(future); + try (AsyncResultSetImpl impl = + new AsyncResultSetImpl(executorProvider, createResultSet(), bufferSize)) { + resultSets.add(impl); + final ImmutableList.Builder builder = ImmutableList.builder(); + impl.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + CursorState state; + while ((state = resultSet.tryNext()) == CursorState.OK) { + builder.add(Row.create(resultSet)); + // Randomly request the iterator to pause. + if (random.nextBoolean()) { + return CallbackResponse.PAUSE; + } + } + if (state == CursorState.DONE) { + future.set(builder.build()); + } + return CallbackResponse.CONTINUE; + } catch (SpannerException e) { + future.setException(e); + throw e; + } + } + }); + } + } + } + final AtomicBoolean finished = new AtomicBoolean(false); + // Both resume and cancel resultsets randomly. + ExecutorService resumeService = createExecService(); + resumeService.execute( + new Runnable() { + @Override + public void run() { + while (!finished.get()) { + // Randomly resume result sets. + resultSets.get(random.nextInt(resultSets.size())).resume(); + } + // Make sure all result sets finish. + for (AsyncResultSet rs : resultSets) { + rs.resume(); + } + } + }); + ExecutorService cancelService = createExecService(); + cancelService.execute( + new Runnable() { + @Override + public void run() { + while (!finished.get()) { + // Randomly cancel result sets. + int index = random.nextInt(resultSets.size()); + resultSets.get(index).cancel(); + cancelledIndexes.add(index); + } + } + }); + + // First wait until all result sets have finished. + for (ApiFuture> future : futures) { + try { + future.get(); + } catch (Throwable e) { + // ignore for now. + } + } + finished.set(true); + cancelService.shutdown(); + cancelService.awaitTermination(10L, TimeUnit.SECONDS); + + int index = 0; + for (ApiFuture> future : futures) { + try { + ImmutableList list = future.get(30L, TimeUnit.SECONDS); + // Note that the fact that the call succeeded for for this result set, does not + // necessarily mean that the result set was not cancelled. Cancelling a result set is a + // best-effort operation, and the entire result set may still be produced and returned to + // the user. + assertThat(list).containsExactlyElementsIn(createExpectedRows()); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.CANCELLED); + assertThat(cancelledIndexes).contains(index); + } + index++; + } + if (executor instanceof ExecutorService) { + ((ExecutorService) executor).shutdown(); + } + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java new file mode 100644 index 0000000000..9359dc6694 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java @@ -0,0 +1,443 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFuture; +import com.google.api.gax.core.ExecutorProvider; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.CursorState; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.common.base.Function; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Range; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AsyncResultSetImplTest { + private ExecutorProvider mockedProvider; + private ExecutorProvider simpleProvider; + + @Before + public void setup() { + mockedProvider = mock(ExecutorProvider.class); + when(mockedProvider.getExecutor()).thenReturn(mock(ScheduledExecutorService.class)); + simpleProvider = SpannerOptions.createAsyncExecutorProvider(1, 1L, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + @Test + public void close() { + AsyncResultSetImpl rs = + new AsyncResultSetImpl( + mockedProvider, mock(ResultSet.class), AsyncResultSetImpl.DEFAULT_BUFFER_SIZE); + rs.close(); + // Closing a second time should be a no-op. + rs.close(); + + // The following methods are not allowed to call after closing the result set. + try { + rs.setCallback(mock(Executor.class), mock(ReadyCallback.class)); + fail("missing expected exception"); + } catch (IllegalStateException e) { + } + try { + rs.toList(mock(Function.class)); + fail("missing expected exception"); + } catch (IllegalStateException e) { + } + try { + rs.toListAsync(mock(Function.class), mock(Executor.class)); + fail("missing expected exception"); + } catch (IllegalStateException e) { + } + + // The following methods are allowed on a closed result set. + AsyncResultSetImpl rs2 = + new AsyncResultSetImpl( + mockedProvider, mock(ResultSet.class), AsyncResultSetImpl.DEFAULT_BUFFER_SIZE); + rs2.setCallback(mock(Executor.class), mock(ReadyCallback.class)); + rs2.close(); + rs2.cancel(); + rs2.resume(); + } + + @Test + public void tryNextNotAllowed() { + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl( + mockedProvider, mock(ResultSet.class), AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.setCallback(mock(Executor.class), mock(ReadyCallback.class)); + try { + rs.tryNext(); + fail("missing expected exception"); + } catch (IllegalStateException e) { + assertThat(e.getMessage()) + .contains("tryNext may only be called from a DataReady callback."); + } + } + } + + @Test + public void toList() { + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()).thenReturn(true, true, true, false); + when(delegate.getCurrentRowAsStruct()).thenReturn(mock(Struct.class)); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + ImmutableList list = + rs.toList( + new Function() { + @Override + public Object apply(StructReader input) { + return new Object(); + } + }); + assertThat(list).hasSize(3); + } + } + + @Test + public void toListPropagatesError() { + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()) + .thenThrow( + SpannerExceptionFactory.newSpannerException( + ErrorCode.INVALID_ARGUMENT, "invalid query")); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.toList( + new Function() { + @Override + public Object apply(StructReader input) { + return new Object(); + } + }); + fail("missing expected exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(e.getMessage()).contains("invalid query"); + } + } + + @Test + public void toListAsync() throws InterruptedException, ExecutionException { + ExecutorService executor = Executors.newFixedThreadPool(1); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()).thenReturn(true, true, true, false); + when(delegate.getCurrentRowAsStruct()).thenReturn(mock(Struct.class)); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + ApiFuture> future = + rs.toListAsync( + new Function() { + @Override + public Object apply(StructReader input) { + return new Object(); + } + }, + executor); + assertThat(future.get()).hasSize(3); + } + executor.shutdown(); + } + + @Test + public void toListAsyncPropagatesError() throws InterruptedException { + ExecutorService executor = Executors.newFixedThreadPool(1); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()) + .thenThrow( + SpannerExceptionFactory.newSpannerException( + ErrorCode.INVALID_ARGUMENT, "invalid query")); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.toListAsync( + new Function() { + @Override + public Object apply(StructReader input) { + return new Object(); + } + }, + executor) + .get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(se.getMessage()).contains("invalid query"); + } + executor.shutdown(); + } + + @Test + public void withCallback() throws InterruptedException { + Executor executor = Executors.newSingleThreadExecutor(); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()).thenReturn(true, true, true, false); + when(delegate.getCurrentRowAsStruct()).thenReturn(mock(Struct.class)); + final AtomicInteger callbackCounter = new AtomicInteger(); + final AtomicInteger rowCounter = new AtomicInteger(); + final CountDownLatch finishedLatch = new CountDownLatch(1); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + callbackCounter.incrementAndGet(); + CursorState state; + while ((state = resultSet.tryNext()) == CursorState.OK) { + rowCounter.incrementAndGet(); + } + if (state == CursorState.DONE) { + finishedLatch.countDown(); + } + return CallbackResponse.CONTINUE; + } + }); + } + finishedLatch.await(); + // There should be between 1 and 4 callbacks, depending on the timing of the threads. + // Normally, there should be just 1 callback. + assertThat(callbackCounter.get()).isIn(Range.closed(1, 4)); + assertThat(rowCounter.get()).isEqualTo(3); + } + + @Test + public void callbackReceivesError() throws InterruptedException { + Executor executor = Executors.newSingleThreadExecutor(); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()) + .thenThrow( + SpannerExceptionFactory.newSpannerException( + ErrorCode.INVALID_ARGUMENT, "invalid query")); + final BlockingDeque receivedErr = new LinkedBlockingDeque<>(1); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + resultSet.tryNext(); + receivedErr.push(new Exception("missing expected exception")); + } catch (SpannerException e) { + receivedErr.push(e); + } + return CallbackResponse.DONE; + } + }); + } + Exception e = receivedErr.take(); + assertThat(e).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e; + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(se.getMessage()).contains("invalid query"); + } + + @Test + public void callbackReceivesErrorHalfwayThrough() throws InterruptedException { + Executor executor = Executors.newSingleThreadExecutor(); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()) + .thenReturn(true) + .thenThrow( + SpannerExceptionFactory.newSpannerException( + ErrorCode.INVALID_ARGUMENT, "invalid query")); + when(delegate.getCurrentRowAsStruct()).thenReturn(mock(Struct.class)); + final AtomicInteger rowCount = new AtomicInteger(); + final BlockingDeque receivedErr = new LinkedBlockingDeque<>(1); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + if (resultSet.tryNext() != CursorState.DONE) { + rowCount.incrementAndGet(); + return CallbackResponse.CONTINUE; + } + } catch (SpannerException e) { + receivedErr.push(e); + } + return CallbackResponse.DONE; + } + }); + } + Exception e = receivedErr.take(); + assertThat(e).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e; + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(se.getMessage()).contains("invalid query"); + assertThat(rowCount.get()).isEqualTo(1); + } + + @Test + public void pauseResume() throws InterruptedException { + Executor executor = Executors.newSingleThreadExecutor(); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()).thenReturn(true, true, true, false); + when(delegate.getCurrentRowAsStruct()).thenReturn(mock(Struct.class)); + final AtomicInteger callbackCounter = new AtomicInteger(); + final BlockingDeque queue = new LinkedBlockingDeque<>(1); + final AtomicBoolean finished = new AtomicBoolean(false); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + callbackCounter.incrementAndGet(); + CursorState state = resultSet.tryNext(); + if (state == CursorState.OK) { + try { + queue.put(new Object()); + } catch (InterruptedException e) { + // Finish early if an error occurs. + return CallbackResponse.DONE; + } + return CallbackResponse.PAUSE; + } + finished.set(true); + return CallbackResponse.DONE; + } + }); + int rowCounter = 0; + while (!finished.get()) { + Object o = queue.poll(1L, TimeUnit.MILLISECONDS); + if (o != null) { + rowCounter++; + } + rs.resume(); + } + // There should be exactly 4 callbacks as we only consume one row per callback. + assertThat(callbackCounter.get()).isEqualTo(4); + assertThat(rowCounter).isEqualTo(3); + } + } + + @Test + public void cancel() throws InterruptedException { + Executor executor = Executors.newSingleThreadExecutor(); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()).thenReturn(true, true, true, false); + when(delegate.getCurrentRowAsStruct()).thenReturn(mock(Struct.class)); + final AtomicInteger callbackCounter = new AtomicInteger(); + final BlockingDeque queue = new LinkedBlockingDeque<>(1); + final AtomicBoolean finished = new AtomicBoolean(false); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + callbackCounter.incrementAndGet(); + try { + CursorState state = resultSet.tryNext(); + if (state == CursorState.OK) { + try { + queue.put(new Object()); + } catch (InterruptedException e) { + // Finish early if an error occurs. + return CallbackResponse.DONE; + } + } + // Pause after 2 rows to make sure that no more data is consumed until the cancel + // call has been received. + return callbackCounter.get() == 2 + ? CallbackResponse.PAUSE + : CallbackResponse.CONTINUE; + } catch (SpannerException e) { + if (e.getErrorCode() == ErrorCode.CANCELLED) { + finished.set(true); + } + } + return CallbackResponse.DONE; + } + }); + int rowCounter = 0; + while (!finished.get()) { + Object o = queue.poll(1L, TimeUnit.MILLISECONDS); + if (o != null) { + rowCounter++; + } + if (rowCounter == 2) { + // Cancel the result set and then resume it to get the cancelled error. + rs.cancel(); + rs.resume(); + } + } + assertThat(callbackCounter.get()).isIn(Range.closed(2, 4)); + assertThat(rowCounter).isIn(Range.closed(2, 3)); + } + } + + @Test + public void callbackReturnsError() throws InterruptedException { + Executor executor = Executors.newSingleThreadExecutor(); + ResultSet delegate = mock(ResultSet.class); + when(delegate.next()).thenReturn(true, true, true, false); + when(delegate.getCurrentRowAsStruct()).thenReturn(mock(Struct.class)); + final AtomicInteger callbackCounter = new AtomicInteger(); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + callbackCounter.incrementAndGet(); + throw new RuntimeException("async test"); + } + }); + rs.getResult().get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.UNKNOWN); + assertThat(se.getMessage()).contains("async test"); + assertThat(callbackCounter.get()).isEqualTo(1); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java new file mode 100644 index 0000000000..eb00047ca4 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java @@ -0,0 +1,618 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.MockSpannerTestUtil.*; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.api.core.ApiFunction; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.cloud.spanner.AsyncRunner.AsyncWork; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.common.base.Function; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.spanner.v1.BatchCreateSessionsRequest; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import io.grpc.Status; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AsyncRunnerTest extends AbstractAsyncTransactionTest { + @Test + public void asyncRunnerUpdate() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return txn.executeUpdateAsync(UPDATE_STATEMENT); + } + }, + executor); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + } + + @Test + public void asyncRunnerIsNonBlocking() throws Exception { + mockSpanner.freeze(); + AsyncRunner runner = clientWithEmptySessionPool().runAsync(); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + txn.executeUpdateAsync(UPDATE_STATEMENT); + return ApiFutures.immediateFuture(null); + } + }, + executor); + ApiFuture ts = runner.getCommitTimestamp(); + mockSpanner.unfreeze(); + assertThat(res.get()).isNull(); + assertThat(ts.get()).isNotNull(); + } + + @Test + public void asyncRunnerInvalidUpdate() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return txn.executeUpdateAsync(INVALID_UPDATE_STATEMENT); + } + }, + executor); + try { + updateCount.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(se.getMessage()).contains("invalid statement"); + } + } + + @Test + public void asyncRunnerFireAndForgetInvalidUpdate() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + txn.executeUpdateAsync(INVALID_UPDATE_STATEMENT); + return txn.executeUpdateAsync(UPDATE_STATEMENT); + } + }, + executor); + assertThat(res.get()).isEqualTo(UPDATE_COUNT); + } + + @Test + public void asyncRunnerUpdateAborted() throws Exception { + try { + // Temporarily set the result of the update to 2 rows. + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L)); + final AtomicInteger attempt = new AtomicInteger(); + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } else { + // Set the result of the update statement back to 1 row. + mockSpanner.putStatementResult( + StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + return txn.executeUpdateAsync(UPDATE_STATEMENT); + } + }, + executor); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(attempt.get()).isEqualTo(2); + } finally { + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + } + + @Test + public void asyncRunnerCommitAborted() throws Exception { + try { + // Temporarily set the result of the update to 2 rows. + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L)); + final AtomicInteger attempt = new AtomicInteger(); + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(final TransactionContext txn) { + if (attempt.get() > 0) { + // Set the result of the update statement back to 1 row. + mockSpanner.putStatementResult( + StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + ApiFuture updateCount = txn.executeUpdateAsync(UPDATE_STATEMENT); + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + return updateCount; + } + }, + executor); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(attempt.get()).isEqualTo(2); + } finally { + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + } + + @Test + public void asyncRunnerUpdateAbortedWithoutGettingResult() throws Exception { + final AtomicInteger attempt = new AtomicInteger(); + AsyncRunner runner = clientWithEmptySessionPool().runAsync(); + ApiFuture result = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + // This update statement will be aborted, but the error will not propagated to the + // transaction runner and cause the transaction to retry. Instead, the commit call + // will do that. + txn.executeUpdateAsync(UPDATE_STATEMENT); + // Resolving this future will not resolve the result of the entire transaction. The + // transaction result will be resolved when the commit has actually finished + // successfully. + return ApiFutures.immediateFuture(null); + } + }, + executor); + assertThat(result.get()).isNull(); + assertThat(attempt.get()).isEqualTo(2); + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteSqlRequest.class, + CommitRequest.class, + BeginTransactionRequest.class, + ExecuteSqlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncRunnerCommitFails() throws Exception { + mockSpanner.setCommitExecutionTime( + SimulatedExecutionTime.ofException( + Status.RESOURCE_EXHAUSTED + .withDescription("mutation limit exceeded") + .asRuntimeException())); + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + // This statement will succeed, but the commit will fail. The error from the commit + // will bubble up to the future that is returned by the transaction, and the update + // count returned here will never reach the user application. + return txn.executeUpdateAsync(UPDATE_STATEMENT); + } + }, + executor); + try { + updateCount.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.RESOURCE_EXHAUSTED); + assertThat(se.getMessage()).contains("mutation limit exceeded"); + } + } + + @Test + public void asyncRunnerWaitsUntilAsyncUpdateHasFinished() throws Exception { + AsyncRunner runner = clientWithEmptySessionPool().runAsync(); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + txn.executeUpdateAsync(UPDATE_STATEMENT); + return ApiFutures.immediateFuture(null); + } + }, + executor); + res.get(); + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteSqlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncRunnerBatchUpdate() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + } + }, + executor); + assertThat(updateCount.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + } + + @Test + public void asyncRunnerIsNonBlockingWithBatchUpdate() throws Exception { + mockSpanner.freeze(); + AsyncRunner runner = clientWithEmptySessionPool().runAsync(); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT)); + return ApiFutures.immediateFuture(null); + } + }, + executor); + ApiFuture ts = runner.getCommitTimestamp(); + mockSpanner.unfreeze(); + assertThat(res.get()).isNull(); + assertThat(ts.get()).isNotNull(); + } + + @Test + public void asyncRunnerInvalidBatchUpdate() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return txn.batchUpdateAsync( + ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT)); + } + }, + executor); + try { + updateCount.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(se.getMessage()).contains("invalid statement"); + } + } + + @Test + public void asyncRunnerFireAndForgetInvalidBatchUpdate() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT)); + return txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + } + }, + executor); + assertThat(res.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + } + + @Test + public void asyncRunnerBatchUpdateAborted() throws Exception { + final AtomicInteger attempt = new AtomicInteger(); + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + if (attempt.incrementAndGet() == 1) { + return txn.batchUpdateAsync( + ImmutableList.of(UPDATE_STATEMENT, UPDATE_ABORTED_STATEMENT)); + } else { + return txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + } + } + }, + executor); + assertThat(updateCount.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + assertThat(attempt.get()).isEqualTo(2); + } + + @Test + public void asyncRunnerWithBatchUpdateCommitAborted() throws Exception { + try { + // Temporarily set the result of the update to 2 rows. + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L)); + final AtomicInteger attempt = new AtomicInteger(); + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(final TransactionContext txn) { + if (attempt.get() > 0) { + // Set the result of the update statement back to 1 row. + mockSpanner.putStatementResult( + StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + ApiFuture updateCount = + txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + return updateCount; + } + }, + executor); + assertThat(updateCount.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + assertThat(attempt.get()).isEqualTo(2); + } finally { + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + } + + @Test + public void asyncRunnerBatchUpdateAbortedWithoutGettingResult() throws Exception { + final AtomicInteger attempt = new AtomicInteger(); + AsyncRunner runner = clientWithEmptySessionPool().runAsync(); + ApiFuture result = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + // This update statement will be aborted, but the error will not propagated to the + // transaction runner and cause the transaction to retry. Instead, the commit call + // will do that. + txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + // Resolving this future will not resolve the result of the entire transaction. The + // transaction result will be resolved when the commit has actually finished + // successfully. + return ApiFutures.immediateFuture(null); + } + }, + executor); + assertThat(result.get()).isNull(); + assertThat(attempt.get()).isEqualTo(2); + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncRunnerWithBatchUpdateCommitFails() throws Exception { + mockSpanner.setCommitExecutionTime( + SimulatedExecutionTime.ofException( + Status.RESOURCE_EXHAUSTED + .withDescription("mutation limit exceeded") + .asRuntimeException())); + AsyncRunner runner = client().runAsync(); + ApiFuture updateCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + // This statement will succeed, but the commit will fail. The error from the commit + // will bubble up to the future that is returned by the transaction, and the update + // count returned here will never reach the user application. + return txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + } + }, + executor); + try { + updateCount.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.RESOURCE_EXHAUSTED); + assertThat(se.getMessage()).contains("mutation limit exceeded"); + } + } + + @Test + public void asyncRunnerWaitsUntilAsyncBatchUpdateHasFinished() throws Exception { + AsyncRunner runner = clientWithEmptySessionPool().runAsync(); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT)); + return ApiFutures.immediateFuture(null); + } + }, + executor); + res.get(); + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + + @Test + public void closeTransactionBeforeEndOfAsyncQuery() throws Exception { + final BlockingQueue results = new SynchronousQueue<>(); + final SettableApiFuture finished = SettableApiFuture.create(); + DatabaseClientImpl clientImpl = (DatabaseClientImpl) client(); + + // There should currently not be any sessions checked out of the pool. + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(0); + + AsyncRunner runner = clientImpl.runAsync(); + final CountDownLatch dataReceived = new CountDownLatch(1); + final CountDownLatch dataChecked = new CountDownLatch(1); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + try (AsyncResultSet rs = + txn.readAsync( + READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES, Options.bufferRows(1))) { + rs.setCallback( + Executors.newSingleThreadExecutor(), + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + dataReceived.countDown(); + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + finished.set(true); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + dataChecked.await(); + results.put(resultSet.getString(0)); + } + } + } catch (Throwable t) { + finished.setException(t); + return CallbackResponse.DONE; + } + } + }); + } + try { + dataReceived.await(); + return ApiFutures.immediateFuture(null); + } catch (InterruptedException e) { + return ApiFutures.immediateFailedFuture( + SpannerExceptionFactory.propagateInterrupt(e)); + } + } + }, + executor); + // Wait until at least one row has been fetched. At that moment there should be one session + // checked out. + dataReceived.await(); + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(1); + assertThat(res.isDone()).isFalse(); + dataChecked.countDown(); + // Get the data from the transaction. + List resultList = new ArrayList<>(); + do { + results.drainTo(resultList); + } while (!finished.isDone() || results.size() > 0); + assertThat(finished.get()).isTrue(); + assertThat(resultList).containsExactly("k1", "k2", "k3"); + assertThat(res.get()).isNull(); + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(0); + } + + @Test + public void asyncRunnerReadRow() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture val = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return ApiFutures.transform( + txn.readRowAsync(READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES), + new ApiFunction() { + @Override + public String apply(Struct input) { + return input.getString("Value"); + } + }, + MoreExecutors.directExecutor()); + } + }, + executor); + assertThat(val.get()).isEqualTo("v1"); + } + + @Test + public void asyncRunnerRead() throws Exception { + AsyncRunner runner = client().runAsync(); + ApiFuture> val = + runner.runAsync( + new AsyncWork>() { + @Override + public ApiFuture> doWorkAsync(TransactionContext txn) { + return txn.readAsync(READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES) + .toListAsync( + new Function() { + @Override + public String apply(StructReader input) { + return input.getString("Value"); + } + }, + MoreExecutors.directExecutor()); + } + }, + executor); + assertThat(val.get()).containsExactly("v1", "v2", "v3"); + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java new file mode 100644 index 0000000000..e2299f3615 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerTest.java @@ -0,0 +1,1045 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.MockSpannerTestUtil.INVALID_UPDATE_STATEMENT; +import static com.google.cloud.spanner.MockSpannerTestUtil.READ_COLUMN_NAMES; +import static com.google.cloud.spanner.MockSpannerTestUtil.READ_TABLE_NAME; +import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_ABORTED_STATEMENT; +import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_COUNT; +import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_STATEMENT; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutureCallback; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionFunction; +import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionStep; +import com.google.cloud.spanner.AsyncTransactionManager.CommitTimestampFuture; +import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Options.ReadOption; +import com.google.common.base.Function; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Range; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.AbstractMessage; +import com.google.spanner.v1.BatchCreateSessionsRequest; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import io.grpc.Status; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class AsyncTransactionManagerTest extends AbstractAsyncTransactionTest { + + @Parameter public Executor executor; + + @Parameters(name = "executor = {0}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {MoreExecutors.directExecutor()}, + {Executors.newSingleThreadExecutor()}, + {Executors.newFixedThreadPool(4)} + }); + } + + /** + * Static helper methods that simplifies creating {@link AsyncTransactionFunction}s for Java7. + * Java8 and higher can use lambda expressions. + */ + public static class AsyncTransactionManagerHelper { + + public static AsyncTransactionFunction readAsync( + final String table, + final KeySet keys, + final Iterable columns, + final ReadOption... options) { + return new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, I input) throws Exception { + return ApiFutures.immediateFuture(txn.readAsync(table, keys, columns, options)); + } + }; + } + + public static AsyncTransactionFunction readRowAsync( + final String table, final Key key, final Iterable columns) { + return new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, I input) throws Exception { + return txn.readRowAsync(table, key, columns); + } + }; + } + + public static AsyncTransactionFunction buffer(Mutation mutation) { + return buffer(ImmutableList.of(mutation)); + } + + public static AsyncTransactionFunction buffer(final Iterable mutations) { + return new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, I input) throws Exception { + txn.buffer(mutations); + return ApiFutures.immediateFuture(null); + } + }; + } + + public static AsyncTransactionFunction executeUpdateAsync(Statement statement) { + return executeUpdateAsync(SettableApiFuture.create(), statement); + } + + public static AsyncTransactionFunction executeUpdateAsync( + final SettableApiFuture result, final Statement statement) { + return new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, I input) throws Exception { + ApiFuture updateCount = txn.executeUpdateAsync(statement); + ApiFutures.addCallback( + updateCount, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + result.setException(t); + } + + @Override + public void onSuccess(Long input) { + result.set(input); + } + }, + MoreExecutors.directExecutor()); + return updateCount; + } + }; + } + + public static AsyncTransactionFunction batchUpdateAsync( + final Statement... statements) { + return batchUpdateAsync(SettableApiFuture.create(), statements); + } + + public static AsyncTransactionFunction batchUpdateAsync( + final SettableApiFuture result, final Statement... statements) { + return new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, I input) throws Exception { + ApiFuture updateCounts = txn.batchUpdateAsync(Arrays.asList(statements)); + ApiFutures.addCallback( + updateCounts, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + result.setException(t); + } + + @Override + public void onSuccess(long[] input) { + result.set(input); + } + }, + MoreExecutors.directExecutor()); + return updateCounts; + } + }; + } + } + + @Test + public void asyncTransactionManagerUpdate() throws Exception { + final SettableApiFuture updateCount = SettableApiFuture.create(); + + try (AsyncTransactionManager manager = client().transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + try { + CommitTimestampFuture commitTimestamp = + txn.then( + AsyncTransactionManagerHelper.executeUpdateAsync( + updateCount, UPDATE_STATEMENT), + executor) + .commitAsync(); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(commitTimestamp.get()).isNotNull(); + break; + } catch (AbortedException e) { + txn = manager.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerIsNonBlocking() throws Exception { + SettableApiFuture updateCount = SettableApiFuture.create(); + + mockSpanner.freeze(); + try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + try { + CommitTimestampFuture commitTimestamp = + txn.then( + AsyncTransactionManagerHelper.executeUpdateAsync( + updateCount, UPDATE_STATEMENT), + executor) + .commitAsync(); + mockSpanner.unfreeze(); + assertThat(updateCount.get(10L, TimeUnit.SECONDS)).isEqualTo(UPDATE_COUNT); + assertThat(commitTimestamp.get(10L, TimeUnit.SECONDS)).isNotNull(); + break; + } catch (AbortedException e) { + txn = manager.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerInvalidUpdate() throws Exception { + try (AsyncTransactionManager manager = client().transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + try { + CommitTimestampFuture commitTimestamp = + txn.then( + AsyncTransactionManagerHelper.executeUpdateAsync( + INVALID_UPDATE_STATEMENT), + executor) + .commitAsync(); + commitTimestamp.get(); + fail("missing expected exception"); + } catch (AbortedException e) { + txn = manager.resetForRetryAsync(); + } catch (ExecutionException e) { + manager.rollbackAsync(); + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(se.getMessage()).contains("invalid statement"); + break; + } + } + } + } + + @Test + public void asyncTransactionManagerCommitAborted() throws Exception { + SettableApiFuture updateCount = SettableApiFuture.create(); + final AtomicInteger attempt = new AtomicInteger(); + try (AsyncTransactionManager manager = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + try { + attempt.incrementAndGet(); + CommitTimestampFuture commitTimestamp = + txn.then( + AsyncTransactionManagerHelper.executeUpdateAsync( + updateCount, UPDATE_STATEMENT), + executor) + .then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Long input) + throws Exception { + if (attempt.get() == 1) { + mockSpanner.abortTransaction(txn); + } + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync(); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(commitTimestamp.get()).isNotNull(); + assertThat(attempt.get()).isEqualTo(2); + break; + } catch (AbortedException e) { + txn = manager.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerFireAndForgetInvalidUpdate() throws Exception { + final SettableApiFuture updateCount = SettableApiFuture.create(); + + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + CommitTimestampFuture ts = + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + // This fire-and-forget update statement should not fail the transaction. + txn.executeUpdateAsync(INVALID_UPDATE_STATEMENT); + ApiFutures.addCallback( + txn.executeUpdateAsync(UPDATE_STATEMENT), + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + updateCount.setException(t); + } + + @Override + public void onSuccess(Long result) { + updateCount.set(result); + } + }, + MoreExecutors.directExecutor()); + return updateCount; + } + }, + executor) + .commitAsync(); + assertThat(updateCount.get()).isEqualTo(UPDATE_COUNT); + assertThat(ts.get()).isNotNull(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerChain() throws Exception { + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + CommitTimestampFuture ts = + txn.then( + AsyncTransactionManagerHelper.executeUpdateAsync(UPDATE_STATEMENT), + executor) + .then( + AsyncTransactionManagerHelper.readRowAsync( + READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES), + executor) + .then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Struct input) + throws Exception { + return ApiFutures.immediateFuture(input.getString("Value")); + } + }, + executor) + .then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, String input) + throws Exception { + assertThat(input).isEqualTo("v1"); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync(); + assertThat(ts.get()).isNotNull(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerChainWithErrorInTheMiddle() throws Exception { + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + CommitTimestampFuture ts = + txn.then( + AsyncTransactionManagerHelper.executeUpdateAsync( + INVALID_UPDATE_STATEMENT), + executor) + .then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Long input) + throws Exception { + throw new IllegalStateException("this should not be executed"); + } + }, + executor) + .commitAsync(); + ts.get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } catch (ExecutionException e) { + mgr.rollbackAsync(); + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + break; + } + } + } + } + + @Test + public void asyncTransactionManagerUpdateAborted() throws Exception { + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + // Temporarily set the result of the update to 2 rows. + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L)); + final AtomicInteger attempt = new AtomicInteger(); + + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + CommitTimestampFuture ts = + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + if (attempt.incrementAndGet() == 1) { + // Abort the first attempt. + mockSpanner.abortTransaction(txn); + } else { + // Set the result of the update statement back to 1 row. + mockSpanner.putStatementResult( + StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + return ApiFutures.immediateFuture(null); + } + }, + executor) + .then( + AsyncTransactionManagerHelper.executeUpdateAsync(UPDATE_STATEMENT), + executor) + .commitAsync(); + assertThat(ts.get()).isNotNull(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + assertThat(attempt.get()).isEqualTo(2); + } finally { + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + } + + @Test + public void asyncTransactionManagerUpdateAbortedWithoutGettingResult() throws Exception { + final AtomicInteger attempt = new AtomicInteger(); + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + CommitTimestampFuture ts = + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + // This update statement will be aborted, but the error will not + // propagated to the transaction runner and cause the transaction to + // retry. Instead, the commit call will do that. + txn.executeUpdateAsync(UPDATE_STATEMENT); + // Resolving this future will not resolve the result of the entire + // transaction. The transaction result will be resolved when the commit + // has actually finished successfully. + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync(); + assertThat(ts.get()).isNotNull(); + assertThat(attempt.get()).isEqualTo(2); + // The server may receive 1 or 2 commit requests depending on whether the call to + // commitAsync() already knows that the transaction has aborted. If it does, it will not + // attempt to call the Commit RPC and instead directly propagate the Aborted error. + assertThat(mockSpanner.getRequestTypes()) + .containsAtLeast( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteSqlRequest.class, + BeginTransactionRequest.class, + ExecuteSqlRequest.class, + CommitRequest.class); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerCommitFails() throws Exception { + mockSpanner.setCommitExecutionTime( + SimulatedExecutionTime.ofException( + Status.RESOURCE_EXHAUSTED + .withDescription("mutation limit exceeded") + .asRuntimeException())); + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + AsyncTransactionManagerHelper.executeUpdateAsync(UPDATE_STATEMENT), + executor) + .commitAsync() + .get(); + fail("missing expected exception"); + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.RESOURCE_EXHAUSTED); + assertThat(se.getMessage()).contains("mutation limit exceeded"); + break; + } + } + } + } + + @Test + public void asyncTransactionManagerWaitsUntilAsyncUpdateHasFinished() throws Exception { + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + // Shoot-and-forget update. The commit will still wait for this request to + // finish. + txn.executeUpdateAsync(UPDATE_STATEMENT); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync() + .get(); + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteSqlRequest.class, + CommitRequest.class); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerBatchUpdate() throws Exception { + final SettableApiFuture result = SettableApiFuture.create(); + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + AsyncTransactionManagerHelper.batchUpdateAsync( + result, UPDATE_STATEMENT, UPDATE_STATEMENT), + executor) + .commitAsync() + .get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + } + + @Test + public void asyncTransactionManagerIsNonBlockingWithBatchUpdate() throws Exception { + SettableApiFuture res = SettableApiFuture.create(); + mockSpanner.freeze(); + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + CommitTimestampFuture ts = + txn.then( + AsyncTransactionManagerHelper.batchUpdateAsync(res, UPDATE_STATEMENT), + executor) + .commitAsync(); + mockSpanner.unfreeze(); + assertThat(ts.get()).isNotNull(); + assertThat(res.get()).asList().containsExactly(UPDATE_COUNT); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + } + + @Test + public void asyncTransactionManagerInvalidBatchUpdate() throws Exception { + SettableApiFuture result = SettableApiFuture.create(); + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + AsyncTransactionManagerHelper.batchUpdateAsync( + result, UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT), + executor) + .commitAsync() + .get(); + fail("missing expected exception"); + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(se.getMessage()).contains("invalid statement"); + break; + } + } + } + } + + @Test + public void asyncTransactionManagerFireAndForgetInvalidBatchUpdate() throws Exception { + SettableApiFuture result = SettableApiFuture.create(); + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + txn.batchUpdateAsync( + ImmutableList.of(UPDATE_STATEMENT, INVALID_UPDATE_STATEMENT)); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .then( + AsyncTransactionManagerHelper.batchUpdateAsync( + result, UPDATE_STATEMENT, UPDATE_STATEMENT), + executor) + .commitAsync() + .get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncTransactionManagerBatchUpdateAborted() throws Exception { + final AtomicInteger attempt = new AtomicInteger(); + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + if (attempt.incrementAndGet() == 1) { + return txn.batchUpdateAsync( + ImmutableList.of(UPDATE_STATEMENT, UPDATE_ABORTED_STATEMENT)); + } else { + return txn.batchUpdateAsync( + ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + } + } + }, + executor) + .commitAsync() + .get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(attempt.get()).isEqualTo(2); + // There should only be 1 CommitRequest, as the first attempt should abort already after the + // ExecuteBatchDmlRequest. + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncTransactionManagerWithBatchUpdateCommitAborted() throws Exception { + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + // Temporarily set the result of the update to 2 rows. + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT + 1L)); + final AtomicInteger attempt = new AtomicInteger(); + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + final SettableApiFuture result = SettableApiFuture.create(); + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + if (attempt.get() > 0) { + // Set the result of the update statement back to 1 row. + mockSpanner.putStatementResult( + StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + return ApiFutures.immediateFuture(null); + } + }, + executor) + .then( + AsyncTransactionManagerHelper.batchUpdateAsync( + result, UPDATE_STATEMENT, UPDATE_STATEMENT), + executor) + .then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, long[] input) + throws Exception { + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync() + .get(); + assertThat(result.get()).asList().containsExactly(UPDATE_COUNT, UPDATE_COUNT); + assertThat(attempt.get()).isEqualTo(2); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } finally { + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + } + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncTransactionManagerBatchUpdateAbortedWithoutGettingResult() throws Exception { + final AtomicInteger attempt = new AtomicInteger(); + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + if (attempt.incrementAndGet() == 1) { + mockSpanner.abortTransaction(txn); + } + // This update statement will be aborted, but the error will not propagated to + // the transaction manager and cause the transaction to retry. Instead, the + // commit call will do that. Depending on the timing, that will happen + // directly in the transaction manager if the ABORTED error has already been + // returned by the batch update call before the commit call starts. Otherwise, + // the backend will return an ABORTED error for the commit call. + txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT, UPDATE_STATEMENT)); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync() + .get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(attempt.get()).isEqualTo(2); + Iterable> requests = mockSpanner.getRequestTypes(); + int size = Iterables.size(requests); + assertThat(size).isIn(Range.closed(6, 7)); + if (size == 6) { + assertThat(requests) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } else { + assertThat(requests) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + } + + @Test + public void asyncTransactionManagerWithBatchUpdateCommitFails() throws Exception { + mockSpanner.setCommitExecutionTime( + SimulatedExecutionTime.ofException( + Status.RESOURCE_EXHAUSTED + .withDescription("mutation limit exceeded") + .asRuntimeException())); + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + AsyncTransactionManagerHelper.batchUpdateAsync( + UPDATE_STATEMENT, UPDATE_STATEMENT), + executor) + .commitAsync() + .get(); + fail("missing expected exception"); + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.RESOURCE_EXHAUSTED); + assertThat(se.getMessage()).contains("mutation limit exceeded"); + break; + } + } + } + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncTransactionManagerWaitsUntilAsyncBatchUpdateHasFinished() throws Exception { + try (AsyncTransactionManager mgr = clientWithEmptySessionPool().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + txn.batchUpdateAsync(ImmutableList.of(UPDATE_STATEMENT)); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync() + .get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + BatchCreateSessionsRequest.class, + BeginTransactionRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } + + @Test + public void asyncTransactionManagerReadRow() throws Exception { + ApiFuture val; + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + AsyncTransactionStep step; + val = + step = + txn.then( + AsyncTransactionManagerHelper.readRowAsync( + READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES), + executor) + .then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Struct input) + throws Exception { + return ApiFutures.immediateFuture(input.getString("Value")); + } + }, + executor); + step.commitAsync().get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(val.get()).isEqualTo("v1"); + } + + @Test + public void asyncTransactionManagerRead() throws Exception { + AsyncTransactionStep> res; + try (AsyncTransactionManager mgr = client().transactionManagerAsync()) { + TransactionContextFuture txn = mgr.beginAsync(); + while (true) { + try { + res = + txn.then( + new AsyncTransactionFunction>() { + @Override + public ApiFuture> apply( + TransactionContext txn, Void input) throws Exception { + return txn.readAsync(READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES) + .toListAsync( + new Function() { + @Override + public String apply(StructReader input) { + return input.getString("Value"); + } + }, + MoreExecutors.directExecutor()); + } + }, + executor); + // Commit the transaction. + res.commitAsync().get(); + break; + } catch (AbortedException e) { + txn = mgr.resetForRetryAsync(); + } + } + } + assertThat(res.get()).containsExactly("v1", "v2", "v3"); + } + + @Test + public void asyncTransactionManagerQuery() throws Exception { + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of("SELECT FirstName FROM Singers WHERE ID=1"), + MockSpannerTestUtil.READ_FIRST_NAME_SINGERS_RESULTSET)); + final long singerId = 1L; + try (AsyncTransactionManager manager = client().transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + final String column = "FirstName"; + CommitTimestampFuture commitTimestamp = + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + return txn.readRowAsync( + "Singers", Key.of(singerId), Collections.singleton(column)); + } + }, + executor) + .then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Struct input) + throws Exception { + String name = input.getString(column); + txn.buffer( + Mutation.newUpdateBuilder("Singers") + .set(column) + .to(name.toUpperCase()) + .build()); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync(); + try { + commitTimestamp.get(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + txn = manager.resetForRetryAsync(); + } + } + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java index 26bbef4535..1bcb303f72 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java @@ -59,7 +59,7 @@ public void release(ScheduledExecutorService executor) { } SessionImpl mockSession() { - SessionImpl session = mock(SessionImpl.class); + final SessionImpl session = mock(SessionImpl.class); when(session.getName()) .thenReturn( "projects/dummy/instances/dummy/database/dummy/sessions/session" + sessionIndex); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index 4f98f59104..3a41a46961 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -16,35 +16,42 @@ package com.google.cloud.spanner; +import static com.google.cloud.spanner.MockSpannerTestUtil.SELECT1; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; import com.google.api.gax.grpc.testing.LocalChannelProvider; import com.google.api.gax.retrying.RetrySettings; import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.cloud.spanner.AsyncRunner.AsyncWork; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; import com.google.common.base.Stopwatch; +import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.AbstractMessage; -import com.google.protobuf.ListValue; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; -import com.google.spanner.v1.ResultSetMetadata; -import com.google.spanner.v1.StructType; -import com.google.spanner.v1.StructType.Field; -import com.google.spanner.v1.TypeCode; import io.grpc.Server; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessServerBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -72,37 +79,17 @@ public class DatabaseClientImplTest { private static final Statement INVALID_UPDATE_STATEMENT = Statement.of("UPDATE NON_EXISTENT_TABLE SET BAR=1 WHERE BAZ=2"); private static final long UPDATE_COUNT = 1L; - private static final Statement SELECT1 = Statement.of("SELECT 1 AS COL1"); - private static final ResultSetMetadata SELECT1_METADATA = - ResultSetMetadata.newBuilder() - .setRowType( - StructType.newBuilder() - .addFields( - Field.newBuilder() - .setName("COL1") - .setType( - com.google.spanner.v1.Type.newBuilder() - .setCode(TypeCode.INT64) - .build()) - .build()) - .build()) - .build(); - private static final com.google.spanner.v1.ResultSet SELECT1_RESULTSET = - com.google.spanner.v1.ResultSet.newBuilder() - .addRows( - ListValue.newBuilder() - .addValues(com.google.protobuf.Value.newBuilder().setStringValue("1").build()) - .build()) - .setMetadata(SELECT1_METADATA) - .build(); private Spanner spanner; + private Spanner spannerWithEmptySessionPool; + private static final ExecutorService executor = Executors.newSingleThreadExecutor(); @BeforeClass public static void startStaticServer() throws IOException { mockSpanner = new MockSpannerServiceImpl(); mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions. mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); - mockSpanner.putStatementResult(StatementResult.query(SELECT1, SELECT1_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.query(SELECT1, MockSpannerTestUtil.SELECT1_RESULTSET)); mockSpanner.putStatementResult( StatementResult.exception( INVALID_UPDATE_STATEMENT, @@ -123,6 +110,7 @@ public static void startStaticServer() throws IOException { public static void stopServer() throws InterruptedException { server.shutdown(); server.awaitTermination(); + executor.shutdown(); } @Before @@ -132,17 +120,473 @@ public void setUp() { .setProjectId(TEST_PROJECT) .setChannelProvider(channelProvider) .setCredentials(NoCredentials.getInstance()) + .setSessionPoolOption(SessionPoolOptions.newBuilder().setFailOnSessionLeak().build()) + .build() + .getService(); + spannerWithEmptySessionPool = + spanner + .getOptions() + .toBuilder() + .setSessionPoolOption( + SessionPoolOptions.newBuilder().setMinSessions(0).setFailOnSessionLeak().build()) .build() .getService(); } @After public void tearDown() { + mockSpanner.unfreeze(); spanner.close(); + spannerWithEmptySessionPool.close(); mockSpanner.reset(); mockSpanner.removeAllExecutionTimes(); } + @Test + public void write() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + client.write( + Arrays.asList( + Mutation.newInsertBuilder("FOO").set("ID").to(1L).set("NAME").to("Bar").build())); + } + + @Test + public void writeAtLeastOnce() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + client.writeAtLeastOnce( + Arrays.asList( + Mutation.newInsertBuilder("FOO").set("ID").to(1L).set("NAME").to("Bar").build())); + } + + @Test + public void singleUse() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = client.singleUse().executeQuery(SELECT1)) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void singleUseIsNonBlocking() { + mockSpanner.freeze(); + // Use a Spanner instance with no initial sessions in the pool to show that getting a session + // from the pool and then preparing a query is non-blocking (i.e. does not wait on a reply from + // the server). + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = client.singleUse().executeQuery(SELECT1)) { + mockSpanner.unfreeze(); + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void singleUseAsync() throws Exception { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + final AtomicInteger rowCount = new AtomicInteger(); + ApiFuture res; + try (AsyncResultSet rs = client.singleUse().executeQueryAsync(SELECT1)) { + res = + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + while (true) { + switch (resultSet.tryNext()) { + case OK: + rowCount.incrementAndGet(); + break; + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + } + } + } + }); + } + res.get(); + assertThat(rowCount.get()).isEqualTo(1); + } + + @Test + public void singleUseBound() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = + client + .singleUse(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS)) + .executeQuery(SELECT1)) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void singleUseBoundIsNonBlocking() { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = + client + .singleUse(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS)) + .executeQuery(SELECT1)) { + mockSpanner.unfreeze(); + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void singleUseBoundAsync() throws Exception { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + final AtomicInteger rowCount = new AtomicInteger(); + ApiFuture res; + try (AsyncResultSet rs = + client + .singleUse(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS)) + .executeQueryAsync(SELECT1)) { + res = + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + while (true) { + switch (resultSet.tryNext()) { + case OK: + rowCount.incrementAndGet(); + break; + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + } + } + } + }); + } + res.get(); + assertThat(rowCount.get()).isEqualTo(1); + } + + @Test + public void singleUseTransaction() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = client.singleUseReadOnlyTransaction().executeQuery(SELECT1)) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void singleUseTransactionIsNonBlocking() { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = client.singleUseReadOnlyTransaction().executeQuery(SELECT1)) { + mockSpanner.unfreeze(); + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void singleUseTransactionBound() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = + client + .singleUseReadOnlyTransaction(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS)) + .executeQuery(SELECT1)) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void singleUseTransactionBoundIsNonBlocking() { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet rs = + client + .singleUseReadOnlyTransaction(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS)) + .executeQuery(SELECT1)) { + mockSpanner.unfreeze(); + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + + @Test + public void readOnlyTransaction() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ReadOnlyTransaction tx = client.readOnlyTransaction()) { + try (ResultSet rs = tx.executeQuery(SELECT1)) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + } + + @Test + public void readOnlyTransactionIsNonBlocking() { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ReadOnlyTransaction tx = client.readOnlyTransaction()) { + try (ResultSet rs = tx.executeQuery(SELECT1)) { + mockSpanner.unfreeze(); + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + } + + @Test + public void readOnlyTransactionBound() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ReadOnlyTransaction tx = + client.readOnlyTransaction(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS))) { + try (ResultSet rs = tx.executeQuery(SELECT1)) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + } + + @Test + public void readOnlyTransactionBoundIsNonBlocking() { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ReadOnlyTransaction tx = + client.readOnlyTransaction(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS))) { + try (ResultSet rs = tx.executeQuery(SELECT1)) { + mockSpanner.unfreeze(); + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + } + + @Test + public void readWriteTransaction() { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + TransactionRunner runner = client.readWriteTransaction(); + runner.run( + new TransactionCallable() { + @Override + public Void run(TransactionContext transaction) throws Exception { + transaction.executeUpdate(UPDATE_STATEMENT); + return null; + } + }); + } + + @Test + public void readWriteTransactionIsNonBlocking() { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + TransactionRunner runner = client.readWriteTransaction(); + // The runner.run(...) method cannot be made non-blocking, as it returns the result of the + // transaction. + mockSpanner.unfreeze(); + runner.run( + new TransactionCallable() { + @Override + public Void run(TransactionContext transaction) throws Exception { + transaction.executeUpdate(UPDATE_STATEMENT); + return null; + } + }); + } + + @Test + public void runAsync() throws Exception { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + ExecutorService executor = Executors.newSingleThreadExecutor(); + AsyncRunner runner = client.runAsync(); + ApiFuture fut = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return ApiFutures.immediateFuture(txn.executeUpdate(UPDATE_STATEMENT)); + } + }, + executor); + assertThat(fut.get()).isEqualTo(UPDATE_COUNT); + executor.shutdown(); + } + + @Test + public void runAsyncIsNonBlocking() throws Exception { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + ExecutorService executor = Executors.newSingleThreadExecutor(); + AsyncRunner runner = client.runAsync(); + ApiFuture fut = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return ApiFutures.immediateFuture(txn.executeUpdate(UPDATE_STATEMENT)); + } + }, + executor); + mockSpanner.unfreeze(); + assertThat(fut.get()).isEqualTo(UPDATE_COUNT); + executor.shutdown(); + } + + @Test + public void runAsyncWithException() throws Exception { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + ExecutorService executor = Executors.newSingleThreadExecutor(); + AsyncRunner runner = client.runAsync(); + ApiFuture fut = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return ApiFutures.immediateFuture(txn.executeUpdate(INVALID_UPDATE_STATEMENT)); + } + }, + executor); + try { + fut.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + } + executor.shutdown(); + } + + @Test + public void transactionManager() throws Exception { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (TransactionManager txManager = client.transactionManager()) { + while (true) { + TransactionContext tx = txManager.begin(); + try { + tx.executeUpdate(UPDATE_STATEMENT); + txManager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + tx = txManager.resetForRetry(); + } + } + } + } + + @Test + public void transactionManagerIsNonBlocking() throws Exception { + mockSpanner.freeze(); + DatabaseClient client = + spannerWithEmptySessionPool.getDatabaseClient( + DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (TransactionManager txManager = client.transactionManager()) { + while (true) { + mockSpanner.unfreeze(); + TransactionContext tx = txManager.begin(); + try { + tx.executeUpdate(UPDATE_STATEMENT); + txManager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + tx = txManager.resetForRetry(); + } + } + } + } + + @Test + public void transactionManagerExecuteQueryAsync() throws Exception { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + final AtomicInteger rowCount = new AtomicInteger(); + try (TransactionManager txManager = client.transactionManager()) { + while (true) { + TransactionContext tx = txManager.begin(); + try { + try (AsyncResultSet rs = tx.executeQueryAsync(SELECT1)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case OK: + rowCount.incrementAndGet(); + break; + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + } + } + } catch (Throwable t) { + return CallbackResponse.DONE; + } + } + }); + } + txManager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + tx = txManager.resetForRetry(); + } + } + } + assertThat(rowCount.get()).isEqualTo(1); + } + /** * Test that the update statement can be executed as a partitioned transaction that returns a * lower bound update count. @@ -470,6 +914,7 @@ public void testDatabaseOrInstanceDoesNotExistOnCreate() { DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); // The create session failure should propagate to the client and not retry. try (ResultSet rs = dbClient.singleUse().executeQuery(SELECT1)) { + rs.next(); fail("missing expected exception"); } catch (DatabaseNotFoundException | InstanceNotFoundException e) { // The server should only receive one BatchCreateSessions request. @@ -933,6 +1378,54 @@ public void testBackendPartitionQueryOptions() { } } + @Test + public void testAsyncQuery() throws Exception { + final int EXPECTED_ROW_COUNT = 10; + RandomResultSetGenerator generator = new RandomResultSetGenerator(EXPECTED_ROW_COUNT); + com.google.spanner.v1.ResultSet resultSet = generator.generate(); + mockSpanner.putStatementResult( + StatementResult.query(Statement.of("SELECT * FROM RANDOM"), resultSet)); + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + ExecutorService executor = Executors.newSingleThreadExecutor(); + ApiFuture resultSetClosed; + final SettableFuture finished = SettableFuture.create(); + final List receivedResults = new ArrayList<>(); + try (AsyncResultSet rs = + client.singleUse().executeQueryAsync(Statement.of("SELECT * FROM RANDOM"))) { + resultSetClosed = + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (rs.tryNext()) { + case DONE: + finished.set(true); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + receivedResults.add(resultSet.getCurrentRowAsStruct()); + break; + default: + throw new IllegalStateException("Unknown cursor state"); + } + } + } catch (Throwable t) { + finished.setException(t); + return CallbackResponse.DONE; + } + } + }); + } + assertThat(finished.get()).isTrue(); + assertThat(receivedResults.size()).isEqualTo(EXPECTED_ROW_COUNT); + resultSetClosed.get(); + } + @Test public void testClientIdReusedOnDatabaseNotFound() { mockSpanner.setBatchCreateSessionsExecutionTime( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestWithClosedSessionsEnv.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestWithClosedSessionsEnv.java index 6b22ba77c3..edbc7976c0 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestWithClosedSessionsEnv.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestWithClosedSessionsEnv.java @@ -17,6 +17,7 @@ package com.google.cloud.spanner; import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.SessionPool.PooledSessionFuture; import com.google.cloud.spanner.testing.RemoteSpannerHelper; /** @@ -73,30 +74,30 @@ public void setAllowSessionReplacing(boolean allow) { } @Override - PooledSession getReadSession() { - PooledSession session = super.getReadSession(); + PooledSessionFuture getReadSession() { + PooledSessionFuture session = super.getReadSession(); if (invalidateNextSession) { - session.delegate.close(); - session.setAllowReplacing(false); - awaitDeleted(session.delegate); - session.setAllowReplacing(allowReplacing); + session.get().delegate.close(); + session.get().setAllowReplacing(false); + awaitDeleted(session.get().delegate); + session.get().setAllowReplacing(allowReplacing); invalidateNextSession = false; } - session.setAllowReplacing(allowReplacing); + session.get().setAllowReplacing(allowReplacing); return session; } @Override - PooledSession getReadWriteSession() { - PooledSession session = super.getReadWriteSession(); + PooledSessionFuture getReadWriteSession() { + PooledSessionFuture session = super.getReadWriteSession(); if (invalidateNextSession) { - session.delegate.close(); - session.setAllowReplacing(false); - awaitDeleted(session.delegate); - session.setAllowReplacing(allowReplacing); + session.get().delegate.close(); + session.get().setAllowReplacing(false); + awaitDeleted(session.get().delegate); + session.get().setAllowReplacing(allowReplacing); invalidateNextSession = false; } - session.setAllowReplacing(allowReplacing); + session.get().setAllowReplacing(allowReplacing); return session; } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockDatabaseAdminServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockDatabaseAdminServiceImpl.java index 832dccb14c..25e162039b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockDatabaseAdminServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockDatabaseAdminServiceImpl.java @@ -79,8 +79,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.concurrent.CountDownLatch; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -434,7 +433,7 @@ private com.google.rpc.Status fromException(Exception e) { private static final String EXPIRE_TIME_MASK = "expire_time"; private static final Random RND = new Random(); private final Queue exceptions = new ConcurrentLinkedQueue<>(); - private final ReadWriteLock freezeLock = new ReentrantReadWriteLock(); + private volatile CountDownLatch freezeLock = new CountDownLatch(0); private final ConcurrentMap databases = new ConcurrentHashMap<>(); private final ConcurrentMap backups = new ConcurrentHashMap<>(); private final ConcurrentMap> filterMatches = new ConcurrentHashMap<>(); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 164a8842c7..cae41510fc 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -19,9 +19,11 @@ import com.google.api.gax.grpc.testing.MockGrpcService; import com.google.cloud.ByteArray; import com.google.cloud.Date; +import com.google.cloud.spanner.AbstractResultSet.GrpcStruct; import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; import com.google.common.base.Optional; import com.google.common.base.Preconditions; +import com.google.common.base.Stopwatch; import com.google.common.base.Throwables; import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.AbstractMessage; @@ -79,6 +81,7 @@ import java.util.Collection; import java.util.Collections; import java.util.Comparator; +import java.util.Deque; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -88,14 +91,15 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; import org.threeten.bp.Instant; /** @@ -232,7 +236,7 @@ private enum StatementResultType { private final StatementResultType type; private final Statement statement; private final Long updateCount; - private final ResultSet resultSet; + private final Deque resultSets; private final StatusRuntimeException exception; /** Creates a {@link StatementResult} for a query that returns a {@link ResultSet}. */ @@ -240,6 +244,15 @@ public static StatementResult query(Statement statement, ResultSet resultSet) { return new StatementResult(statement, resultSet); } + /** + * Creates a {@link StatementResult} for a query that returns a {@link ResultSet} the first + * time, and a different {@link ResultSet} for all subsequent calls. + */ + public static StatementResult queryAndThen( + Statement statement, ResultSet resultSet, ResultSet next) { + return new StatementResult(statement, resultSet); + } + /** Creates a {@link StatementResult} for a read request. */ public static StatementResult read( String table, KeySet keySet, Iterable columns, ResultSet resultSet) { @@ -256,6 +269,25 @@ public static StatementResult exception(Statement statement, StatusRuntimeExcept return new StatementResult(statement, exception); } + private static class KeepLastElementDeque extends LinkedList { + private static KeepLastElementDeque singleton(E item) { + return new KeepLastElementDeque(Collections.singleton(item)); + } + + private static KeepLastElementDeque of(E first, E second) { + return new KeepLastElementDeque(Arrays.asList(first, second)); + } + + private KeepLastElementDeque(Collection coll) { + super(coll); + } + + @Override + public E pop() { + return this.size() == 1 ? super.peek() : super.pop(); + } + } + /** * Creates a {@link Statement} for a read statement. This {@link Statement} can be used to mock * a result for a read request. @@ -275,6 +307,7 @@ public static Statement createReadStatement( builder.append(", "); } builder.append(col); + first = false; } builder.append(" FROM ").append(table); if (keySet.isAll()) { @@ -302,14 +335,24 @@ private static boolean isValidKeySet(KeySet keySet) { private StatementResult(Statement statement, Long updateCount) { this.statement = Preconditions.checkNotNull(statement); this.updateCount = Preconditions.checkNotNull(updateCount); - this.resultSet = null; + this.resultSets = null; this.exception = null; this.type = StatementResultType.UPDATE_COUNT; } private StatementResult(Statement statement, ResultSet resultSet) { this.statement = Preconditions.checkNotNull(statement); - this.resultSet = Preconditions.checkNotNull(resultSet); + this.resultSets = KeepLastElementDeque.singleton(Preconditions.checkNotNull(resultSet)); + this.updateCount = null; + this.exception = null; + this.type = StatementResultType.RESULT_SET; + } + + private StatementResult(Statement statement, ResultSet resultSet, ResultSet andThen) { + this.statement = Preconditions.checkNotNull(statement); + this.resultSets = + KeepLastElementDeque.of( + Preconditions.checkNotNull(resultSet), Preconditions.checkNotNull(andThen)); this.updateCount = null; this.exception = null; this.type = StatementResultType.RESULT_SET; @@ -318,7 +361,7 @@ private StatementResult(Statement statement, ResultSet resultSet) { private StatementResult( String table, KeySet keySet, Iterable columns, ResultSet resultSet) { this.statement = createReadStatement(table, keySet, columns); - this.resultSet = Preconditions.checkNotNull(resultSet); + this.resultSets = KeepLastElementDeque.singleton(Preconditions.checkNotNull(resultSet)); this.updateCount = null; this.exception = null; this.type = StatementResultType.RESULT_SET; @@ -327,7 +370,7 @@ private StatementResult( private StatementResult(Statement statement, StatusRuntimeException exception) { this.statement = Preconditions.checkNotNull(statement); this.exception = Preconditions.checkNotNull(exception); - this.resultSet = null; + this.resultSets = null; this.updateCount = null; this.type = StatementResultType.EXCEPTION; } @@ -340,7 +383,7 @@ private ResultSet getResultSet() { Preconditions.checkState( type == StatementResultType.RESULT_SET, "This statement result does not contain a result set"); - return resultSet; + return resultSets.pop(); } private Long getUpdateCount() { @@ -394,6 +437,11 @@ public static SimulatedExecutionTime ofStickyException(Exception exception) { return new SimulatedExecutionTime(0, 0, Arrays.asList(exception), true); } + public static SimulatedExecutionTime stickyDatabaseNotFoundException(String name) { + return ofStickyException( + SpannerExceptionFactoryTest.newStatusDatabaseNotFoundException(name)); + } + public static SimulatedExecutionTime ofExceptions(Collection exceptions) { return new SimulatedExecutionTime(0, 0, exceptions, false); } @@ -421,19 +469,15 @@ private SimulatedExecutionTime( void simulateExecutionTime( Queue globalExceptions, boolean stickyGlobalExceptions, - ReadWriteLock freezeLock) { - try { - freezeLock.readLock().lock(); - checkException(globalExceptions, stickyGlobalExceptions); - checkException(this.exceptions, stickyException); - if (minimumExecutionTime > 0 || randomExecutionTime > 0) { - Uninterruptibles.sleepUninterruptibly( - (randomExecutionTime == 0 ? 0 : RANDOM.nextInt(randomExecutionTime)) - + minimumExecutionTime, - TimeUnit.MILLISECONDS); - } - } finally { - freezeLock.readLock().unlock(); + CountDownLatch freezeLock) { + Uninterruptibles.awaitUninterruptibly(freezeLock); + checkException(globalExceptions, stickyGlobalExceptions); + checkException(this.exceptions, stickyException); + if (minimumExecutionTime > 0 || randomExecutionTime > 0) { + Uninterruptibles.sleepUninterruptibly( + (randomExecutionTime == 0 ? 0 : RANDOM.nextInt(randomExecutionTime)) + + minimumExecutionTime, + TimeUnit.MILLISECONDS); } } @@ -451,22 +495,23 @@ private static void checkException(Queue exceptions, boolean keepExce private final Random random = new Random(); private double abortProbability = 0.0010D; - private final Queue requests = new ConcurrentLinkedQueue<>(); - private final ReadWriteLock freezeLock = new ReentrantReadWriteLock(); - private final Queue exceptions = new ConcurrentLinkedQueue<>(); + private final Object lock = new Object(); + private Deque requests = new ConcurrentLinkedDeque<>(); + private volatile CountDownLatch freezeLock = new CountDownLatch(0); + private Queue exceptions = new ConcurrentLinkedQueue<>(); private boolean stickyGlobalExceptions = false; - private final ConcurrentMap statementResults = - new ConcurrentHashMap<>(); - private final ConcurrentMap sessions = new ConcurrentHashMap<>(); + private ConcurrentMap statementResults = new ConcurrentHashMap<>(); + private ConcurrentMap statementGetCounts = new ConcurrentHashMap<>(); + private ConcurrentMap sessions = new ConcurrentHashMap<>(); private ConcurrentMap sessionLastUsed = new ConcurrentHashMap<>(); - private final ConcurrentMap transactions = new ConcurrentHashMap<>(); - private final ConcurrentMap isPartitionedDmlTransaction = + private ConcurrentMap transactions = new ConcurrentHashMap<>(); + private ConcurrentMap isPartitionedDmlTransaction = new ConcurrentHashMap<>(); - private final ConcurrentMap abortedTransactions = new ConcurrentHashMap<>(); + private ConcurrentMap abortedTransactions = new ConcurrentHashMap<>(); private final AtomicBoolean abortNextTransaction = new AtomicBoolean(); private final AtomicBoolean abortNextStatement = new AtomicBoolean(); - private final ConcurrentMap transactionCounters = new ConcurrentHashMap<>(); - private final ConcurrentMap> partitionTokens = new ConcurrentHashMap<>(); + private ConcurrentMap transactionCounters = new ConcurrentHashMap<>(); + private ConcurrentMap> partitionTokens = new ConcurrentHashMap<>(); private ConcurrentMap transactionLastUsed = new ConcurrentHashMap<>(); private int maxNumSessionsInOneBatch = 100; private int maxTotalSessions = Integer.MAX_VALUE; @@ -532,11 +577,29 @@ private Timestamp getCurrentGoogleTimestamp() { */ public void putStatementResult(StatementResult result) { Preconditions.checkNotNull(result); - statementResults.put(result.statement, result); + synchronized (lock) { + statementResults.put(result.statement, result); + } + } + + public void putStatementResults(StatementResult... results) { + synchronized (lock) { + for (StatementResult result : results) { + statementResults.put(result.statement, result); + } + } } private StatementResult getResult(Statement statement) { - StatementResult res = statementResults.get(statement); + StatementResult res; + synchronized (lock) { + res = statementResults.get(statement); + if (statementGetCounts.containsKey(statement)) { + statementGetCounts.put(statement, statementGetCounts.get(statement) + 1L); + } else { + statementGetCounts.put(statement, 1L); + } + } if (res == null) { throw Status.INTERNAL .withDescription( @@ -593,11 +656,11 @@ public void abortAllTransactions() { } public void freeze() { - freezeLock.writeLock().lock(); + freezeLock = new CountDownLatch(1); } public void unfreeze() { - freezeLock.writeLock().unlock(); + freezeLock.countDown(); } public void setMaxSessionsInOneBatch(int max) { @@ -935,6 +998,7 @@ public void executeBatchDml( status = com.google.rpc.Status.newBuilder() .setCode(res.getException().getStatus().getCode().value()) + .setMessage(res.getException().getMessage()) .build(); break resultLoop; case RESULT_SET: @@ -1055,6 +1119,7 @@ public void executeStreamingSql( } } + @SuppressWarnings("unchecked") private Statement buildStatement( String sql, Map paramTypes, com.google.protobuf.Struct params) { Statement.Builder builder = Statement.newBuilder(sql); @@ -1063,7 +1128,37 @@ private Statement buildStatement( if (value.getKindCase() == KindCase.NULL_VALUE) { switch (entry.getValue().getCode()) { case ARRAY: - throw new IllegalArgumentException("Array parameters not (yet) supported"); + switch (entry.getValue().getArrayElementType().getCode()) { + case BOOL: + builder.bind(entry.getKey()).toBoolArray((Iterable) null); + break; + case BYTES: + builder.bind(entry.getKey()).toBytesArray(null); + break; + case DATE: + builder.bind(entry.getKey()).toDateArray(null); + break; + case FLOAT64: + builder.bind(entry.getKey()).toFloat64Array((Iterable) null); + break; + case INT64: + builder.bind(entry.getKey()).toInt64Array((Iterable) null); + break; + case STRING: + builder.bind(entry.getKey()).toStringArray(null); + break; + case TIMESTAMP: + builder.bind(entry.getKey()).toTimestampArray(null); + break; + case STRUCT: + case TYPE_CODE_UNSPECIFIED: + case UNRECOGNIZED: + default: + throw new IllegalArgumentException( + "Unknown or invalid array parameter type: " + + entry.getValue().getArrayElementType().getCode()); + } + break; case BOOL: builder.bind(entry.getKey()).to((Boolean) null); break; @@ -1097,7 +1192,72 @@ private Statement buildStatement( } else { switch (entry.getValue().getCode()) { case ARRAY: - throw new IllegalArgumentException("Array parameters not (yet) supported"); + switch (entry.getValue().getArrayElementType().getCode()) { + case BOOL: + builder + .bind(entry.getKey()) + .toBoolArray( + (Iterable) + GrpcStruct.decodeArrayValue( + com.google.cloud.spanner.Type.bool(), value.getListValue())); + break; + case BYTES: + builder + .bind(entry.getKey()) + .toBytesArray( + (Iterable) + GrpcStruct.decodeArrayValue( + com.google.cloud.spanner.Type.bytes(), value.getListValue())); + break; + case DATE: + builder + .bind(entry.getKey()) + .toDateArray( + (Iterable) + GrpcStruct.decodeArrayValue( + com.google.cloud.spanner.Type.date(), value.getListValue())); + break; + case FLOAT64: + builder + .bind(entry.getKey()) + .toFloat64Array( + (Iterable) + GrpcStruct.decodeArrayValue( + com.google.cloud.spanner.Type.float64(), value.getListValue())); + break; + case INT64: + builder + .bind(entry.getKey()) + .toInt64Array( + (Iterable) + GrpcStruct.decodeArrayValue( + com.google.cloud.spanner.Type.int64(), value.getListValue())); + break; + case STRING: + builder + .bind(entry.getKey()) + .toStringArray( + (Iterable) + GrpcStruct.decodeArrayValue( + com.google.cloud.spanner.Type.string(), value.getListValue())); + break; + case TIMESTAMP: + builder + .bind(entry.getKey()) + .toTimestampArray( + (Iterable) + GrpcStruct.decodeArrayValue( + com.google.cloud.spanner.Type.timestamp(), value.getListValue())); + break; + case STRUCT: + case TYPE_CODE_UNSPECIFIED: + case UNRECOGNIZED: + default: + throw new IllegalArgumentException( + "Unknown or invalid array parameter type: " + + entry.getValue().getArrayElementType().getCode()); + } + break; case BOOL: builder.bind(entry.getKey()).to(value.getBoolValue()); break; @@ -1119,6 +1279,9 @@ private Statement buildStatement( case STRUCT: throw new IllegalArgumentException("Struct parameters not (yet) supported"); case TIMESTAMP: + builder + .bind(entry.getKey()) + .to(com.google.cloud.Timestamp.parseTimestamp(value.getStringValue())); break; case TYPE_CODE_UNSPECIFIED: case UNRECOGNIZED: @@ -1200,12 +1363,12 @@ public Iterator iterator() { return request.getColumnsList().iterator(); } }; - StatementResult res = - statementResults.get( - StatementResult.createReadStatement( - request.getTable(), - request.getKeySet().getAll() ? KeySet.all() : KeySet.singleKey(Key.of()), - cols)); + Statement statement = + StatementResult.createReadStatement( + request.getTable(), + request.getKeySet().getAll() ? KeySet.all() : KeySet.singleKey(Key.of()), + cols); + StatementResult res = getResult(statement); returnResultSet( res.getResultSet(), transactionId, request.getTransaction(), responseObserver); responseObserver.onCompleted(); @@ -1250,12 +1413,17 @@ public Iterator iterator() { return request.getColumnsList().iterator(); } }; - StatementResult res = - statementResults.get( - StatementResult.createReadStatement( - request.getTable(), - request.getKeySet().getAll() ? KeySet.all() : KeySet.singleKey(Key.of()), - cols)); + Statement statement = + StatementResult.createReadStatement( + request.getTable(), + request.getKeySet().getAll() ? KeySet.all() : KeySet.singleKey(Key.of()), + cols); + StatementResult res = getResult(statement); + if (res == null) { + throw Status.NOT_FOUND + .withDescription("No result found for " + statement.toString()) + .asRuntimeException(); + } returnPartialResultSet( res.getResultSet(), transactionId, request.getTransaction(), responseObserver); } catch (StatusRuntimeException e) { @@ -1496,7 +1664,10 @@ private void ensureMostRecentTransaction(Session session, ByteString transaction throw Status.FAILED_PRECONDITION .withDescription( String.format( - "This transaction has been invalidated by a later transaction in the same session.", + "This transaction has been invalidated by a later transaction in the same session.\nTransaction id: " + + id + + "\nExpected: " + + counter.get(), session.getName())) .asRuntimeException(); } @@ -1662,6 +1833,37 @@ public List getRequests() { return new ArrayList<>(this.requests); } + public Iterable> getRequestTypes() { + List> res = new LinkedList<>(); + for (AbstractMessage m : this.requests) { + res.add(m.getClass()); + } + return res; + } + + public int countRequestsOfType(Class type) { + int c = 0; + for (AbstractMessage m : this.requests) { + if (m.getClass().equals(type)) { + c++; + } + } + return c; + } + + public void waitForLastRequestToBe(Class type, long timeoutMillis) + throws InterruptedException, TimeoutException { + Stopwatch watch = Stopwatch.createStarted(); + while (!(this.requests.peekLast() != null + && this.requests.peekLast().getClass().equals(type))) { + Thread.sleep(10L); + if (watch.elapsed(TimeUnit.MILLISECONDS) > timeoutMillis) { + throw new TimeoutException( + "Timeout while waiting for last request to become " + type.getName()); + } + } + } + @Override public void addResponse(AbstractMessage response) { throw new UnsupportedOperationException(); @@ -1672,6 +1874,10 @@ public void addException(Exception exception) { exceptions.add(exception); } + public void clearExceptions() { + exceptions.clear(); + } + public void setStickyGlobalExceptions(boolean sticky) { this.stickyGlobalExceptions = sticky; } @@ -1684,18 +1890,21 @@ public ServerServiceDefinition getServiceDefinition() { /** Removes all sessions and transactions. Mocked results are not removed. */ @Override public void reset() { - requests.clear(); - sessions.clear(); + requests = new ConcurrentLinkedDeque<>(); + exceptions = new ConcurrentLinkedQueue<>(); + statementGetCounts = new ConcurrentHashMap<>(); + sessions = new ConcurrentHashMap<>(); + sessionLastUsed = new ConcurrentHashMap<>(); + transactions = new ConcurrentHashMap<>(); + isPartitionedDmlTransaction = new ConcurrentHashMap<>(); + abortedTransactions = new ConcurrentHashMap<>(); + transactionCounters = new ConcurrentHashMap<>(); + partitionTokens = new ConcurrentHashMap<>(); + transactionLastUsed = new ConcurrentHashMap<>(); + numSessionsCreated.set(0); - sessionLastUsed.clear(); - transactions.clear(); - isPartitionedDmlTransaction.clear(); - abortedTransactions.clear(); - transactionCounters.clear(); - partitionTokens.clear(); - transactionLastUsed.clear(); - exceptions.clear(); stickyGlobalExceptions = false; + freezeLock.countDown(); } public void removeAllExecutionTimes() { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerTestUtil.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerTestUtil.java new file mode 100644 index 0000000000..cc6784b679 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerTestUtil.java @@ -0,0 +1,151 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.cloud.spanner.Type.StructField; +import com.google.common.collect.ContiguousSet; +import com.google.protobuf.ListValue; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.TypeCode; +import java.util.Arrays; + +public class MockSpannerTestUtil { + static final Statement SELECT1 = Statement.of("SELECT 1 AS COL1"); + private static final ResultSetMetadata SELECT1_METADATA = + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("COL1") + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.INT64) + .build()) + .build()) + .build()) + .build(); + static final com.google.spanner.v1.ResultSet SELECT1_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("1").build()) + .build()) + .setMetadata(SELECT1_METADATA) + .build(); + + static final String TEST_PROJECT = "my-project"; + static final String TEST_INSTANCE = "my-instance"; + static final String TEST_DATABASE = "my-database"; + + static final Statement UPDATE_STATEMENT = Statement.of("UPDATE FOO SET BAR=1 WHERE BAZ=2"); + static final Statement INVALID_UPDATE_STATEMENT = + Statement.of("UPDATE NON_EXISTENT_TABLE SET BAR=1 WHERE BAZ=2"); + static final Statement UPDATE_ABORTED_STATEMENT = + Statement.of("UPDATE FOO SET BAR=1 WHERE BAZ=2 AND THIS_WILL_ABORT=TRUE"); + static final long UPDATE_COUNT = 1L; + + static final String READ_TABLE_NAME = "TestTable"; + static final String EMPTY_READ_TABLE_NAME = "EmptyTestTable"; + static final Iterable READ_COLUMN_NAMES = Arrays.asList("Key", "Value"); + static final Statement READ_ONE_KEY_VALUE_STATEMENT = + Statement.of("SELECT Key, Value FROM TestTable WHERE ID=1"); + static final Statement READ_MULTIPLE_KEY_VALUE_STATEMENT = + Statement.of("SELECT Key, Value FROM TestTable WHERE 1=1"); + static final Statement READ_ONE_EMPTY_KEY_VALUE_STATEMENT = + Statement.of("SELECT Key, Value FROM EmptyTestTable WHERE ID=1"); + static final Statement READ_ALL_EMPTY_KEY_VALUE_STATEMENT = + Statement.of("SELECT Key, Value FROM EmptyTestTable WHERE 1=1"); + static final ResultSetMetadata READ_KEY_VALUE_METADATA = + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("Key") + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.STRING) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("Value") + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.STRING) + .build()) + .build()) + .build()) + .build(); + static final Type READ_TABLE_TYPE = + Type.struct(StructField.of("Key", Type.string()), StructField.of("Value", Type.string())); + static final com.google.spanner.v1.ResultSet EMPTY_KEY_VALUE_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows(ListValue.newBuilder().build()) + .setMetadata(READ_KEY_VALUE_METADATA) + .build(); + static final com.google.spanner.v1.ResultSet READ_ONE_KEY_VALUE_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("k1").build()) + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("v1").build()) + .build()) + .setMetadata(READ_KEY_VALUE_METADATA) + .build(); + static final com.google.spanner.v1.ResultSet READ_MULTIPLE_KEY_VALUE_RESULTSET = + generateKeyValueResultSet(ContiguousSet.closed(1, 3)); + + static com.google.spanner.v1.ResultSet generateKeyValueResultSet(Iterable rows) { + com.google.spanner.v1.ResultSet.Builder builder = com.google.spanner.v1.ResultSet.newBuilder(); + for (Integer row : rows) { + builder.addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("k" + row).build()) + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("v" + row).build()) + .build()); + } + return builder.setMetadata(READ_KEY_VALUE_METADATA).build(); + } + + static final ResultSetMetadata READ_FIRST_NAME_SINGERS_METADATA = + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("FirstName") + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.STRING) + .build()) + .build()) + .build()) + .build(); + static final com.google.spanner.v1.ResultSet READ_FIRST_NAME_SINGERS_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues( + com.google.protobuf.Value.newBuilder().setStringValue("FirstName").build()) + .build()) + .setMetadata(READ_FIRST_NAME_SINGERS_METADATA) + .build(); +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RandomResultSetGenerator.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RandomResultSetGenerator.java new file mode 100644 index 0000000000..63bc234a41 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RandomResultSetGenerator.java @@ -0,0 +1,166 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.client.util.Base64; +import com.google.cloud.Date; +import com.google.cloud.Timestamp; +import com.google.protobuf.ListValue; +import com.google.protobuf.NullValue; +import com.google.protobuf.Value; +import com.google.protobuf.util.Timestamps; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.util.Random; + +public class RandomResultSetGenerator { + private static final Type TYPES[] = + new Type[] { + Type.newBuilder().setCode(TypeCode.BOOL).build(), + Type.newBuilder().setCode(TypeCode.INT64).build(), + Type.newBuilder().setCode(TypeCode.FLOAT64).build(), + Type.newBuilder().setCode(TypeCode.STRING).build(), + Type.newBuilder().setCode(TypeCode.BYTES).build(), + Type.newBuilder().setCode(TypeCode.DATE).build(), + Type.newBuilder().setCode(TypeCode.TIMESTAMP).build(), + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.BOOL)) + .build(), + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.INT64)) + .build(), + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.FLOAT64)) + .build(), + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.STRING)) + .build(), + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.BYTES)) + .build(), + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.DATE)) + .build(), + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.TIMESTAMP)) + .build(), + }; + + private static final ResultSetMetadata generateMetadata() { + StructType.Builder rowTypeBuilder = StructType.newBuilder(); + for (int col = 0; col < TYPES.length; col++) { + rowTypeBuilder.addFields(Field.newBuilder().setName("COL" + col).setType(TYPES[col])).build(); + } + ResultSetMetadata.Builder builder = ResultSetMetadata.newBuilder(); + builder.setRowType(rowTypeBuilder.build()); + return builder.build(); + } + + private static final ResultSetMetadata METADATA = generateMetadata(); + + private final int rowCount; + private final Random random = new Random(); + + public RandomResultSetGenerator(int rowCount) { + this.rowCount = rowCount; + } + + public ResultSet generate() { + ResultSet.Builder builder = ResultSet.newBuilder(); + for (int row = 0; row < rowCount; row++) { + ListValue.Builder rowBuilder = ListValue.newBuilder(); + for (int col = 0; col < TYPES.length; col++) { + Value.Builder valueBuilder = Value.newBuilder(); + setRandomValue(valueBuilder, TYPES[col]); + rowBuilder.addValues(valueBuilder.build()); + } + builder.addRows(rowBuilder.build()); + } + builder.setMetadata(METADATA); + return builder.build(); + } + + private void setRandomValue(Value.Builder builder, Type type) { + if (randomNull()) { + builder.setNullValue(NullValue.NULL_VALUE); + } else { + switch (type.getCode()) { + case ARRAY: + int length = random.nextInt(20) + 1; + ListValue.Builder arrayBuilder = ListValue.newBuilder(); + for (int i = 0; i < length; i++) { + Value.Builder valueBuilder = Value.newBuilder(); + setRandomValue(valueBuilder, type.getArrayElementType()); + arrayBuilder.addValues(valueBuilder.build()); + } + builder.setListValue(arrayBuilder.build()); + break; + case BOOL: + builder.setBoolValue(random.nextBoolean()); + break; + case STRING: + case BYTES: + byte[] bytes = new byte[random.nextInt(200)]; + random.nextBytes(bytes); + builder.setStringValue(Base64.encodeBase64String(bytes)); + break; + case DATE: + Date date = + Date.fromYearMonthDay( + random.nextInt(2019) + 1, random.nextInt(11) + 1, random.nextInt(28) + 1); + builder.setStringValue(date.toString()); + break; + case FLOAT64: + builder.setNumberValue(random.nextDouble()); + break; + case INT64: + builder.setStringValue(String.valueOf(random.nextLong())); + break; + case TIMESTAMP: + com.google.protobuf.Timestamp ts = + Timestamps.add( + Timestamps.EPOCH, + com.google.protobuf.Duration.newBuilder() + .setSeconds(random.nextInt(100_000_000)) + .setNanos(random.nextInt(1000_000_000)) + .build()); + builder.setStringValue(Timestamp.fromProto(ts).toString()); + break; + case STRUCT: + case TYPE_CODE_UNSPECIFIED: + case UNRECOGNIZED: + default: + throw new IllegalArgumentException("Unknown or unsupported type: " + type.getCode()); + } + } + } + + private boolean randomNull() { + return random.nextInt(10) == 0; + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java new file mode 100644 index 0000000000..13e4c47d08 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java @@ -0,0 +1,510 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.MockSpannerTestUtil.*; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.api.core.ApiFunction; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.grpc.testing.LocalChannelProvider; +import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.common.base.Function; +import com.google.common.collect.ContiguousSet; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.inprocess.InProcessServerBuilder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.SynchronousQueue; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ReadAsyncTest { + private static MockSpannerServiceImpl mockSpanner; + private static Server server; + private static LocalChannelProvider channelProvider; + + private static ExecutorService executor; + private Spanner spanner; + private DatabaseClient client; + + @BeforeClass + public static void setup() throws Exception { + mockSpanner = new MockSpannerServiceImpl(); + mockSpanner.putStatementResult( + StatementResult.query(READ_ONE_KEY_VALUE_STATEMENT, READ_ONE_KEY_VALUE_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.query(READ_ONE_EMPTY_KEY_VALUE_STATEMENT, EMPTY_KEY_VALUE_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.query( + READ_MULTIPLE_KEY_VALUE_STATEMENT, READ_MULTIPLE_KEY_VALUE_RESULTSET)); + + String uniqueName = InProcessServerBuilder.generateName(); + server = + InProcessServerBuilder.forName(uniqueName) + .scheduledExecutorService(new ScheduledThreadPoolExecutor(1)) + .addService(mockSpanner) + .build() + .start(); + channelProvider = LocalChannelProvider.create(uniqueName); + executor = Executors.newScheduledThreadPool(8); + } + + @AfterClass + public static void teardown() throws Exception { + executor.shutdown(); + server.shutdown(); + server.awaitTermination(); + } + + @Before + public void before() { + spanner = + SpannerOptions.newBuilder() + .setProjectId(TEST_PROJECT) + .setChannelProvider(channelProvider) + .setCredentials(NoCredentials.getInstance()) + .setSessionPoolOption( + SessionPoolOptions.newBuilder().setFailOnSessionLeak().setMinSessions(0).build()) + .build() + .getService(); + client = spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + } + + @After + public void after() { + spanner.close(); + mockSpanner.removeAllExecutionTimes(); + } + + @Test + public void readAsyncPropagatesError() throws Exception { + ApiFuture result; + try (AsyncResultSet resultSet = + client + .singleUse(TimestampBound.strong()) + .readAsync(EMPTY_READ_TABLE_NAME, KeySet.singleKey(Key.of("k99")), READ_COLUMN_NAMES)) { + result = + resultSet.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.CANCELLED, "Don't want the data"); + } + }); + } + try { + result.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.CANCELLED); + assertThat(se.getMessage()).contains("Don't want the data"); + } + } + + @Test + public void emptyReadAsync() throws Exception { + ApiFuture result; + try (AsyncResultSet resultSet = + client + .singleUse(TimestampBound.strong()) + .readAsync(EMPTY_READ_TABLE_NAME, KeySet.singleKey(Key.of("k99")), READ_COLUMN_NAMES)) { + result = + resultSet.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + while (true) { + switch (resultSet.tryNext()) { + case OK: + fail("received unexpected data"); + case NOT_READY: + return CallbackResponse.CONTINUE; + case DONE: + assertThat(resultSet.getType()).isEqualTo(READ_TABLE_TYPE); + return CallbackResponse.DONE; + } + } + } + }); + } + assertThat(result.get()).isNull(); + } + + @Test + public void pointReadAsync() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync(READ_TABLE_NAME, Key.of("k1"), READ_COLUMN_NAMES); + assertThat(row.get()).isNotNull(); + assertThat(row.get().getString(0)).isEqualTo("k1"); + assertThat(row.get().getString(1)).isEqualTo("v1"); + } + + @Test + public void pointReadNotFound() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync(EMPTY_READ_TABLE_NAME, Key.of("k999"), READ_COLUMN_NAMES); + assertThat(row.get()).isNull(); + } + + @Test + public void invalidDatabase() throws Exception { + mockSpanner.setBatchCreateSessionsExecutionTime( + SimulatedExecutionTime.stickyDatabaseNotFoundException("invalid-database")); + DatabaseClient invalidClient = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, "invalid-database")); + ApiFuture row = + invalidClient + .singleUse(TimestampBound.strong()) + .readRowAsync(READ_TABLE_NAME, Key.of("k99"), READ_COLUMN_NAMES); + try { + row.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(DatabaseNotFoundException.class); + } + } + + @Test + public void tableNotFound() throws Exception { + mockSpanner.setStreamingReadExecutionTime( + SimulatedExecutionTime.ofException( + Status.NOT_FOUND + .withDescription("Table not found: BadTableName") + .asRuntimeException())); + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync("BadTableName", Key.of("k1"), READ_COLUMN_NAMES); + try { + row.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + assertThat(se.getMessage()).contains("BadTableName"); + } + } + + /** + * Ending a read-only transaction before an asynchronous query that was executed on that + * transaction has finished fetching all rows should keep the session checked out of the pool + * until all the rows have been returned. The session is then automatically returned to the + * session. + */ + @Test + public void closeTransactionBeforeEndOfAsyncQuery() throws Exception { + final BlockingQueue results = new SynchronousQueue<>(); + final SettableApiFuture finished = SettableApiFuture.create(); + ApiFuture closed; + DatabaseClientImpl clientImpl = (DatabaseClientImpl) client; + + // There should currently not be any sessions checked out of the pool. + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(0); + + final CountDownLatch dataReceived = new CountDownLatch(1); + try (ReadOnlyTransaction tx = client.readOnlyTransaction()) { + try (AsyncResultSet rs = + tx.readAsync(READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES, Options.bufferRows(1))) { + closed = + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + finished.set(true); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + dataReceived.countDown(); + results.put(resultSet.getString(0)); + } + } + } catch (Throwable t) { + finished.setException(t); + return CallbackResponse.DONE; + } + } + }); + } + // Wait until at least one row has been fetched. At that moment there should be one session + // checked out. + dataReceived.await(); + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(1); + } + // The read-only transaction is now closed, but the ready callback will continue to receive + // data. As it tries to put the data into a synchronous queue and the underlying buffer can also + // only hold 1 row, the async result set has not yet finished. The read-only transaction will + // release the session back into the pool when all async statements have finished. The number of + // sessions in use is therefore still 1. + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(1); + List resultList = new ArrayList<>(); + do { + results.drainTo(resultList); + } while (!finished.isDone() || results.size() > 0); + assertThat(finished.get()).isTrue(); + assertThat(resultList).containsExactly("k1", "k2", "k3"); + // The session will be released back into the pool by the asynchronous result set when it has + // returned all rows. As this is done in the background, it could take a couple of milliseconds. + closed.get(); + assertThat(clientImpl.pool.getNumberOfSessionsInUse()).isEqualTo(0); + } + + @Test + public void readOnlyTransaction() throws Exception { + Statement statement1 = + Statement.of("SELECT * FROM TestTable WHERE Key IN ('k10', 'k11', 'k12')"); + Statement statement2 = Statement.of("SELECT * FROM TestTable WHERE Key IN ('k1', 'k2', 'k3"); + mockSpanner.putStatementResult( + StatementResult.query(statement1, generateKeyValueResultSet(ContiguousSet.closed(10, 12)))); + mockSpanner.putStatementResult( + StatementResult.query(statement2, generateKeyValueResultSet(ContiguousSet.closed(1, 3)))); + + ApiFuture> values1; + ApiFuture> values2; + try (ReadOnlyTransaction tx = client.readOnlyTransaction()) { + try (AsyncResultSet rs = tx.executeQueryAsync(statement1)) { + values1 = + rs.toListAsync( + new Function() { + @Override + public String apply(StructReader input) { + return input.getString("Value"); + } + }, + executor); + } + try (AsyncResultSet rs = tx.executeQueryAsync(statement2)) { + values2 = + rs.toListAsync( + new Function() { + @Override + public String apply(StructReader input) { + return input.getString("Value"); + } + }, + executor); + } + } + ApiFuture> allValues = + ApiFutures.transform( + ApiFutures.allAsList(Arrays.asList(values1, values2)), + new ApiFunction>, Iterable>() { + @Override + public Iterable apply(List> input) { + return Iterables.mergeSorted( + input, + new Comparator() { + @Override + public int compare(String o1, String o2) { + // Return in numerical order (i.e. without the preceding 'v'). + return Integer.valueOf(o1.substring(1)) + .compareTo(Integer.valueOf(o2.substring(1))); + } + }); + } + }, + executor); + assertThat(allValues.get()).containsExactly("v1", "v2", "v3", "v10", "v11", "v12"); + } + + @Test + public void pauseResume() throws Exception { + Statement unevenStatement = + Statement.of("SELECT * FROM TestTable WHERE MOD(CAST(SUBSTR(Key, 2) AS INT64), 2) = 1"); + Statement evenStatement = + Statement.of("SELECT * FROM TestTable WHERE MOD(CAST(SUBSTR(Key, 2) AS INT64), 2) = 0"); + mockSpanner.putStatementResult( + StatementResult.query( + unevenStatement, generateKeyValueResultSet(ImmutableSet.of(1, 3, 5, 7, 9)))); + mockSpanner.putStatementResult( + StatementResult.query( + evenStatement, generateKeyValueResultSet(ImmutableSet.of(2, 4, 6, 8, 10)))); + + final Object lock = new Object(); + ApiFuture evenFinished; + ApiFuture unevenFinished; + final CountDownLatch unevenReturnedFirstRow = new CountDownLatch(1); + final Deque allValues = new ConcurrentLinkedDeque<>(); + try (ReadOnlyTransaction tx = client.readOnlyTransaction()) { + try (AsyncResultSet evenRs = tx.executeQueryAsync(evenStatement); + AsyncResultSet unevenRs = tx.executeQueryAsync(unevenStatement)) { + unevenFinished = + unevenRs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + synchronized (lock) { + allValues.add(resultSet.getString("Value")); + } + unevenReturnedFirstRow.countDown(); + return CallbackResponse.PAUSE; + } + } + } + }); + evenFinished = + evenRs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + // Make sure the uneven result set has returned the first before we start the + // even + // results. + unevenReturnedFirstRow.await(); + while (true) { + switch (resultSet.tryNext()) { + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + synchronized (lock) { + allValues.add(resultSet.getString("Value")); + } + return CallbackResponse.PAUSE; + } + } + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } + } + }); + while (!(evenFinished.isDone() && unevenFinished.isDone())) { + synchronized (lock) { + if (allValues.peekLast() != null) { + if (Integer.valueOf(allValues.peekLast().substring(1)) % 2 == 1) { + evenRs.resume(); + } else { + unevenRs.resume(); + } + } + if (allValues.size() == 10) { + unevenRs.resume(); + evenRs.resume(); + } + } + } + } + } + assertThat(ApiFutures.allAsList(Arrays.asList(evenFinished, unevenFinished)).get()) + .containsExactly(null, null); + assertThat(allValues) + .containsExactly("v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10"); + } + + @Test + public void cancel() throws Exception { + final List values = new LinkedList<>(); + final CountDownLatch receivedFirstRow = new CountDownLatch(1); + final CountDownLatch cancelled = new CountDownLatch(1); + final ApiFuture res; + try (AsyncResultSet rs = + client.singleUse().readAsync(READ_TABLE_NAME, KeySet.all(), READ_COLUMN_NAMES)) { + res = + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + values.add(resultSet.getString("Value")); + receivedFirstRow.countDown(); + cancelled.await(); + break; + } + } + } catch (Throwable t) { + return CallbackResponse.DONE; + } + } + }); + receivedFirstRow.await(); + rs.cancel(); + } + cancelled.countDown(); + try { + res.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.CANCELLED); + assertThat(values).containsExactly("v1"); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RetryOnInvalidatedSessionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RetryOnInvalidatedSessionTest.java index 29f442a761..7380791eed 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RetryOnInvalidatedSessionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RetryOnInvalidatedSessionTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.api.core.ApiFuture; import com.google.api.gax.core.NoCredentialsProvider; import com.google.api.gax.grpc.testing.LocalChannelProvider; import com.google.cloud.NoCredentials; @@ -28,7 +29,9 @@ import com.google.cloud.spanner.v1.SpannerClient; import com.google.cloud.spanner.v1.SpannerClient.ListSessionsPagedResponse; import com.google.cloud.spanner.v1.SpannerSettings; +import com.google.common.base.Function; import com.google.common.base.Stopwatch; +import com.google.common.collect.ImmutableList; import com.google.protobuf.ListValue; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.StructType; @@ -41,6 +44,9 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.AfterClass; @@ -54,6 +60,14 @@ @RunWith(Parameterized.class) public class RetryOnInvalidatedSessionTest { + private static final class ToLongTransformer implements Function { + @Override + public Long apply(StructReader input) { + return input.getLong(0); + } + } + + private static final ToLongTransformer TO_LONG = new ToLongTransformer(); @Parameter(0) public boolean failOnInvalidatedSession; @@ -138,6 +152,7 @@ public static Collection data() { private static SpannerClient spannerClient; private static Spanner spanner; private static DatabaseClient client; + private static ExecutorService executor; @BeforeClass public static void startStaticServer() throws IOException { @@ -166,6 +181,7 @@ public static void startStaticServer() throws IOException { .setCredentialsProvider(NoCredentialsProvider.create()) .build(); spannerClient = SpannerClient.create(settings); + executor = Executors.newSingleThreadExecutor(); } @AfterClass @@ -173,13 +189,16 @@ public static void stopServer() throws InterruptedException { spannerClient.close(); server.shutdown(); server.awaitTermination(); + executor.shutdown(); } @Before public void setUp() { mockSpanner.reset(); SessionPoolOptions.Builder builder = - SessionPoolOptions.newBuilder().setWriteSessionsFraction(WRITE_SESSIONS_FRACTION); + SessionPoolOptions.newBuilder() + .setWriteSessionsFraction(WRITE_SESSIONS_FRACTION) + .setFailOnSessionLeak(); if (failOnInvalidatedSession) { builder.setFailIfSessionNotFound(); } @@ -253,6 +272,20 @@ public void singleUseSelect() throws InterruptedException { } } + @Test + public void singleUseSelectAsync() throws Exception { + invalidateSessionPool(); + ApiFuture> list; + try (AsyncResultSet rs = client.singleUse().executeQueryAsync(SELECT1AND2)) { + list = rs.toListAsync(TO_LONG, executor); + assertThat(list.get()).containsExactly(1L, 2L); + assertThat(failOnInvalidatedSession).isFalse(); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SessionNotFoundException.class); + assertThat(failOnInvalidatedSession).isTrue(); + } + } + @Test public void singleUseRead() throws InterruptedException { invalidateSessionPool(); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index be4179f21b..c756a7898a 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.api.core.ApiFutures; import com.google.api.core.NanoClock; import com.google.api.gax.retrying.RetrySettings; import com.google.cloud.Timestamp; @@ -41,6 +42,7 @@ import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.Session; import com.google.spanner.v1.Transaction; +import io.opencensus.trace.Span; import java.text.ParseException; import java.util.Arrays; import java.util.Calendar; @@ -98,16 +100,17 @@ public void setUp() { .thenReturn(sessionProto); Transaction txn = Transaction.newBuilder().setId(ByteString.copyFromUtf8("TEST")).build(); Mockito.when( - rpc.beginTransaction( + rpc.beginTransactionAsync( Mockito.any(BeginTransactionRequest.class), Mockito.any(Map.class))) - .thenReturn(txn); + .thenReturn(ApiFutures.immediateFuture(txn)); CommitResponse commitResponse = CommitResponse.newBuilder() .setCommitTimestamp(com.google.protobuf.Timestamp.getDefaultInstance()) .build(); - Mockito.when(rpc.commit(Mockito.any(CommitRequest.class), Mockito.any(Map.class))) - .thenReturn(commitResponse); + Mockito.when(rpc.commitAsync(Mockito.any(CommitRequest.class), Mockito.any(Map.class))) + .thenReturn(ApiFutures.immediateFuture(commitResponse)); session = spanner.getSessionClient(db).createSession(); + ((SessionImpl) session).setCurrentSpan(mock(Span.class)); // We expect the same options, "options", on all calls on "session". options = optionsCaptor.getValue(); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java index 8007ce8385..0e72b2b9bc 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java @@ -25,6 +25,7 @@ import com.google.cloud.spanner.SessionClient.SessionConsumer; import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.SessionPool.PooledSessionFuture; import com.google.cloud.spanner.SessionPool.SessionConsumerImpl; import com.google.common.base.Function; import java.util.ArrayList; @@ -200,7 +201,7 @@ public void testKeepAlive() throws Exception { assertThat(pingedSessions).containsExactly(session1.getName(), 2, session2.getName(), 3); // Update the last use date and release the session to the pool and do another maintenance // cycle. - ((PooledSession) session6).markUsed(); + ((PooledSessionFuture) session6).get().markUsed(); session6.close(); runMaintainanceLoop(clock, pool, 3); assertThat(pingedSessions).containsExactly(session1.getName(), 2, session2.getName(), 3); @@ -261,9 +262,9 @@ public void testIdleSessions() throws Exception { // Now check out three sessions so the pool will create an additional session. The pool will // only keep 2 sessions alive, as that is the setting for MinSessions. - Session session3 = pool.getReadSession(); - Session session4 = pool.getReadSession(); - Session session5 = pool.getReadSession(); + Session session3 = pool.getReadSession().get(); + Session session4 = pool.getReadSession().get(); + Session session5 = pool.getReadSession().get(); // Note that session2 was now the first session in the pool as it was the last to receive a // ping. assertThat(session3.getName()).isEqualTo(session2.getName()); @@ -278,9 +279,9 @@ public void testIdleSessions() throws Exception { assertThat(pool.totalSessions()).isEqualTo(2); // Check out three sessions again and keep one session checked out. - Session session6 = pool.getReadSession(); - Session session7 = pool.getReadSession(); - Session session8 = pool.getReadSession(); + Session session6 = pool.getReadSession().get(); + Session session7 = pool.getReadSession().get(); + Session session8 = pool.getReadSession().get(); session8.close(); session7.close(); // Now advance the clock to idle sessions. This should remove session8 from the pool. @@ -292,9 +293,9 @@ public void testIdleSessions() throws Exception { // Check out three sessions and keep them all checked out. No sessions should be removed from // the pool. - Session session9 = pool.getReadSession(); - Session session10 = pool.getReadSession(); - Session session11 = pool.getReadSession(); + Session session9 = pool.getReadSession().get(); + Session session10 = pool.getReadSession().get(); + Session session11 = pool.getReadSession().get(); runMaintainanceLoop(clock, pool, loopsToIdleSessions); assertThat(idledSessions).containsExactly(session5, session8); assertThat(pool.totalSessions()).isEqualTo(3); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java index e5f5dff463..b806f5fad6 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java @@ -17,7 +17,7 @@ package com.google.cloud.spanner; import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.any; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -26,9 +26,11 @@ import com.google.api.core.ApiFutures; import com.google.cloud.spanner.SessionClient.SessionConsumer; import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.SessionPool.PooledSessionFuture; import com.google.cloud.spanner.SessionPool.SessionConsumerImpl; import com.google.common.base.Function; import com.google.common.util.concurrent.Uninterruptibles; +import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import java.util.ArrayList; import java.util.Collection; @@ -39,6 +41,8 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; @@ -66,6 +70,7 @@ public class SessionPoolStressTest extends BaseSessionPoolTest { DatabaseId db = DatabaseId.of("projects/p/instances/i/databases/unused"); SessionPool pool; SessionPoolOptions options; + ExecutorService createExecutor = Executors.newSingleThreadExecutor(); Object lock = new Object(); Random random = new Random(); FakeClock clock = new FakeClock(); @@ -97,43 +102,31 @@ private void setupSpanner(DatabaseId db) { SessionClient sessionClient = mock(SessionClient.class); when(mockSpanner.getSessionClient(db)).thenReturn(sessionClient); when(mockSpanner.getOptions()).thenReturn(spannerOptions); - when(sessionClient.createSession()) - .thenAnswer( - new Answer() { - - @Override - public Session answer(InvocationOnMock invocation) { - synchronized (lock) { - SessionImpl session = mockSession(); - setupSession(session); - - sessions.put(session.getName(), false); - if (sessions.size() > maxAliveSessions) { - maxAliveSessions = sessions.size(); - } - return session; - } - } - }); doAnswer( new Answer() { @Override - public Void answer(InvocationOnMock invocation) { - int sessionCount = invocation.getArgumentAt(0, Integer.class); - for (int s = 0; s < sessionCount; s++) { - synchronized (lock) { - SessionImpl session = mockSession(); - setupSession(session); - - sessions.put(session.getName(), false); - if (sessions.size() > maxAliveSessions) { - maxAliveSessions = sessions.size(); - } - SessionConsumerImpl consumer = - invocation.getArgumentAt(2, SessionConsumerImpl.class); - consumer.onSessionReady(session); - } - } + public Void answer(final InvocationOnMock invocation) { + createExecutor.submit( + new Runnable() { + @Override + public void run() { + int sessionCount = invocation.getArgumentAt(0, Integer.class); + for (int s = 0; s < sessionCount; s++) { + SessionImpl session; + synchronized (lock) { + session = mockSession(); + setupSession(session); + sessions.put(session.getName(), false); + if (sessions.size() > maxAliveSessions) { + maxAliveSessions = sessions.size(); + } + } + SessionConsumerImpl consumer = + invocation.getArgumentAt(2, SessionConsumerImpl.class); + consumer.onSessionReady(session); + } + } + }); return null; } }) @@ -189,36 +182,43 @@ public Void answer(InvocationOnMock invocation) { expireSession(session); throw SpannerExceptionFactoryTest.newSessionNotFoundException(session.getName()); } + String name = session.getName(); synchronized (lock) { - if (sessions.put(session.getName(), true)) { + if (sessions.put(name, true)) { setFailed(); } + session.readyTransactionId = ByteString.copyFromUtf8("foo"); } return null; } }) .when(session) .prepareReadWriteTransaction(); + when(session.hasReadyTransaction()).thenCallRealMethod(); } private void expireSession(Session session) { + String name = session.getName(); synchronized (lock) { - sessions.remove(session.getName()); - expiredSessions.add(session.getName()); + sessions.remove(name); + expiredSessions.add(name); } } private void assertWritePrepared(Session session) { + String name = session.getName(); synchronized (lock) { - if (!sessions.get(session.getName())) { + if (!sessions.containsKey(name) || !sessions.get(name)) { setFailed(); } } } - private void resetTransaction(Session session) { + private void resetTransaction(SessionImpl session) { + String name = session.getName(); synchronized (lock) { - sessions.put(session.getName(), false); + session.readyTransactionId = null; + sessions.put(name, false); } } @@ -264,8 +264,9 @@ public void stressTest() throws Exception { new Function() { @Override public Void apply(PooledSession pooled) { + String name = pooled.getName(); synchronized (lock) { - sessions.remove(pooled.getName()); + sessions.remove(name); return null; } } @@ -279,16 +280,18 @@ public void run() { Uninterruptibles.awaitUninterruptibly(releaseThreads); for (int j = 0; j < numOperationsPerThread; j++) { try { - Session session = null; + PooledSessionFuture session = null; if (random.nextInt(10) < writeOperationFraction) { session = pool.getReadWriteSession(); - assertWritePrepared(session); + PooledSession sess = session.get(); + assertWritePrepared(sess); } else { session = pool.getReadSession(); + session.get(); } Uninterruptibles.sleepUninterruptibly( random.nextInt(5), TimeUnit.MILLISECONDS); - resetTransaction(session); + resetTransaction(session.get().delegate); session.close(); } catch (SpannerException e) { if (e.getErrorCode() != ErrorCode.RESOURCE_EXHAUSTED || shouldBlock) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java index db56c8ff82..d5ea648bbd 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java @@ -47,6 +47,7 @@ import com.google.cloud.spanner.SessionClient.SessionConsumer; import com.google.cloud.spanner.SessionPool.Clock; import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.SessionPool.PooledSessionFuture; import com.google.cloud.spanner.SessionPool.SessionConsumerImpl; import com.google.cloud.spanner.SpannerImpl.ClosedException; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; @@ -59,12 +60,14 @@ import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.ExecuteBatchDmlRequest; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.RollbackRequest; import io.opencensus.metrics.LabelValue; import io.opencensus.metrics.MetricRegistry; +import io.opencensus.trace.Span; import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; @@ -204,21 +207,21 @@ public void sessionCreation() { public void poolLifo() { setupMockSessionCreation(); pool = createPool(); - Session session1 = pool.getReadSession(); - Session session2 = pool.getReadSession(); + Session session1 = pool.getReadSession().get(); + Session session2 = pool.getReadSession().get(); assertThat(session1).isNotEqualTo(session2); session2.close(); session1.close(); - Session session3 = pool.getReadSession(); - Session session4 = pool.getReadSession(); + Session session3 = pool.getReadSession().get(); + Session session4 = pool.getReadSession().get(); assertThat(session3).isEqualTo(session1); assertThat(session4).isEqualTo(session2); session3.close(); session4.close(); - Session session5 = pool.getReadWriteSession(); - Session session6 = pool.getReadWriteSession(); + Session session5 = pool.getReadWriteSession().get(); + Session session6 = pool.getReadWriteSession().get(); assertThat(session5).isEqualTo(session4); assertThat(session6).isEqualTo(session3); session6.close(); @@ -259,7 +262,7 @@ public void run() { pool = createPool(); Session session1 = pool.getReadSession(); // Leaked sessions - PooledSession leakedSession = pool.getReadSession(); + PooledSessionFuture leakedSession = pool.getReadSession(); // Clear the leaked exception to suppress logging of expected exceptions. leakedSession.clearLeakedException(); session1.close(); @@ -335,7 +338,7 @@ public Void call() throws Exception { .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); - PooledSession leakedSession = pool.getReadSession(); + PooledSessionFuture leakedSession = pool.getReadSession(); // Suppress expected leakedSession warning. leakedSession.clearLeakedException(); AtomicBoolean failed = new AtomicBoolean(false); @@ -344,7 +347,7 @@ public Void call() throws Exception { insideCreation.await(); pool.closeAsync(new SpannerImpl.ClosedException()); releaseCreation.countDown(); - latch.await(); + latch.await(5L, TimeUnit.SECONDS); assertThat(failed.get()).isTrue(); } @@ -393,7 +396,7 @@ public Void call() throws Exception { .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); - PooledSession leakedSession = pool.getReadSession(); + PooledSessionFuture leakedSession = pool.getReadSession(); // Suppress expected leakedSession warning. leakedSession.clearLeakedException(); AtomicBoolean failed = new AtomicBoolean(false); @@ -510,7 +513,8 @@ public void run() { .when(sessionClient) .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); - PooledSession leakedSession = pool.getReadSession(); + PooledSessionFuture leakedSession = pool.getReadSession(); + leakedSession.get(); // Suppress expected leakedSession warning. leakedSession.clearLeakedException(); pool.closeAsync(new SpannerImpl.ClosedException()); @@ -562,7 +566,7 @@ public Void call() { .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); try { - pool.getReadSession(); + pool.getReadSession().get(); fail("Expected exception"); } catch (SpannerException ex) { assertThat(ex.getErrorCode()).isEqualTo(ErrorCode.INTERNAL); @@ -593,7 +597,7 @@ public Void call() { .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); try { - pool.getReadWriteSession(); + pool.getReadWriteSession().get(); fail("Expected exception"); } catch (SpannerException ex) { assertThat(ex.getErrorCode()).isEqualTo(ErrorCode.INTERNAL); @@ -626,7 +630,7 @@ public void run() { .prepareReadWriteTransaction(); pool = createPool(); try { - pool.getReadWriteSession(); + pool.getReadWriteSession().get(); fail("Expected exception"); } catch (SpannerException ex) { assertThat(ex.getErrorCode()).isEqualTo(ErrorCode.INTERNAL); @@ -655,14 +659,15 @@ public void run() { .when(sessionClient) .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); - try (Session session = pool.getReadWriteSession()) { + try (PooledSessionFuture session = pool.getReadWriteSession()) { assertThat(session).isNotNull(); + session.get(); verify(mockSession).prepareReadWriteTransaction(); } } @Test - public void getMultipleReadWriteSessions() { + public void getMultipleReadWriteSessions() throws Exception { SessionImpl mockSession1 = mockSession(); SessionImpl mockSession2 = mockSession(); final LinkedList sessions = @@ -686,8 +691,10 @@ public void run() { .when(sessionClient) .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); - Session session1 = pool.getReadWriteSession(); - Session session2 = pool.getReadWriteSession(); + PooledSessionFuture session1 = pool.getReadWriteSession(); + PooledSessionFuture session2 = pool.getReadWriteSession(); + session1.get(); + session2.get(); verify(mockSession1).prepareReadWriteTransaction(); verify(mockSession2).prepareReadWriteTransaction(); session1.close(); @@ -782,8 +789,8 @@ public void run() { pool = createPool(); // One of the sessions would be pre prepared. Uninterruptibles.awaitUninterruptibly(prepareLatch); - PooledSession readSession = pool.getReadSession(); - PooledSession writeSession = pool.getReadWriteSession(); + PooledSession readSession = pool.getReadSession().get(); + PooledSession writeSession = pool.getReadWriteSession().get(); verify(writeSession.delegate, times(1)).prepareReadWriteTransaction(); verify(readSession.delegate, never()).prepareReadWriteTransaction(); readSession.close(); @@ -832,7 +839,7 @@ public void run() { pool.getReadWriteSession().close(); prepareLatch.await(); // This session should also be write prepared. - PooledSession readSession = pool.getReadSession(); + PooledSession readSession = pool.getReadSession().get(); verify(readSession.delegate, times(2)).prepareReadWriteTransaction(); } @@ -904,7 +911,7 @@ public void run() { .when(sessionClient) .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class)); pool = createPool(); - assertThat(pool.getReadWriteSession().delegate).isEqualTo(mockSession2); + assertThat(pool.getReadWriteSession().get().delegate).isEqualTo(mockSession2); } @Test @@ -949,9 +956,14 @@ public void run() { pool.getReadSession().close(); runMaintainanceLoop(clock, pool, pool.poolMaintainer.numClosureCycles); assertThat(pool.numIdleSessionsRemoved()).isEqualTo(0L); - Session readSession1 = pool.getReadSession(); - Session readSession2 = pool.getReadSession(); - Session readSession3 = pool.getReadSession(); + PooledSessionFuture readSession1 = pool.getReadSession(); + PooledSessionFuture readSession2 = pool.getReadSession(); + PooledSessionFuture readSession3 = pool.getReadSession(); + // Wait until the sessions have actually been gotten in order to make sure they are in use in + // parallel. + readSession1.get(); + readSession2.get(); + readSession3.get(); readSession1.close(); readSession2.close(); readSession3.close(); @@ -1005,8 +1017,10 @@ public void run() { FakeClock clock = new FakeClock(); clock.currentTimeMillis = System.currentTimeMillis(); pool = createPool(clock); - Session session1 = pool.getReadSession(); - Session session2 = pool.getReadSession(); + PooledSessionFuture session1 = pool.getReadSession(); + PooledSessionFuture session2 = pool.getReadSession(); + session1.get(); + session2.get(); session1.close(); session2.close(); runMaintainanceLoop(clock, pool, pool.poolMaintainer.numKeepAliveCycles); @@ -1133,7 +1147,8 @@ public void blockAndTimeoutOnPoolExhaustion() throws Exception { setupMockSessionCreation(); pool = createPool(); // Take the only session that can be in the pool. - Session checkedOutSession = pool.getReadSession(); + PooledSessionFuture checkedOutSession = pool.getReadSession(); + checkedOutSession.get(); final Boolean finWrite = write; ExecutorService executor = Executors.newFixedThreadPool(1); final CountDownLatch latch = new CountDownLatch(1); @@ -1143,7 +1158,7 @@ public void blockAndTimeoutOnPoolExhaustion() throws Exception { new Callable() { @Override public Void call() { - Session session; + PooledSessionFuture session; latch.countDown(); if (finWrite) { session = pool.getReadWriteSession(); @@ -1326,7 +1341,8 @@ public void testSessionNotFoundReadWriteTransaction() { .thenThrow(sessionNotFound); when(rpc.executeBatchDml(any(ExecuteBatchDmlRequest.class), any(Map.class))) .thenThrow(sessionNotFound); - when(rpc.commit(any(CommitRequest.class), any(Map.class))).thenThrow(sessionNotFound); + when(rpc.commitAsync(any(CommitRequest.class), any(Map.class))) + .thenReturn(ApiFutures.immediateFailedFuture(sessionNotFound)); doThrow(sessionNotFound).when(rpc).rollback(any(RollbackRequest.class), any(Map.class)); final SessionImpl closedSession = mock(SessionImpl.class); when(closedSession.getName()) @@ -1342,9 +1358,10 @@ public void testSessionNotFoundReadWriteTransaction() { when(closedSession.asyncClose()) .thenReturn(ApiFutures.immediateFuture(Empty.getDefaultInstance())); when(closedSession.newTransaction()).thenReturn(closedTransactionContext); - when(closedSession.beginTransaction()).thenThrow(sessionNotFound); + when(closedSession.beginTransactionAsync()).thenThrow(sessionNotFound); TransactionRunnerImpl closedTransactionRunner = new TransactionRunnerImpl(closedSession, rpc, 10); + closedTransactionRunner.setSpan(mock(Span.class)); when(closedSession.readWriteTransaction()).thenReturn(closedTransactionRunner); final SessionImpl openSession = mock(SessionImpl.class); @@ -1354,9 +1371,11 @@ public void testSessionNotFoundReadWriteTransaction() { .thenReturn("projects/dummy/instances/dummy/database/dummy/sessions/session-open"); final TransactionContextImpl openTransactionContext = mock(TransactionContextImpl.class); when(openSession.newTransaction()).thenReturn(openTransactionContext); - when(openSession.beginTransaction()).thenReturn(ByteString.copyFromUtf8("open-txn")); + when(openSession.beginTransactionAsync()) + .thenReturn(ApiFutures.immediateFuture(ByteString.copyFromUtf8("open-txn"))); TransactionRunnerImpl openTransactionRunner = new TransactionRunnerImpl(openSession, mock(SpannerRpc.class), 10); + openTransactionRunner.setSpan(mock(Span.class)); when(openSession.readWriteTransaction()).thenReturn(openTransactionRunner); ResultSet openResultSet = mock(ResultSet.class); @@ -1422,7 +1441,7 @@ public void run() { SessionPool pool = SessionPool.createPool( options, new TestExecutorFactory(), spanner.getSessionClient(db)); - try (PooledSession readWriteSession = pool.getReadWriteSession()) { + try (PooledSessionFuture readWriteSession = pool.getReadWriteSession()) { TransactionRunner runner = readWriteSession.readWriteTransaction(); try { runner.run( @@ -1546,7 +1565,7 @@ public void run() { FakeClock clock = new FakeClock(); clock.currentTimeMillis = System.currentTimeMillis(); pool = createPool(clock); - PooledSession session = pool.getReadWriteSession(); + PooledSession session = pool.getReadWriteSession().get(); assertThat(session.delegate).isEqualTo(openSession); } @@ -1726,8 +1745,10 @@ public void testSessionMetrics() throws Exception { setupMockSessionCreation(); pool = createPool(clock, metricRegistry, labelValues); - Session session1 = pool.getReadSession(); - Session session2 = pool.getReadSession(); + PooledSessionFuture session1 = pool.getReadSession(); + PooledSessionFuture session2 = pool.getReadSession(); + session1.get(); + session2.get(); MetricsRecord record = metricRegistry.pollRecord(); assertThat(record.getMetrics().size()).isEqualTo(6); @@ -1861,7 +1882,8 @@ private void getSessionAsync(final CountDownLatch latch, final AtomicBoolean fai new Runnable() { @Override public void run() { - try (Session session = pool.getReadSession()) { + try (PooledSessionFuture future = pool.getReadSession()) { + PooledSession session = future.get(); failed.compareAndSet(false, session == null); Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); } catch (Throwable e) { @@ -1879,7 +1901,8 @@ private void getReadWriteSessionAsync(final CountDownLatch latch, final AtomicBo new Runnable() { @Override public void run() { - try (Session session = pool.getReadWriteSession()) { + try (PooledSessionFuture future = pool.getReadWriteSession()) { + PooledSession session = future.get(); failed.compareAndSet(false, session == null); Uninterruptibles.sleepUninterruptibly(2, TimeUnit.MILLISECONDS); } catch (SpannerException e) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanTest.java index 9bbbdcea82..7dcc9b65e1 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanTest.java @@ -324,6 +324,7 @@ public Void run(TransactionContext transaction) { } Map spans = failOnOverkillTraceComponent.getSpans(); + assertThat(spans.size()).isEqualTo(5); assertThat(spans).containsEntry("CloudSpanner.ReadWriteTransaction", true); assertThat(spans).containsEntry("CloudSpannerOperation.BatchCreateSessions", true); assertThat(spans).containsEntry("SessionPool.WaitForSession", true); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerExceptionFactoryTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerExceptionFactoryTest.java index bc7dd5498d..49cbfb905d 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerExceptionFactoryTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerExceptionFactoryTest.java @@ -52,6 +52,11 @@ static DatabaseNotFoundException newDatabaseNotFoundException(String name) { "Database", SpannerExceptionFactory.DATABASE_RESOURCE_TYPE, name); } + static StatusRuntimeException newStatusDatabaseNotFoundException(String name) { + return newStatusResourceNotFoundException( + "Database", SpannerExceptionFactory.DATABASE_RESOURCE_TYPE, name); + } + static InstanceNotFoundException newInstanceNotFoundException(String name) { return (InstanceNotFoundException) newResourceNotFoundException( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerGaxRetryTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerGaxRetryTest.java index 05729d6f53..b98702f87c 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerGaxRetryTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerGaxRetryTest.java @@ -321,6 +321,13 @@ public void readWriteTransactionTimeout() { mockSpanner.setBeginTransactionExecutionTime(ONE_SECOND); try { TransactionRunner runner = clientWithTimeout.readWriteTransaction(); + runner.run( + new TransactionCallable() { + @Override + public Void run(TransactionContext transaction) throws Exception { + return null; + } + }); fail("Expected exception"); } catch (SpannerException ex) { assertEquals(ErrorCode.DEADLINE_EXCEEDED, ex.getErrorCode()); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerMatchers.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerMatchers.java index 9662047867..4723497a47 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerMatchers.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerMatchers.java @@ -21,6 +21,7 @@ import com.google.protobuf.Message; import com.google.protobuf.TextFormat; import java.lang.reflect.InvocationTargetException; +import java.util.concurrent.ExecutionException; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.hamcrest.Matcher; @@ -47,6 +48,15 @@ public static Matcher isSpannerException(ErrorCode code return new SpannerExceptionMatcher<>(code); } + /** + * Returns a method that checks that a {@link Throwable} is an {@link ExecutionException} where + * the cause is a {@link SpannerException} with an error code to {@code code}. + */ + public static Matcher isExecutionExceptionWithSpannerCause( + ErrorCode code) { + return new ExecutionExceptionWithSpannerCauseMatcher<>(code); + } + private static class ProtoTextMatcher extends BaseMatcher { private final T expected; @@ -110,4 +120,31 @@ public void describeTo(Description description) { description.appendText("SpannerException[" + expectedCode + "]"); } } + + private static class ExecutionExceptionWithSpannerCauseMatcher + extends BaseMatcher { + private final ErrorCode expectedCode; + + ExecutionExceptionWithSpannerCauseMatcher(ErrorCode expectedCode) { + this.expectedCode = checkNotNull(expectedCode); + } + + @Override + public boolean matches(Object item) { + if (!(item instanceof ExecutionException)) { + return false; + } + ExecutionException ee = (ExecutionException) item; + if (!(ee.getCause() instanceof SpannerException)) { + return false; + } + SpannerException e = (SpannerException) ee.getCause(); + return e.getErrorCode() == expectedCode; + } + + @Override + public void describeTo(Description description) { + description.appendText("ExecutionException[SpannerException[" + expectedCode + "]]"); + } + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java index ba569653c3..38aa66516e 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java @@ -26,6 +26,7 @@ import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; +import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; import com.google.cloud.Timestamp; import com.google.cloud.grpc.GrpcTransportOptions; @@ -38,6 +39,7 @@ import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.Transaction; +import io.opencensus.trace.Span; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -75,7 +77,7 @@ public void release(ScheduledExecutorService exec) { @Before public void setUp() { initMocks(this); - manager = new TransactionManagerImpl(session); + manager = new TransactionManagerImpl(session, mock(Span.class)); } @Test @@ -246,26 +248,29 @@ public List answer(InvocationOnMock invocation) { .build()); } }); - when(rpc.beginTransaction(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap())) + when(rpc.beginTransactionAsync(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap())) .thenAnswer( - new Answer() { + new Answer>() { @Override - public Transaction answer(InvocationOnMock invocation) { - return Transaction.newBuilder() - .setId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) - .build(); + public ApiFuture answer(InvocationOnMock invocation) { + return ApiFutures.immediateFuture( + Transaction.newBuilder() + .setId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .build()); } }); - when(rpc.commit(Mockito.any(CommitRequest.class), Mockito.anyMap())) + when(rpc.commitAsync(Mockito.any(CommitRequest.class), Mockito.anyMap())) .thenAnswer( - new Answer() { + new Answer>() { @Override - public CommitResponse answer(InvocationOnMock invocation) { - return CommitResponse.newBuilder() - .setCommitTimestamp( - com.google.protobuf.Timestamp.newBuilder() - .setSeconds(System.currentTimeMillis() * 1000)) - .build(); + public ApiFuture answer(InvocationOnMock invocation) + throws Throwable { + return ApiFutures.immediateFuture( + CommitResponse.newBuilder() + .setCommitTimestamp( + com.google.protobuf.Timestamp.newBuilder() + .setSeconds(System.currentTimeMillis() * 1000)) + .build()); } }); DatabaseId db = DatabaseId.of("test", "test", "test"); @@ -276,7 +281,7 @@ public CommitResponse answer(InvocationOnMock invocation) { mgr.commit(); } verify(rpc, times(1)) - .beginTransaction(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap()); + .beginTransactionAsync(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap()); } } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java index 074f0a905c..d61c89300f 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; import com.google.cloud.grpc.GrpcTransportOptions; import com.google.cloud.grpc.GrpcTransportOptions.ExecutorFactory; @@ -50,6 +51,7 @@ import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.protobuf.ProtoUtils; +import io.opencensus.trace.Span; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -95,6 +97,13 @@ public void setUp() { firstRun = true; when(session.newTransaction()).thenReturn(txn); transactionRunner = new TransactionRunnerImpl(session, rpc, 1); + when(rpc.commitAsync(Mockito.any(CommitRequest.class), Mockito.anyMap())) + .thenReturn( + ApiFutures.immediateFuture( + CommitResponse.newBuilder() + .setCommitTimestamp(Timestamp.getDefaultInstance()) + .build())); + transactionRunner.setSpan(mock(Span.class)); } @SuppressWarnings("unchecked") @@ -126,25 +135,28 @@ public List answer(InvocationOnMock invocation) { .build()); } }); - when(rpc.beginTransaction(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap())) + when(rpc.beginTransactionAsync(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap())) .thenAnswer( - new Answer() { + new Answer>() { @Override - public Transaction answer(InvocationOnMock invocation) { - return Transaction.newBuilder() - .setId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) - .build(); + public ApiFuture answer(InvocationOnMock invocation) { + return ApiFutures.immediateFuture( + Transaction.newBuilder() + .setId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .build()); } }); - when(rpc.commit(Mockito.any(CommitRequest.class), Mockito.anyMap())) + when(rpc.commitAsync(Mockito.any(CommitRequest.class), Mockito.anyMap())) .thenAnswer( - new Answer() { + new Answer>() { @Override - public CommitResponse answer(InvocationOnMock invocation) { - return CommitResponse.newBuilder() - .setCommitTimestamp( - Timestamp.newBuilder().setSeconds(System.currentTimeMillis() * 1000)) - .build(); + public ApiFuture answer(InvocationOnMock invocation) + throws Throwable { + return ApiFutures.immediateFuture( + CommitResponse.newBuilder() + .setCommitTimestamp( + Timestamp.newBuilder().setSeconds(System.currentTimeMillis() * 1000)) + .build()); } }); DatabaseId db = DatabaseId.of("test", "test", "test"); @@ -160,7 +172,7 @@ public Void run(TransactionContext transaction) { } }); verify(rpc, times(1)) - .beginTransaction(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap()); + .beginTransactionAsync(Mockito.any(BeginTransactionRequest.class), Mockito.anyMap()); } } @@ -272,10 +284,12 @@ private long[] batchDmlException(int status) { .setRpc(rpc) .build(); when(session.newTransaction()).thenReturn(transaction); - when(session.beginTransaction()) - .thenReturn(ByteString.copyFromUtf8(UUID.randomUUID().toString())); + when(session.beginTransactionAsync()) + .thenReturn( + ApiFutures.immediateFuture(ByteString.copyFromUtf8(UUID.randomUUID().toString()))); when(session.getName()).thenReturn(SessionId.of("p", "i", "d", "test").getName()); TransactionRunnerImpl runner = new TransactionRunnerImpl(session, rpc, 10); + runner.setSpan(mock(Span.class)); ExecuteBatchDmlResponse response1 = ExecuteBatchDmlResponse.newBuilder() .addResultSets( @@ -300,7 +314,8 @@ private long[] batchDmlException(int status) { .thenReturn(response1, response2); CommitResponse commitResponse = CommitResponse.newBuilder().setCommitTimestamp(Timestamp.getDefaultInstance()).build(); - when(rpc.commit(Mockito.any(CommitRequest.class), Mockito.anyMap())).thenReturn(commitResponse); + when(rpc.commitAsync(Mockito.any(CommitRequest.class), Mockito.anyMap())) + .thenReturn(ApiFutures.immediateFuture(commitResponse)); final Statement statement = Statement.of("UPDATE FOO SET BAR=1"); final AtomicInteger numCalls = new AtomicInteger(0); long updateCount[] = diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadOnlyTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadOnlyTransactionTest.java index aac55bd9f8..118f596c86 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadOnlyTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadOnlyTransactionTest.java @@ -26,7 +26,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.api.core.ApiFuture; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncResultSet; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Key; @@ -131,6 +133,34 @@ public void close() {} public Timestamp getReadTimestamp() { return readTimestamp; } + + @Override + public AsyncResultSet readAsync( + String table, KeySet keys, Iterable columns, ReadOption... options) { + return null; + } + + @Override + public AsyncResultSet readUsingIndexAsync( + String table, String index, KeySet keys, Iterable columns, ReadOption... options) { + return null; + } + + @Override + public ApiFuture readRowAsync(String table, Key key, Iterable columns) { + return null; + } + + @Override + public ApiFuture readRowUsingIndexAsync( + String table, String index, Key key, Iterable columns) { + return null; + } + + @Override + public AsyncResultSet executeQueryAsync(Statement statement, QueryOption... options) { + return null; + } } private ReadOnlyTransaction createSubject() { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SingleUseTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SingleUseTransactionTest.java index 27ee1903fa..e73eb8e0b2 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SingleUseTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SingleUseTransactionTest.java @@ -24,8 +24,10 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.api.core.ApiFuture; import com.google.api.gax.longrunning.OperationFuture; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AsyncResultSet; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Key; @@ -230,6 +232,34 @@ public void close() {} public Timestamp getReadTimestamp() { return readTimestamp; } + + @Override + public AsyncResultSet readAsync( + String table, KeySet keys, Iterable columns, ReadOption... options) { + return null; + } + + @Override + public AsyncResultSet readUsingIndexAsync( + String table, String index, KeySet keys, Iterable columns, ReadOption... options) { + return null; + } + + @Override + public ApiFuture readRowAsync(String table, Key key, Iterable columns) { + return null; + } + + @Override + public ApiFuture readRowUsingIndexAsync( + String table, String index, Key key, Iterable columns) { + return null; + } + + @Override + public AsyncResultSet executeQueryAsync(Statement statement, QueryOption... options) { + return null; + } } private DdlClient createDefaultMockDdlClient() { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITAsyncAPITest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITAsyncAPITest.java new file mode 100644 index 0000000000..a2239aa3b9 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITAsyncAPITest.java @@ -0,0 +1,309 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.it; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; + +import com.google.api.core.ApiFuture; +import com.google.cloud.spanner.AsyncResultSet; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.cloud.spanner.AsyncRunner; +import com.google.cloud.spanner.AsyncRunner.AsyncWork; +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.IntegrationTest; +import com.google.cloud.spanner.IntegrationTestEnv; +import com.google.cloud.spanner.Key; +import com.google.cloud.spanner.KeyRange; +import com.google.cloud.spanner.KeySet; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.TimestampBound; +import com.google.cloud.spanner.TransactionContext; +import com.google.cloud.spanner.Type; +import com.google.cloud.spanner.Type.StructField; +import com.google.cloud.spanner.testing.RemoteSpannerHelper; +import com.google.common.util.concurrent.SettableFuture; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Integration tests for asynchronous APIs. */ +@Category(IntegrationTest.class) +@RunWith(JUnit4.class) +public class ITAsyncAPITest { + @ClassRule public static IntegrationTestEnv env = new IntegrationTestEnv(); + private static final String TABLE_NAME = "TestTable"; + private static final String INDEX_NAME = "TestTableByValue"; + private static final List ALL_COLUMNS = Arrays.asList("Key", "StringValue"); + private static final Type TABLE_TYPE = + Type.struct( + StructField.of("Key", Type.string()), StructField.of("StringValue", Type.string())); + + private static Database db; + private static DatabaseClient client; + private static ExecutorService executor; + + @BeforeClass + public static void setUpDatabase() { + db = + env.getTestHelper() + .createTestDatabase( + "CREATE TABLE TestTable (" + + " Key STRING(MAX) NOT NULL," + + " StringValue STRING(MAX)," + + ") PRIMARY KEY (Key)", + "CREATE INDEX TestTableByValue ON TestTable(StringValue)", + "CREATE INDEX TestTableByValueDesc ON TestTable(StringValue DESC)"); + client = env.getTestHelper().getDatabaseClient(db); + + // Includes k0..k14. Note that strings k{10,14} sort between k1 and k2. + List mutations = new ArrayList<>(); + for (int i = 0; i < 15; ++i) { + mutations.add( + Mutation.newInsertOrUpdateBuilder(TABLE_NAME) + .set("Key") + .to("k" + i) + .set("StringValue") + .to("v" + i) + .build()); + } + client.write(mutations); + executor = Executors.newSingleThreadExecutor(); + } + + @AfterClass + public static void cleanup() { + executor.shutdown(); + } + + @Test + public void emptyReadAsync() throws Exception { + final SettableFuture result = SettableFuture.create(); + AsyncResultSet resultSet = + client + .singleUse(TimestampBound.strong()) + .readAsync( + TABLE_NAME, + KeySet.range(KeyRange.closedOpen(Key.of("k99"), Key.of("z"))), + ALL_COLUMNS); + resultSet.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case OK: + fail("received unexpected data"); + case NOT_READY: + return CallbackResponse.CONTINUE; + case DONE: + assertThat(resultSet.getType()).isEqualTo(TABLE_TYPE); + result.set(true); + return CallbackResponse.DONE; + } + } + } catch (Throwable t) { + result.setException(t); + return CallbackResponse.DONE; + } + } + }); + assertThat(result.get()).isTrue(); + } + + @Test + public void indexEmptyReadAsync() throws Exception { + final SettableFuture result = SettableFuture.create(); + AsyncResultSet resultSet = + client + .singleUse(TimestampBound.strong()) + .readUsingIndexAsync( + TABLE_NAME, + INDEX_NAME, + KeySet.range(KeyRange.closedOpen(Key.of("v99"), Key.of("z"))), + ALL_COLUMNS); + resultSet.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case OK: + fail("received unexpected data"); + case NOT_READY: + return CallbackResponse.CONTINUE; + case DONE: + assertThat(resultSet.getType()).isEqualTo(TABLE_TYPE); + result.set(true); + return CallbackResponse.DONE; + } + } + } catch (Throwable t) { + result.setException(t); + return CallbackResponse.DONE; + } + } + }); + assertThat(result.get()).isTrue(); + } + + @Test + public void pointReadAsync() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync(TABLE_NAME, Key.of("k1"), ALL_COLUMNS); + assertThat(row.get()).isNotNull(); + assertThat(row.get().getString(0)).isEqualTo("k1"); + assertThat(row.get().getString(1)).isEqualTo("v1"); + // Ensure that the Struct implementation supports equality properly. + assertThat(row.get()) + .isEqualTo(Struct.newBuilder().set("Key").to("k1").set("StringValue").to("v1").build()); + } + + @Test + public void indexPointReadAsync() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowUsingIndexAsync(TABLE_NAME, INDEX_NAME, Key.of("v1"), ALL_COLUMNS); + assertThat(row.get()).isNotNull(); + assertThat(row.get().getString(0)).isEqualTo("k1"); + assertThat(row.get().getString(1)).isEqualTo("v1"); + } + + @Test + public void pointReadNotFound() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync(TABLE_NAME, Key.of("k999"), ALL_COLUMNS); + assertThat(row.get()).isNull(); + } + + @Test + public void indexPointReadNotFound() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowUsingIndexAsync(TABLE_NAME, INDEX_NAME, Key.of("v999"), ALL_COLUMNS); + assertThat(row.get()).isNull(); + } + + @Test + public void invalidDatabase() throws Exception { + RemoteSpannerHelper helper = env.getTestHelper(); + DatabaseClient invalidClient = + helper.getClient().getDatabaseClient(DatabaseId.of(helper.getInstanceId(), "invalid")); + ApiFuture row = + invalidClient + .singleUse(TimestampBound.strong()) + .readRowAsync(TABLE_NAME, Key.of("k99"), ALL_COLUMNS); + try { + row.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + } + } + + @Test + public void tableNotFound() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync("BadTableName", Key.of("k1"), ALL_COLUMNS); + try { + row.get(); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + assertThat(se.getMessage()).contains("BadTableName"); + } + } + + @Test + public void columnNotFound() throws Exception { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync(TABLE_NAME, Key.of("k1"), Arrays.asList("Key", "BadColumnName")); + try { + row.get(); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + assertThat(se.getMessage()).contains("BadColumnName"); + } + } + + @Test + public void asyncRunnerFireAndForgetInvalidUpdate() throws Exception { + assumeFalse( + "errors in read/write transactions on emulator are sticky", + env.getTestHelper().isEmulator()); + try { + assertThat(client.singleUse().readRow("TestTable", Key.of("k999"), ALL_COLUMNS)).isNull(); + AsyncRunner runner = client.runAsync(); + ApiFuture res = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + // The error returned by this update statement will not bubble up and fail the + // transaction. + txn.executeUpdateAsync(Statement.of("UPDATE BadTableName SET FOO=1 WHERE ID=2")); + return txn.executeUpdateAsync( + Statement.of( + "INSERT INTO TestTable (Key, StringValue) VALUES ('k999', 'v999')")); + } + }, + executor); + assertThat(res.get()).isEqualTo(1L); + assertThat(client.singleUse().readRow("TestTable", Key.of("k999"), ALL_COLUMNS)).isNotNull(); + } finally { + client.writeAtLeastOnce(Arrays.asList(Mutation.delete("TestTable", Key.of("k999")))); + assertThat(client.singleUse().readRow("TestTable", Key.of("k999"), ALL_COLUMNS)).isNull(); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITAsyncExamplesTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITAsyncExamplesTest.java new file mode 100644 index 0000000000..c5e2419ba6 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITAsyncExamplesTest.java @@ -0,0 +1,550 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.it; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.api.core.ApiFunction; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.cloud.spanner.AsyncResultSet; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; +import com.google.cloud.spanner.AsyncRunner; +import com.google.cloud.spanner.AsyncRunner.AsyncWork; +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.IntegrationTest; +import com.google.cloud.spanner.IntegrationTestEnv; +import com.google.cloud.spanner.Key; +import com.google.cloud.spanner.KeySet; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.ReadOnlyTransaction; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.StructReader; +import com.google.cloud.spanner.TransactionContext; +import com.google.common.base.Function; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Integration tests for asynchronous APIs. */ +@Category(IntegrationTest.class) +@RunWith(JUnit4.class) +public class ITAsyncExamplesTest { + @ClassRule public static IntegrationTestEnv env = new IntegrationTestEnv(); + private static final String TABLE_NAME = "TestTable"; + private static final String INDEX_NAME = "TestTableByValue"; + private static final List ALL_COLUMNS = Arrays.asList("Key", "StringValue"); + private static final ImmutableList ALL_VALUES_IN_PK_ORDER = + ImmutableList.of( + "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9"); + + private static Database db; + private static DatabaseClient client; + private static ExecutorService executor; + + @BeforeClass + public static void setUpDatabase() { + db = + env.getTestHelper() + .createTestDatabase( + "CREATE TABLE TestTable (" + + " Key STRING(MAX) NOT NULL," + + " StringValue STRING(MAX)," + + ") PRIMARY KEY (Key)", + "CREATE INDEX TestTableByValue ON TestTable(StringValue)", + "CREATE INDEX TestTableByValueDesc ON TestTable(StringValue DESC)"); + client = env.getTestHelper().getDatabaseClient(db); + + // Includes k0..k14. Note that strings k{10,14} sort between k1 and k2. + List mutations = new ArrayList<>(); + for (int i = 0; i < 15; ++i) { + mutations.add( + Mutation.newInsertOrUpdateBuilder(TABLE_NAME) + .set("Key") + .to("k" + i) + .set("StringValue") + .to("v" + i) + .build()); + } + client.write(mutations); + executor = Executors.newScheduledThreadPool(8); + } + + @AfterClass + public static void cleanup() { + executor.shutdown(); + } + + @Test + public void readAsync() throws Exception { + final SettableApiFuture> future = SettableApiFuture.create(); + try (AsyncResultSet rs = client.singleUse().readAsync(TABLE_NAME, KeySet.all(), ALL_COLUMNS)) { + rs.setCallback( + executor, + new ReadyCallback() { + final List values = new LinkedList<>(); + + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + future.set(values); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + values.add(resultSet.getString("StringValue")); + break; + } + } + } catch (Throwable t) { + future.setException(t); + return CallbackResponse.DONE; + } + } + }); + } + assertThat(future.get()).containsExactlyElementsIn(ALL_VALUES_IN_PK_ORDER); + } + + @Test + public void readUsingIndexAsync() throws Exception { + final SettableApiFuture> future = SettableApiFuture.create(); + try (AsyncResultSet rs = + client.singleUse().readUsingIndexAsync(TABLE_NAME, INDEX_NAME, KeySet.all(), ALL_COLUMNS)) { + rs.setCallback( + executor, + new ReadyCallback() { + final List values = new LinkedList<>(); + + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + future.set(values); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + values.add(resultSet.getString("StringValue")); + break; + } + } + } catch (Throwable t) { + future.setException(t); + return CallbackResponse.DONE; + } + } + }); + } + assertThat(future.get()).containsExactlyElementsIn(ALL_VALUES_IN_PK_ORDER); + } + + @Test + public void readRowAsync() throws Exception { + ApiFuture row = client.singleUse().readRowAsync(TABLE_NAME, Key.of("k1"), ALL_COLUMNS); + assertThat(row.get().getString("StringValue")).isEqualTo("v1"); + } + + @Test + public void readRowUsingIndexAsync() throws Exception { + ApiFuture row = + client + .singleUse() + .readRowUsingIndexAsync(TABLE_NAME, INDEX_NAME, Key.of("v2"), ALL_COLUMNS); + assertThat(row.get().getString("Key")).isEqualTo("k2"); + } + + @Test + public void executeQueryAsync() throws Exception { + final ImmutableList keys = ImmutableList.of("k3", "k4"); + final SettableApiFuture> future = SettableApiFuture.create(); + try (AsyncResultSet rs = + client + .singleUse() + .executeQueryAsync( + Statement.newBuilder("SELECT StringValue FROM TestTable WHERE Key IN UNNEST(@keys)") + .bind("keys") + .toStringArray(keys) + .build())) { + rs.setCallback( + executor, + new ReadyCallback() { + final List values = new LinkedList<>(); + + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + future.set(values); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + values.add(resultSet.getString("StringValue")); + break; + } + } + } catch (Throwable t) { + future.setException(t); + return CallbackResponse.DONE; + } + } + }); + } + assertThat(future.get()).containsExactly("v3", "v4"); + } + + @Test + public void runAsync() throws Exception { + AsyncRunner runner = client.runAsync(); + ApiFuture insertCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + // Even though this is a shoot-and-forget asynchronous DML statement, it is + // guaranteed to be executed within the transaction before the commit is executed. + return txn.executeUpdateAsync( + Statement.newBuilder( + "INSERT INTO TestTable (Key, StringValue) VALUES (@key, @value)") + .bind("key") + .to("k999") + .bind("value") + .to("v999") + .build()); + } + }, + executor); + assertThat(insertCount.get()).isEqualTo(1L); + ApiFuture deleteCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return txn.executeUpdateAsync( + Statement.newBuilder("DELETE FROM TestTable WHERE Key=@key") + .bind("key") + .to("k999") + .build()); + } + }, + executor); + assertThat(deleteCount.get()).isEqualTo(1L); + } + + @Test + public void runAsyncBatchUpdate() throws Exception { + AsyncRunner runner = client.runAsync(); + ApiFuture insertCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + // Even though this is a shoot-and-forget asynchronous DML statement, it is + // guaranteed to be executed within the transaction before the commit is executed. + return txn.batchUpdateAsync( + ImmutableList.of( + Statement.newBuilder( + "INSERT INTO TestTable (Key, StringValue) VALUES (@key, @value)") + .bind("key") + .to("k997") + .bind("value") + .to("v997") + .build(), + Statement.newBuilder( + "INSERT INTO TestTable (Key, StringValue) VALUES (@key, @value)") + .bind("key") + .to("k998") + .bind("value") + .to("v998") + .build(), + Statement.newBuilder( + "INSERT INTO TestTable (Key, StringValue) VALUES (@key, @value)") + .bind("key") + .to("k999") + .bind("value") + .to("v999") + .build())); + } + }, + executor); + assertThat(insertCount.get()).asList().containsExactly(1L, 1L, 1L); + ApiFuture deleteCount = + runner.runAsync( + new AsyncWork() { + @Override + public ApiFuture doWorkAsync(TransactionContext txn) { + return txn.batchUpdateAsync( + ImmutableList.of( + Statement.newBuilder("DELETE FROM TestTable WHERE Key=@key") + .bind("key") + .to("k997") + .build(), + Statement.newBuilder("DELETE FROM TestTable WHERE Key=@key") + .bind("key") + .to("k998") + .build(), + Statement.newBuilder("DELETE FROM TestTable WHERE Key=@key") + .bind("key") + .to("k999") + .build())); + } + }, + executor); + assertThat(deleteCount.get()).asList().containsExactly(1L, 1L, 1L); + } + + @Test + public void readOnlyTransaction() throws Exception { + ImmutableList keys1 = ImmutableList.of("k10", "k11", "k12"); + ImmutableList keys2 = ImmutableList.of("k1", "k2", "k3"); + ApiFuture> values1; + ApiFuture> values2; + try (ReadOnlyTransaction tx = client.readOnlyTransaction()) { + try (AsyncResultSet rs = + tx.executeQueryAsync( + Statement.newBuilder("SELECT * FROM TestTable WHERE Key IN UNNEST(@keys)") + .bind("keys") + .toStringArray(keys1) + .build())) { + values1 = + rs.toListAsync( + new Function() { + @Override + public String apply(StructReader input) { + return input.getString("StringValue"); + } + }, + executor); + } + try (AsyncResultSet rs = + tx.executeQueryAsync( + Statement.newBuilder("SELECT * FROM TestTable WHERE Key IN UNNEST(@keys)") + .bind("keys") + .toStringArray(keys2) + .build())) { + values2 = + rs.toListAsync( + new Function() { + @Override + public String apply(StructReader input) { + return input.getString("StringValue"); + } + }, + executor); + } + } + ApiFuture> allValues = + ApiFutures.transform( + ApiFutures.allAsList(Arrays.asList(values1, values2)), + new ApiFunction>, Iterable>() { + @Override + public Iterable apply(List> input) { + return Iterables.mergeSorted( + input, + new Comparator() { + @Override + public int compare(String o1, String o2) { + // Compare based on numerical order (i.e. without the preceding 'v'). + return Integer.valueOf(o1.substring(1)) + .compareTo(Integer.valueOf(o2.substring(1))); + } + }); + } + }, + executor); + assertThat(allValues.get()).containsExactly("v1", "v2", "v3", "v10", "v11", "v12"); + } + + @Test + public void pauseResume() throws Exception { + Statement unevenStatement = + Statement.of( + "SELECT * FROM TestTable WHERE MOD(CAST(SUBSTR(Key, 2) AS INT64), 2) = 1 ORDER BY CAST(SUBSTR(Key, 2) AS INT64)"); + Statement evenStatement = + Statement.of( + "SELECT * FROM TestTable WHERE MOD(CAST(SUBSTR(Key, 2) AS INT64), 2) = 0 ORDER BY CAST(SUBSTR(Key, 2) AS INT64)"); + + final Object lock = new Object(); + final SettableApiFuture evenFinished = SettableApiFuture.create(); + final SettableApiFuture unevenFinished = SettableApiFuture.create(); + final CountDownLatch evenReturnedFirstRow = new CountDownLatch(1); + final Deque allValues = new LinkedList<>(); + try (ReadOnlyTransaction tx = client.readOnlyTransaction()) { + try (AsyncResultSet evenRs = tx.executeQueryAsync(evenStatement); + AsyncResultSet unevenRs = tx.executeQueryAsync(unevenStatement)) { + evenRs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + evenFinished.set(true); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + synchronized (lock) { + allValues.add(resultSet.getString("StringValue")); + } + evenReturnedFirstRow.countDown(); + return CallbackResponse.PAUSE; + } + } + } catch (Throwable t) { + evenFinished.setException(t); + return CallbackResponse.DONE; + } + } + }); + + unevenRs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + // Make sure the even result set has returned the first before we start the uneven + // results. + evenReturnedFirstRow.await(); + while (true) { + switch (resultSet.tryNext()) { + case DONE: + unevenFinished.set(true); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + synchronized (lock) { + allValues.add(resultSet.getString("StringValue")); + } + return CallbackResponse.PAUSE; + } + } + } catch (Throwable t) { + unevenFinished.setException(t); + return CallbackResponse.DONE; + } + } + }); + while (!(evenFinished.isDone() && unevenFinished.isDone())) { + synchronized (lock) { + if (allValues.peekLast() != null) { + if (Integer.valueOf(allValues.peekLast().substring(1)) % 2 == 1) { + evenRs.resume(); + } else { + unevenRs.resume(); + } + } + if (allValues.size() == 15) { + unevenRs.resume(); + evenRs.resume(); + } + } + } + } + } + assertThat(ApiFutures.allAsList(Arrays.asList(evenFinished, unevenFinished)).get()) + .containsExactly(Boolean.TRUE, Boolean.TRUE); + assertThat(allValues) + .containsExactly( + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14"); + } + + @Test + public void cancel() throws Exception { + final List values = new LinkedList<>(); + final SettableApiFuture finished = SettableApiFuture.create(); + final CountDownLatch receivedFirstRow = new CountDownLatch(1); + final CountDownLatch cancelled = new CountDownLatch(1); + try (AsyncResultSet rs = client.singleUse().readAsync(TABLE_NAME, KeySet.all(), ALL_COLUMNS)) { + rs.setCallback( + executor, + new ReadyCallback() { + @Override + public CallbackResponse cursorReady(AsyncResultSet resultSet) { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + finished.set(true); + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + values.add(resultSet.getString("StringValue")); + receivedFirstRow.countDown(); + cancelled.await(); + break; + } + } + } catch (Throwable t) { + finished.setException(t); + return CallbackResponse.DONE; + } + } + }); + receivedFirstRow.await(); + rs.cancel(); + } + cancelled.countDown(); + try { + finished.get(); + fail("missing expected exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.CANCELLED); + assertThat(values).containsExactly("v0"); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDatabaseTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDatabaseTest.java index fda120bff6..450c7463c5 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDatabaseTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDatabaseTest.java @@ -146,6 +146,7 @@ public void instanceNotFound() { .getClient() .getDatabaseClient(DatabaseId.of(nonExistingInstanceId, "some-db")); try (ResultSet rs = client.singleUse().executeQuery(Statement.of("SELECT 1"))) { + rs.next(); fail("missing expected exception"); } catch (InstanceNotFoundException e) { assertThat(e.getResourceName()).isEqualTo(nonExistingInstanceId.getName()); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadOnlyTxnTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadOnlyTxnTest.java index a52a108151..db68b5ec88 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadOnlyTxnTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadOnlyTxnTest.java @@ -310,9 +310,12 @@ public void multiReadTimestamp() { @Test public void multiMinReadTimestamp() { // Cannot use bounded modes with multi-read transactions. - try { - client.readOnlyTransaction(TimestampBound.ofMinReadTimestamp(history.get(2).timestamp)); - fail("Expected exception"); + try (ReadOnlyTransaction tx = + client.readOnlyTransaction(TimestampBound.ofMinReadTimestamp(history.get(2).timestamp))) { + try (ResultSet rs = tx.executeQuery(Statement.of("SELECT 1"))) { + rs.next(); + fail("Expected exception"); + } } catch (IllegalArgumentException ex) { assertNotNull(ex.getMessage()); } @@ -341,9 +344,12 @@ public void multiExactStaleness() { @Test public void multiMaxStaleness() { // Cannot use bounded modes with multi-read transactions. - try { - client.readOnlyTransaction(TimestampBound.ofMaxStaleness(1, TimeUnit.SECONDS)); - fail("Expected exception"); + try (ReadOnlyTransaction tx = + client.readOnlyTransaction(TimestampBound.ofMaxStaleness(1, TimeUnit.SECONDS))) { + try (ResultSet rs = tx.executeQuery(Statement.of("SELECT 1"))) { + rs.next(); + fail("Expected exception"); + } } catch (IllegalArgumentException ex) { assertNotNull(ex.getMessage()); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadTest.java index 4682ddd1ec..87c9e0ae3f 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITReadTest.java @@ -358,6 +358,7 @@ public void run() { try { work.run(); + fail("missing expected exception"); } catch (SpannerException e) { MatcherAssert.assertThat(e, isSpannerException(ErrorCode.CANCELLED)); } @@ -381,6 +382,7 @@ public void run() { try { work.run(); + fail("missing expected exception"); } catch (SpannerException e) { MatcherAssert.assertThat(e, isSpannerException(ErrorCode.DEADLINE_EXCEEDED)); } finally { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITTransactionManagerAsyncTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITTransactionManagerAsyncTest.java new file mode 100644 index 0000000000..c802493dec --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITTransactionManagerAsyncTest.java @@ -0,0 +1,318 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.it; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.cloud.spanner.AbortedException; +import com.google.cloud.spanner.AsyncTransactionManager; +import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionFunction; +import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionStep; +import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.IntegrationTestEnv; +import com.google.cloud.spanner.Key; +import com.google.cloud.spanner.KeySet; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.TransactionContext; +import com.google.cloud.spanner.TransactionManager.TransactionState; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.MoreExecutors; +import java.util.Arrays; +import java.util.Collection; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class ITTransactionManagerAsyncTest { + + @Parameter public Executor executor; + + @Parameters(name = "executor = {0}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {MoreExecutors.directExecutor()}, + {Executors.newSingleThreadExecutor()}, + {Executors.newFixedThreadPool(4)} + }); + } + + @ClassRule public static IntegrationTestEnv env = new IntegrationTestEnv(); + private static Database db; + private static DatabaseClient client; + + @BeforeClass + public static void setUpDatabase() { + // Empty database. + db = + env.getTestHelper() + .createTestDatabase( + "CREATE TABLE T (" + + " K STRING(MAX) NOT NULL," + + " BoolValue BOOL," + + ") PRIMARY KEY (K)"); + client = env.getTestHelper().getDatabaseClient(db); + } + + @Before + public void clearTable() { + client.write(ImmutableList.of(Mutation.delete("T", KeySet.all()))); + } + + @Test + public void testSimpleInsert() throws ExecutionException, InterruptedException { + try (AsyncTransactionManager manager = client.transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + assertThat(manager.getState()).isEqualTo(TransactionState.STARTED); + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + txn.buffer( + Mutation.newInsertBuilder("T") + .set("K") + .to("Key1") + .set("BoolValue") + .to(true) + .build()); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync() + .get(); + assertThat(manager.getState()).isEqualTo(TransactionState.COMMITTED); + Struct row = + client.singleUse().readRow("T", Key.of("Key1"), Arrays.asList("K", "BoolValue")); + assertThat(row.getString(0)).isEqualTo("Key1"); + assertThat(row.getBoolean(1)).isTrue(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + txn = manager.resetForRetryAsync(); + } + } + } + } + + @Test + public void testInvalidInsert() throws InterruptedException { + try (AsyncTransactionManager manager = client.transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + try { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + txn.buffer( + Mutation.newInsertBuilder("InvalidTable") + .set("K") + .to("Key1") + .set("BoolValue") + .to(true) + .build()); + return ApiFutures.immediateFuture(null); + } + }, + executor) + .commitAsync() + .get(); + fail("Expected exception"); + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + txn = manager.resetForRetryAsync(); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(SpannerException.class); + SpannerException se = (SpannerException) e.getCause(); + assertThat(se.getErrorCode()).isEqualTo(ErrorCode.NOT_FOUND); + // expected + break; + } + } + assertThat(manager.getState()).isEqualTo(TransactionState.COMMIT_FAILED); + // We cannot retry for non aborted errors. + try { + manager.resetForRetryAsync(); + fail("Expected exception"); + } catch (IllegalStateException ex) { + assertNotNull(ex.getMessage()); + } + } + } + + @Test + public void testRollback() throws InterruptedException { + try (AsyncTransactionManager manager = client.transactionManagerAsync()) { + TransactionContextFuture txn = manager.beginAsync(); + while (true) { + txn.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) throws Exception { + txn.buffer( + Mutation.newInsertBuilder("T") + .set("K") + .to("Key2") + .set("BoolValue") + .to(true) + .build()); + return ApiFutures.immediateFuture(null); + } + }, + executor); + try { + manager.rollbackAsync(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + txn = manager.resetForRetryAsync(); + } + } + assertThat(manager.getState()).isEqualTo(TransactionState.ROLLED_BACK); + // Row should not have been inserted. + assertThat(client.singleUse().readRow("T", Key.of("Key2"), Arrays.asList("K", "BoolValue"))) + .isNull(); + } + } + + @Test + public void testAbortAndRetry() throws InterruptedException, ExecutionException { + assumeFalse( + "Emulator does not support more than 1 simultanous transaction. " + + "This test would therefore loop indefinetly on the emulator.", + env.getTestHelper().isEmulator()); + + client.write( + Arrays.asList( + Mutation.newInsertBuilder("T").set("K").to("Key3").set("BoolValue").to(true).build())); + try (AsyncTransactionManager manager1 = client.transactionManagerAsync()) { + TransactionContextFuture txn1 = manager1.beginAsync(); + AsyncTransactionManager manager2; + TransactionContextFuture txn2; + AsyncTransactionStep txn2Step1; + while (true) { + try { + AsyncTransactionStep txn1Step1 = + txn1.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + return txn.readRowAsync("T", Key.of("Key3"), Arrays.asList("K", "BoolValue")); + } + }, + executor); + manager2 = client.transactionManagerAsync(); + txn2 = manager2.beginAsync(); + txn2Step1 = + txn2.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) + throws Exception { + return txn.readRowAsync("T", Key.of("Key3"), Arrays.asList("K", "BoolValue")); + } + }, + executor); + + AsyncTransactionStep txn1Step2 = + txn1Step1.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Struct input) + throws Exception { + txn.buffer( + Mutation.newUpdateBuilder("T") + .set("K") + .to("Key3") + .set("BoolValue") + .to(false) + .build()); + return ApiFutures.immediateFuture(null); + } + }, + executor); + + txn2Step1.get(); + txn1Step2.commitAsync().get(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + // It is possible that it was txn2 that aborted. + // In that case we should just retry without resetting anything. + if (manager1.getState() == TransactionState.ABORTED) { + txn1 = manager1.resetForRetryAsync(); + } + } + } + + // txn2 should have been aborted. + try { + txn2Step1.commitAsync().get(); + fail("Expected to abort"); + } catch (AbortedException e) { + assertThat(manager2.getState()).isEqualTo(TransactionState.ABORTED); + txn2 = manager2.resetForRetryAsync(); + } + AsyncTransactionStep txn2Step2 = + txn2.then( + new AsyncTransactionFunction() { + @Override + public ApiFuture apply(TransactionContext txn, Void input) throws Exception { + txn.buffer( + Mutation.newUpdateBuilder("T") + .set("K") + .to("Key3") + .set("BoolValue") + .to(true) + .build()); + return ApiFutures.immediateFuture(null); + } + }, + executor); + txn2Step2.commitAsync().get(); + Struct row = client.singleUse().readRow("T", Key.of("Key3"), Arrays.asList("K", "BoolValue")); + assertThat(row.getString(0)).isEqualTo("Key3"); + assertThat(row.getBoolean(1)).isTrue(); + manager2.close(); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITTransactionTest.java index a05029ad99..5e3c1483e7 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITTransactionTest.java @@ -477,7 +477,12 @@ public void nestedSingleUseReadTxnThrows() { new TransactionCallable() { @Override public Void run(TransactionContext transaction) throws SpannerException { - client.singleUseReadOnlyTransaction(); + try (ResultSet rs = + client + .singleUseReadOnlyTransaction() + .executeQuery(Statement.of("SELECT 1"))) { + rs.next(); + } return null; } diff --git a/versions.txt b/versions.txt index ba2b2e5e35..0d5258a9ce 100644 --- a/versions.txt +++ b/versions.txt @@ -7,4 +7,4 @@ proto-google-cloud-spanner-admin-database-v1:1.57.0:1.57.0 grpc-google-cloud-spanner-v1:1.57.0:1.57.0 grpc-google-cloud-spanner-admin-instance-v1:1.57.0:1.57.0 grpc-google-cloud-spanner-admin-database-v1:1.57.0:1.57.0 -google-cloud-spanner:1.57.0:1.57.0 \ No newline at end of file +google-cloud-spanner:1.57.0:1.57.0