diff --git a/google-cloud-spanner/pom.xml b/google-cloud-spanner/pom.xml index 3f7a8f3ce2..234ab270b4 100644 --- a/google-cloud-spanner/pom.xml +++ b/google-cloud-spanner/pom.xml @@ -203,6 +203,18 @@ com.google.api.grpc proto-google-cloud-spanner-admin-database-v1 + + com.google.api.grpc + grpc-google-cloud-spanner-admin-instance-v1 + + + com.google.api.grpc + grpc-google-cloud-spanner-v1 + + + com.google.api.grpc + grpc-google-cloud-spanner-admin-database-v1 + com.google.guava guava @@ -246,21 +258,6 @@ test - - com.google.api.grpc - grpc-google-cloud-spanner-v1 - test - - - com.google.api.grpc - grpc-google-cloud-spanner-admin-instance-v1 - test - - - com.google.api.grpc - grpc-google-cloud-spanner-admin-database-v1 - test - com.google.api diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index bc3f513ce0..6bd3d0f90c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -18,9 +18,11 @@ import com.google.api.core.ApiFunction; import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.grpc.GrpcCallContext; import com.google.api.gax.grpc.GrpcInterceptorProvider; import com.google.api.gax.longrunning.OperationTimedPollAlgorithm; import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.cloud.NoCredentials; import com.google.cloud.ServiceDefaults; @@ -29,6 +31,8 @@ import com.google.cloud.TransportOptions; import com.google.cloud.grpc.GrpcTransportOptions; import com.google.cloud.spanner.Options.QueryOption; +import com.google.cloud.spanner.SpannerOptions.CallContextConfigurator; +import com.google.cloud.spanner.SpannerOptions.SpannerCallContextTimeoutConfigurator; import com.google.cloud.spanner.admin.database.v1.DatabaseAdminSettings; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStubSettings; import com.google.cloud.spanner.admin.instance.v1.InstanceAdminSettings; @@ -44,11 +48,15 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.SpannerGrpc; import io.grpc.CallCredentials; import io.grpc.CompressorRegistry; +import io.grpc.Context; import io.grpc.ExperimentalApi; import io.grpc.ManagedChannelBuilder; +import io.grpc.MethodDescriptor; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; @@ -121,6 +129,324 @@ public static interface CallCredentialsProvider { CallCredentials getCallCredentials(); } + /** Context key for the {@link CallContextConfigurator} to use. */ + public static final Context.Key CALL_CONTEXT_CONFIGURATOR_KEY = + Context.key("call-context-configurator"); + + /** + * {@link CallContextConfigurator} can be used to modify the {@link ApiCallContext} for one or + * more specific RPCs. This can be used to set specific timeout value for RPCs or use specific + * {@link CallCredentials} for an RPC. The {@link CallContextConfigurator} must be set as a value + * on the {@link Context} using the {@link SpannerOptions#CALL_CONTEXT_CONFIGURATOR_KEY} key. + * + *

This API is meant for advanced users. Most users should instead use the {@link + * SpannerCallContextTimeoutConfigurator} for setting timeouts per RPC. + * + *

Example usage: + * + *

{@code
+   * CallContextConfigurator configurator =
+   *     new CallContextConfigurator() {
+   *       public  ApiCallContext configure(
+   *           ApiCallContext context, ReqT request, MethodDescriptor method) {
+   *         if (method == SpannerGrpc.getExecuteBatchDmlMethod()) {
+   *           return GrpcCallContext.createDefault()
+   *               .withCallOptions(CallOptions.DEFAULT.withDeadlineAfter(60L, TimeUnit.SECONDS));
+   *         }
+   *         return null;
+   *       }
+   *     };
+   * Context context =
+   *     Context.current().withValue(SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY, configurator);
+   * context.run(
+   *     new Runnable() {
+   *       public void run() {
+   *         try {
+   *           client
+   *               .readWriteTransaction()
+   *               .run(
+   *                   new TransactionCallable() {
+   *                     public long[] run(TransactionContext transaction) throws Exception {
+   *                       return transaction.batchUpdate(
+   *                           ImmutableList.of(statement1, statement2));
+   *                     }
+   *                   });
+   *         } catch (SpannerException e) {
+   *           if (e.getErrorCode() == ErrorCode.DEADLINE_EXCEEDED) {
+   *             // handle timeout exception.
+   *           }
+   *         }
+   *       }
+   *     });
+   * }
+ */ + public static interface CallContextConfigurator { + /** + * Configure a {@link ApiCallContext} for a specific RPC call. + * + * @param context The default context. This can be used to inspect the current values. + * @param request The request that will be sent. + * @param method The method that is being called. + * @return An {@link ApiCallContext} that will be merged with the default {@link + * ApiCallContext}. If null is returned, no changes to the default {@link + * ApiCallContext} will be made. + */ + @Nullable + ApiCallContext configure( + ApiCallContext context, ReqT request, MethodDescriptor method); + } + + private enum SpannerMethod { + COMMIT { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + return method == SpannerGrpc.getCommitMethod(); + } + }, + ROLLBACK { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + return method == SpannerGrpc.getRollbackMethod(); + } + }, + + EXECUTE_QUERY { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + // This also matches with Partitioned DML calls, but that call will override any timeout + // settings anyway. + return method == SpannerGrpc.getExecuteStreamingSqlMethod(); + } + }, + READ { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + return method == SpannerGrpc.getStreamingReadMethod(); + } + }, + EXECUTE_UPDATE { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + if (method == SpannerGrpc.getExecuteSqlMethod()) { + ExecuteSqlRequest sqlRequest = (ExecuteSqlRequest) request; + return sqlRequest.getSeqno() != 0L; + } + return false; + } + }, + BATCH_UPDATE { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + return method == SpannerGrpc.getExecuteBatchDmlMethod(); + } + }, + + PARTITION_QUERY { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + return method == SpannerGrpc.getPartitionQueryMethod(); + } + }, + PARTITION_READ { + @Override + boolean isMethod(ReqT request, MethodDescriptor method) { + return method == SpannerGrpc.getPartitionReadMethod(); + } + }; + + abstract boolean isMethod(ReqT request, MethodDescriptor method); + + static SpannerMethod valueOf(ReqT request, MethodDescriptor method) { + for (SpannerMethod m : SpannerMethod.values()) { + if (m.isMethod(request, method)) { + return m; + } + } + return null; + } + } + + /** + * Helper class to configure timeouts for specific Spanner RPCs. The {@link + * SpannerCallContextTimeoutConfigurator} must be set as a value on the {@link Context} using the + * {@link SpannerOptions#CALL_CONTEXT_CONFIGURATOR_KEY} key. + * + *

Example usage: + * + *

{@code
+   * // Create a context with a ExecuteQuery timeout of 10 seconds.
+   * Context context =
+   *     Context.current()
+   *         .withValue(
+   *             SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY,
+   *             SpannerCallContextTimeoutConfigurator.create()
+   *                 .withExecuteQueryTimeout(Duration.ofSeconds(10L)));
+   * context.run(
+   *     new Runnable() {
+   *       public void run() {
+   *         try (ResultSet rs =
+   *             client
+   *                 .singleUse()
+   *                 .executeQuery(
+   *                     Statement.of(
+   *                         "SELECT SingerId, FirstName, LastName FROM Singers ORDER BY LastName"))) {
+   *           while (rs.next()) {
+   *             System.out.printf("%d %s %s%n", rs.getLong(0), rs.getString(1), rs.getString(2));
+   *           }
+   *         } catch (SpannerException e) {
+   *           if (e.getErrorCode() == ErrorCode.DEADLINE_EXCEEDED) {
+   *             // Handle timeout.
+   *           }
+   *         }
+   *       }
+   *     });
+   * }
+ */ + public static class SpannerCallContextTimeoutConfigurator implements CallContextConfigurator { + private Duration commitTimeout; + private Duration rollbackTimeout; + + private Duration executeQueryTimeout; + private Duration executeUpdateTimeout; + private Duration batchUpdateTimeout; + private Duration readTimeout; + + private Duration partitionQueryTimeout; + private Duration partitionReadTimeout; + + public static SpannerCallContextTimeoutConfigurator create() { + return new SpannerCallContextTimeoutConfigurator(); + } + + private SpannerCallContextTimeoutConfigurator() {} + + @Override + public ApiCallContext configure( + ApiCallContext context, ReqT request, MethodDescriptor method) { + SpannerMethod spannerMethod = SpannerMethod.valueOf(request, method); + if (spannerMethod == null) { + return null; + } + switch (SpannerMethod.valueOf(request, method)) { + case BATCH_UPDATE: + return batchUpdateTimeout == null + ? null + : GrpcCallContext.createDefault().withTimeout(batchUpdateTimeout); + case COMMIT: + return commitTimeout == null + ? null + : GrpcCallContext.createDefault().withTimeout(commitTimeout); + case EXECUTE_QUERY: + return executeQueryTimeout == null + ? null + : GrpcCallContext.createDefault() + .withTimeout(executeQueryTimeout) + .withStreamWaitTimeout(executeQueryTimeout); + case EXECUTE_UPDATE: + return executeUpdateTimeout == null + ? null + : GrpcCallContext.createDefault().withTimeout(executeUpdateTimeout); + case PARTITION_QUERY: + return partitionQueryTimeout == null + ? null + : GrpcCallContext.createDefault().withTimeout(partitionQueryTimeout); + case PARTITION_READ: + return partitionReadTimeout == null + ? null + : GrpcCallContext.createDefault().withTimeout(partitionReadTimeout); + case READ: + return readTimeout == null + ? null + : GrpcCallContext.createDefault() + .withTimeout(readTimeout) + .withStreamWaitTimeout(readTimeout); + case ROLLBACK: + return rollbackTimeout == null + ? null + : GrpcCallContext.createDefault().withTimeout(rollbackTimeout); + default: + } + return null; + } + + public Duration getCommitTimeout() { + return commitTimeout; + } + + public SpannerCallContextTimeoutConfigurator withCommitTimeout(Duration commitTimeout) { + this.commitTimeout = commitTimeout; + return this; + } + + public Duration getRollbackTimeout() { + return rollbackTimeout; + } + + public SpannerCallContextTimeoutConfigurator withRollbackTimeout(Duration rollbackTimeout) { + this.rollbackTimeout = rollbackTimeout; + return this; + } + + public Duration getExecuteQueryTimeout() { + return executeQueryTimeout; + } + + public SpannerCallContextTimeoutConfigurator withExecuteQueryTimeout( + Duration executeQueryTimeout) { + this.executeQueryTimeout = executeQueryTimeout; + return this; + } + + public Duration getExecuteUpdateTimeout() { + return executeUpdateTimeout; + } + + public SpannerCallContextTimeoutConfigurator withExecuteUpdateTimeout( + Duration executeUpdateTimeout) { + this.executeUpdateTimeout = executeUpdateTimeout; + return this; + } + + public Duration getBatchUpdateTimeout() { + return batchUpdateTimeout; + } + + public SpannerCallContextTimeoutConfigurator withBatchUpdateTimeout( + Duration batchUpdateTimeout) { + this.batchUpdateTimeout = batchUpdateTimeout; + return this; + } + + public Duration getReadTimeout() { + return readTimeout; + } + + public SpannerCallContextTimeoutConfigurator withReadTimeout(Duration readTimeout) { + this.readTimeout = readTimeout; + return this; + } + + public Duration getPartitionQueryTimeout() { + return partitionQueryTimeout; + } + + public SpannerCallContextTimeoutConfigurator withPartitionQueryTimeout( + Duration partitionQueryTimeout) { + this.partitionQueryTimeout = partitionQueryTimeout; + return this; + } + + public Duration getPartitionReadTimeout() { + return partitionReadTimeout; + } + + public SpannerCallContextTimeoutConfigurator withPartitionReadTimeout( + Duration partitionReadTimeout) { + this.partitionReadTimeout = partitionReadTimeout; + return this; + } + } + /** Default implementation of {@code SpannerFactory}. */ private static class DefaultSpannerFactory implements SpannerFactory { private static final DefaultSpannerFactory INSTANCE = new DefaultSpannerFactory(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 187a6e9a22..520c183238 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -32,6 +32,7 @@ import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.retrying.TimedAttemptSettings; import com.google.api.gax.rpc.AlreadyExistsException; +import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.ApiClientHeaderProvider; import com.google.api.gax.rpc.ApiException; import com.google.api.gax.rpc.HeaderProvider; @@ -51,6 +52,7 @@ import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.SpannerOptions.CallContextConfigurator; import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStubSettings; @@ -77,6 +79,7 @@ import com.google.longrunning.CancelOperationRequest; import com.google.longrunning.GetOperationRequest; import com.google.longrunning.Operation; +import com.google.longrunning.OperationsGrpc; import com.google.protobuf.Empty; import com.google.protobuf.FieldMask; import com.google.protobuf.InvalidProtocolBufferException; @@ -88,6 +91,7 @@ import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; import com.google.spanner.admin.database.v1.CreateDatabaseRequest; import com.google.spanner.admin.database.v1.Database; +import com.google.spanner.admin.database.v1.DatabaseAdminGrpc; import com.google.spanner.admin.database.v1.DeleteBackupRequest; import com.google.spanner.admin.database.v1.DropDatabaseRequest; import com.google.spanner.admin.database.v1.GetBackupRequest; @@ -112,6 +116,7 @@ import com.google.spanner.admin.instance.v1.GetInstanceConfigRequest; import com.google.spanner.admin.instance.v1.GetInstanceRequest; import com.google.spanner.admin.instance.v1.Instance; +import com.google.spanner.admin.instance.v1.InstanceAdminGrpc; import com.google.spanner.admin.instance.v1.InstanceConfig; import com.google.spanner.admin.instance.v1.ListInstanceConfigsRequest; import com.google.spanner.admin.instance.v1.ListInstanceConfigsResponse; @@ -136,9 +141,11 @@ import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.Session; +import com.google.spanner.v1.SpannerGrpc; import com.google.spanner.v1.Transaction; import io.grpc.CallCredentials; import io.grpc.Context; +import io.grpc.MethodDescriptor; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; @@ -500,6 +507,7 @@ private final class OperationFutureCallable> { final OperationCallable operationCallable; final RequestT initialRequest; + final MethodDescriptor method; final String instanceName; final OperationsLister lister; final Function getStartTimeFunction; @@ -509,11 +517,13 @@ private final class OperationFutureCallable operationCallable, RequestT initialRequest, + MethodDescriptor method, String instanceName, OperationsLister lister, Function getStartTimeFunction) { this.operationCallable = operationCallable; this.initialRequest = initialRequest; + this.method = method; this.instanceName = instanceName; this.lister = lister; this.getStartTimeFunction = getStartTimeFunction; @@ -542,7 +552,7 @@ public OperationFuture call() throws Exception { isRetry = true; if (operationName == null) { - GrpcCallContext context = newCallContext(null, instanceName); + GrpcCallContext context = newCallContext(null, instanceName, initialRequest, method); return operationCallable.futureCall(initialRequest, context); } else { return operationCallable.resumeFutureCall(operationName); @@ -629,7 +639,9 @@ public Paginated listInstanceConfigs(int pageSize, @Nullable Str } ListInstanceConfigsRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(null, projectName); + GrpcCallContext context = + newCallContext( + null, projectName, request, InstanceAdminGrpc.getListInstanceConfigsMethod()); ListInstanceConfigsResponse response = get(instanceAdminStub.listInstanceConfigsCallable().futureCall(request, context)); return new Paginated<>(response.getInstanceConfigsList(), response.getNextPageToken()); @@ -640,7 +652,8 @@ public InstanceConfig getInstanceConfig(String instanceConfigName) throws Spanne GetInstanceConfigRequest request = GetInstanceConfigRequest.newBuilder().setName(instanceConfigName).build(); - GrpcCallContext context = newCallContext(null, projectName); + GrpcCallContext context = + newCallContext(null, projectName, request, InstanceAdminGrpc.getGetInstanceConfigMethod()); return get(instanceAdminStub.getInstanceConfigCallable().futureCall(request, context)); } @@ -657,7 +670,8 @@ public Paginated listInstances( } ListInstancesRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(null, projectName); + GrpcCallContext context = + newCallContext(null, projectName, request, InstanceAdminGrpc.getListInstancesMethod()); ListInstancesResponse response = get(instanceAdminStub.listInstancesCallable().futureCall(request, context)); return new Paginated<>(response.getInstancesList(), response.getNextPageToken()); @@ -672,7 +686,8 @@ public OperationFuture createInstance( .setInstanceId(instanceId) .setInstance(instance) .build(); - GrpcCallContext context = newCallContext(null, parent); + GrpcCallContext context = + newCallContext(null, parent, request, InstanceAdminGrpc.getCreateInstanceMethod()); return instanceAdminStub.createInstanceOperationCallable().futureCall(request, context); } @@ -681,7 +696,9 @@ public OperationFuture updateInstance( Instance instance, FieldMask fieldMask) throws SpannerException { UpdateInstanceRequest request = UpdateInstanceRequest.newBuilder().setInstance(instance).setFieldMask(fieldMask).build(); - GrpcCallContext context = newCallContext(null, instance.getName()); + GrpcCallContext context = + newCallContext( + null, instance.getName(), request, InstanceAdminGrpc.getUpdateInstanceMethod()); return instanceAdminStub.updateInstanceOperationCallable().futureCall(request, context); } @@ -689,7 +706,8 @@ public OperationFuture updateInstance( public Instance getInstance(String instanceName) throws SpannerException { GetInstanceRequest request = GetInstanceRequest.newBuilder().setName(instanceName).build(); - GrpcCallContext context = newCallContext(null, instanceName); + GrpcCallContext context = + newCallContext(null, instanceName, request, InstanceAdminGrpc.getGetInstanceMethod()); return get(instanceAdminStub.getInstanceCallable().futureCall(request, context)); } @@ -698,7 +716,8 @@ public void deleteInstance(String instanceName) throws SpannerException { DeleteInstanceRequest request = DeleteInstanceRequest.newBuilder().setName(instanceName).build(); - GrpcCallContext context = newCallContext(null, instanceName); + GrpcCallContext context = + newCallContext(null, instanceName, request, InstanceAdminGrpc.getDeleteInstanceMethod()); get(instanceAdminStub.deleteInstanceCallable().futureCall(request, context)); } @@ -716,7 +735,9 @@ public Paginated listBackupOperations( } ListBackupOperationsRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(null, instanceName); + GrpcCallContext context = + newCallContext( + null, instanceName, request, DatabaseAdminGrpc.getListBackupOperationsMethod()); ListBackupOperationsResponse response = get(databaseAdminStub.listBackupOperationsCallable().futureCall(request, context)); return new Paginated<>(response.getOperationsList(), response.getNextPageToken()); @@ -737,7 +758,9 @@ public Paginated listDatabaseOperations( } ListDatabaseOperationsRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(null, instanceName); + GrpcCallContext context = + newCallContext( + null, instanceName, request, DatabaseAdminGrpc.getListDatabaseOperationsMethod()); ListDatabaseOperationsResponse response = get(databaseAdminStub.listDatabaseOperationsCallable().futureCall(request, context)); return new Paginated<>(response.getOperationsList(), response.getNextPageToken()); @@ -758,7 +781,8 @@ public Paginated listBackups( } ListBackupsRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(null, instanceName); + GrpcCallContext context = + newCallContext(null, instanceName, request, DatabaseAdminGrpc.getListBackupsMethod()); ListBackupsResponse response = get(databaseAdminStub.listBackupsCallable().futureCall(request, context)); return new Paginated<>(response.getBackupsList(), response.getNextPageToken()); @@ -775,7 +799,8 @@ public Paginated listDatabases( } ListDatabasesRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(null, instanceName); + GrpcCallContext context = + newCallContext(null, instanceName, request, DatabaseAdminGrpc.getListDatabasesMethod()); ListDatabasesResponse response = get(databaseAdminStub.listDatabasesCallable().futureCall(request, context)); return new Paginated<>(response.getDatabasesList(), response.getNextPageToken()); @@ -801,6 +826,7 @@ public OperationFuture createDatabase( new OperationFutureCallable( databaseAdminStub.createDatabaseOperationCallable(), request, + DatabaseAdminGrpc.getCreateDatabaseMethod(), instanceName, new OperationsLister() { @Override @@ -856,7 +882,8 @@ public OperationFuture updateDatabaseDdl( .addAllStatements(updateDatabaseStatements) .setOperationId(MoreObjects.firstNonNull(updateId, "")) .build(); - GrpcCallContext context = newCallContext(null, databaseName); + GrpcCallContext context = + newCallContext(null, databaseName, request, DatabaseAdminGrpc.getUpdateDatabaseDdlMethod()); OperationCallable callable = databaseAdminStub.updateDatabaseDdlOperationCallable(); OperationFuture operationFuture = @@ -882,7 +909,8 @@ public void dropDatabase(String databaseName) throws SpannerException { DropDatabaseRequest request = DropDatabaseRequest.newBuilder().setDatabase(databaseName).build(); - GrpcCallContext context = newCallContext(null, databaseName); + GrpcCallContext context = + newCallContext(null, databaseName, request, DatabaseAdminGrpc.getDropDatabaseMethod()); get(databaseAdminStub.dropDatabaseCallable().futureCall(request, context)); } @@ -891,7 +919,8 @@ public Database getDatabase(String databaseName) throws SpannerException { acquireAdministrativeRequestsRateLimiter(); GetDatabaseRequest request = GetDatabaseRequest.newBuilder().setName(databaseName).build(); - GrpcCallContext context = newCallContext(null, databaseName); + GrpcCallContext context = + newCallContext(null, databaseName, request, DatabaseAdminGrpc.getGetDatabaseMethod()); return get(databaseAdminStub.getDatabaseCallable().futureCall(request, context)); } @@ -901,7 +930,8 @@ public List getDatabaseDdl(String databaseName) throws SpannerException GetDatabaseDdlRequest request = GetDatabaseDdlRequest.newBuilder().setDatabase(databaseName).build(); - GrpcCallContext context = newCallContext(null, databaseName); + GrpcCallContext context = + newCallContext(null, databaseName, request, DatabaseAdminGrpc.getGetDatabaseDdlMethod()); return get(databaseAdminStub.getDatabaseDdlCallable().futureCall(request, context)) .getStatementsList(); } @@ -920,6 +950,7 @@ public OperationFuture createBackup( new OperationFutureCallable( databaseAdminStub.createBackupOperationCallable(), request, + DatabaseAdminGrpc.getCreateBackupMethod(), instanceName, new OperationsLister() { @Override @@ -972,6 +1003,7 @@ public OperationFuture restoreDatabase( new OperationFutureCallable( databaseAdminStub.restoreDatabaseOperationCallable(), request, + DatabaseAdminGrpc.getRestoreDatabaseMethod(), databaseInstanceName, new OperationsLister() { @Override @@ -1015,7 +1047,8 @@ public Backup updateBackup(Backup backup, FieldMask updateMask) { acquireAdministrativeRequestsRateLimiter(); UpdateBackupRequest request = UpdateBackupRequest.newBuilder().setBackup(backup).setUpdateMask(updateMask).build(); - GrpcCallContext context = newCallContext(null, backup.getName()); + GrpcCallContext context = + newCallContext(null, backup.getName(), request, DatabaseAdminGrpc.getUpdateBackupMethod()); return databaseAdminStub.updateBackupCallable().call(request, context); } @@ -1023,7 +1056,8 @@ public Backup updateBackup(Backup backup, FieldMask updateMask) { public void deleteBackup(String backupName) { acquireAdministrativeRequestsRateLimiter(); DeleteBackupRequest request = DeleteBackupRequest.newBuilder().setName(backupName).build(); - GrpcCallContext context = newCallContext(null, backupName); + GrpcCallContext context = + newCallContext(null, backupName, request, DatabaseAdminGrpc.getDeleteBackupMethod()); databaseAdminStub.deleteBackupCallable().call(request, context); } @@ -1031,7 +1065,8 @@ public void deleteBackup(String backupName) { public Backup getBackup(String backupName) throws SpannerException { acquireAdministrativeRequestsRateLimiter(); GetBackupRequest request = GetBackupRequest.newBuilder().setName(backupName).build(); - GrpcCallContext context = newCallContext(null, backupName); + GrpcCallContext context = + newCallContext(null, backupName, request, DatabaseAdminGrpc.getGetBackupMethod()); return get(databaseAdminStub.getBackupCallable().futureCall(request, context)); } @@ -1039,7 +1074,8 @@ public Backup getBackup(String backupName) throws SpannerException { public Operation getOperation(String name) throws SpannerException { acquireAdministrativeRequestsRateLimiter(); GetOperationRequest request = GetOperationRequest.newBuilder().setName(name).build(); - GrpcCallContext context = newCallContext(null, name); + GrpcCallContext context = + newCallContext(null, name, request, OperationsGrpc.getGetOperationMethod()); return get( databaseAdminStub.getOperationsStub().getOperationCallable().futureCall(request, context)); } @@ -1048,7 +1084,8 @@ public Operation getOperation(String name) throws SpannerException { public void cancelOperation(String name) throws SpannerException { acquireAdministrativeRequestsRateLimiter(); CancelOperationRequest request = CancelOperationRequest.newBuilder().setName(name).build(); - GrpcCallContext context = newCallContext(null, name); + GrpcCallContext context = + newCallContext(null, name, request, OperationsGrpc.getCancelOperationMethod()); get( databaseAdminStub .getOperationsStub() @@ -1072,7 +1109,8 @@ public List batchCreateSessions( requestBuilder.setSessionTemplate(session); } BatchCreateSessionsRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(options, databaseName); + GrpcCallContext context = + newCallContext(options, databaseName, request, SpannerGrpc.getBatchCreateSessionsMethod()); return get(spannerStub.batchCreateSessionsCallable().futureCall(request, context)) .getSessionList(); } @@ -1088,7 +1126,8 @@ public Session createSession( requestBuilder.setSession(session); } CreateSessionRequest request = requestBuilder.build(); - GrpcCallContext context = newCallContext(options, databaseName); + GrpcCallContext context = + newCallContext(options, databaseName, request, SpannerGrpc.getCreateSessionMethod()); return get(spannerStub.createSessionCallable().futureCall(request, context)); } @@ -1101,14 +1140,16 @@ public void deleteSession(String sessionName, @Nullable Map options) @Override public ApiFuture asyncDeleteSession(String sessionName, @Nullable Map options) { DeleteSessionRequest request = DeleteSessionRequest.newBuilder().setName(sessionName).build(); - GrpcCallContext context = newCallContext(options, sessionName); + GrpcCallContext context = + newCallContext(options, sessionName, request, SpannerGrpc.getDeleteSessionMethod()); return spannerStub.deleteSessionCallable().futureCall(request, context); } @Override public StreamingCall read( ReadRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext(options, request.getSession(), request, SpannerGrpc.getReadMethod()); SpannerResponseObserver responseObserver = new SpannerResponseObserver(consumer); spannerStub.streamingReadCallable().call(request, responseObserver, context); final StreamController controller = responseObserver.getController(); @@ -1135,14 +1176,16 @@ public ResultSet executeQuery(ExecuteSqlRequest request, @Nullable Map executeQueryAsync( ExecuteSqlRequest request, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext(options, request.getSession(), request, SpannerGrpc.getExecuteSqlMethod()); return spannerStub.executeSqlCallable().futureCall(request, context); } @Override public ResultSet executePartitionedDml( ExecuteSqlRequest request, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext(options, request.getSession(), request, SpannerGrpc.getExecuteSqlMethod()); return get(partitionedDmlStub.executeSqlCallable().futureCall(request, context)); } @@ -1154,15 +1197,20 @@ public RetrySettings getPartitionedDmlRetrySettings() { @Override public ServerStream executeStreamingPartitionedDml( ExecuteSqlRequest request, Map options, Duration timeout) { - GrpcCallContext context = newCallContext(options, request.getSession()); - context = context.withStreamWaitTimeout(timeout); + GrpcCallContext context = + newCallContext( + options, request.getSession(), request, SpannerGrpc.getExecuteStreamingSqlMethod()); + // Override any timeout settings that might have been set on the call context. + context = context.withTimeout(timeout).withStreamWaitTimeout(timeout); return partitionedDmlStub.executeStreamingSqlCallable().call(request, context); } @Override public StreamingCall executeQuery( ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext( + options, request.getSession(), request, SpannerGrpc.getExecuteStreamingSqlMethod()); SpannerResponseObserver responseObserver = new SpannerResponseObserver(consumer); spannerStub.executeStreamingSqlCallable().call(request, responseObserver, context); final StreamController controller = responseObserver.getController(); @@ -1190,14 +1238,18 @@ public ExecuteBatchDmlResponse executeBatchDml( @Override public ApiFuture executeBatchDmlAsync( ExecuteBatchDmlRequest request, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext( + options, request.getSession(), request, SpannerGrpc.getExecuteBatchDmlMethod()); return spannerStub.executeBatchDmlCallable().futureCall(request, context); } @Override public ApiFuture beginTransactionAsync( BeginTransactionRequest request, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext( + options, request.getSession(), request, SpannerGrpc.getBeginTransactionMethod()); return spannerStub.beginTransactionCallable().futureCall(request, context); } @@ -1209,9 +1261,10 @@ public Transaction beginTransaction( @Override public ApiFuture commitAsync( - CommitRequest commitRequest, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, commitRequest.getSession()); - return spannerStub.commitCallable().futureCall(commitRequest, context); + CommitRequest request, @Nullable Map options) { + GrpcCallContext context = + newCallContext(options, request.getSession(), request, SpannerGrpc.getCommitMethod()); + return spannerStub.commitCallable().futureCall(request, context); } @Override @@ -1222,7 +1275,8 @@ public CommitResponse commit(CommitRequest commitRequest, @Nullable Map rollbackAsync(RollbackRequest request, @Nullable Map options) { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext(options, request.getSession(), request, SpannerGrpc.getRollbackMethod()); return spannerStub.rollbackCallable().futureCall(request, context); } @@ -1235,91 +1289,85 @@ public void rollback(RollbackRequest request, @Nullable Map options) @Override public PartitionResponse partitionQuery( PartitionQueryRequest request, @Nullable Map options) throws SpannerException { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext( + options, request.getSession(), request, SpannerGrpc.getPartitionQueryMethod()); return get(spannerStub.partitionQueryCallable().futureCall(request, context)); } @Override public PartitionResponse partitionRead( PartitionReadRequest request, @Nullable Map options) throws SpannerException { - GrpcCallContext context = newCallContext(options, request.getSession()); + GrpcCallContext context = + newCallContext( + options, request.getSession(), request, SpannerGrpc.getPartitionReadMethod()); return get(spannerStub.partitionReadCallable().futureCall(request, context)); } @Override public Policy getDatabaseAdminIAMPolicy(String resource) { acquireAdministrativeRequestsRateLimiter(); - GrpcCallContext context = newCallContext(null, resource); - return get( - databaseAdminStub - .getIamPolicyCallable() - .futureCall(GetIamPolicyRequest.newBuilder().setResource(resource).build(), context)); + GetIamPolicyRequest request = GetIamPolicyRequest.newBuilder().setResource(resource).build(); + GrpcCallContext context = + newCallContext(null, resource, request, DatabaseAdminGrpc.getGetIamPolicyMethod()); + return get(databaseAdminStub.getIamPolicyCallable().futureCall(request, context)); } @Override public Policy setDatabaseAdminIAMPolicy(String resource, Policy policy) { acquireAdministrativeRequestsRateLimiter(); - GrpcCallContext context = newCallContext(null, resource); - return get( - databaseAdminStub - .setIamPolicyCallable() - .futureCall( - SetIamPolicyRequest.newBuilder().setResource(resource).setPolicy(policy).build(), - context)); + SetIamPolicyRequest request = + SetIamPolicyRequest.newBuilder().setResource(resource).setPolicy(policy).build(); + GrpcCallContext context = + newCallContext(null, resource, request, DatabaseAdminGrpc.getSetIamPolicyMethod()); + return get(databaseAdminStub.setIamPolicyCallable().futureCall(request, context)); } @Override public TestIamPermissionsResponse testDatabaseAdminIAMPermissions( String resource, Iterable permissions) { acquireAdministrativeRequestsRateLimiter(); - GrpcCallContext context = newCallContext(null, resource); - return get( - databaseAdminStub - .testIamPermissionsCallable() - .futureCall( - TestIamPermissionsRequest.newBuilder() - .setResource(resource) - .addAllPermissions(permissions) - .build(), - context)); + TestIamPermissionsRequest request = + TestIamPermissionsRequest.newBuilder() + .setResource(resource) + .addAllPermissions(permissions) + .build(); + GrpcCallContext context = + newCallContext(null, resource, request, DatabaseAdminGrpc.getTestIamPermissionsMethod()); + return get(databaseAdminStub.testIamPermissionsCallable().futureCall(request, context)); } @Override public Policy getInstanceAdminIAMPolicy(String resource) { acquireAdministrativeRequestsRateLimiter(); - GrpcCallContext context = newCallContext(null, resource); - return get( - instanceAdminStub - .getIamPolicyCallable() - .futureCall(GetIamPolicyRequest.newBuilder().setResource(resource).build(), context)); + GetIamPolicyRequest request = GetIamPolicyRequest.newBuilder().setResource(resource).build(); + GrpcCallContext context = + newCallContext(null, resource, request, InstanceAdminGrpc.getGetIamPolicyMethod()); + return get(instanceAdminStub.getIamPolicyCallable().futureCall(request, context)); } @Override public Policy setInstanceAdminIAMPolicy(String resource, Policy policy) { acquireAdministrativeRequestsRateLimiter(); - GrpcCallContext context = newCallContext(null, resource); - return get( - instanceAdminStub - .setIamPolicyCallable() - .futureCall( - SetIamPolicyRequest.newBuilder().setResource(resource).setPolicy(policy).build(), - context)); + SetIamPolicyRequest request = + SetIamPolicyRequest.newBuilder().setResource(resource).setPolicy(policy).build(); + GrpcCallContext context = + newCallContext(null, resource, request, InstanceAdminGrpc.getSetIamPolicyMethod()); + return get(instanceAdminStub.setIamPolicyCallable().futureCall(request, context)); } @Override public TestIamPermissionsResponse testInstanceAdminIAMPermissions( String resource, Iterable permissions) { acquireAdministrativeRequestsRateLimiter(); - GrpcCallContext context = newCallContext(null, resource); - return get( - instanceAdminStub - .testIamPermissionsCallable() - .futureCall( - TestIamPermissionsRequest.newBuilder() - .setResource(resource) - .addAllPermissions(permissions) - .build(), - context)); + TestIamPermissionsRequest request = + TestIamPermissionsRequest.newBuilder() + .setResource(resource) + .addAllPermissions(permissions) + .build(); + GrpcCallContext context = + newCallContext(null, resource, request, InstanceAdminGrpc.getTestIamPermissionsMethod()); + return get(instanceAdminStub.testIamPermissionsCallable().futureCall(request, context)); } /** Gets the result of an async RPC call, handling any exceptions encountered. */ @@ -1337,7 +1385,11 @@ private static T get(final Future future) throws SpannerException { } @VisibleForTesting - GrpcCallContext newCallContext(@Nullable Map options, String resource) { + GrpcCallContext newCallContext( + @Nullable Map options, + String resource, + ReqT request, + MethodDescriptor method) { GrpcCallContext context = GrpcCallContext.createDefault(); if (options != null) { context = context.withChannelAffinity(Option.CHANNEL_HINT.getLong(options).intValue()); @@ -1350,7 +1402,13 @@ GrpcCallContext newCallContext(@Nullable Map options, String resource context.withCallOptions(context.getCallOptions().withCallCredentials(callCredentials)); } } - return context.withStreamWaitTimeout(waitTimeout).withStreamIdleTimeout(idleTimeout); + context = context.withStreamWaitTimeout(waitTimeout).withStreamIdleTimeout(idleTimeout); + CallContextConfigurator configurator = SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY.get(); + ApiCallContext apiCallContextFromContext = null; + if (configurator != null) { + apiCallContextFromContext = configurator.configure(context, request, method); + } + return (GrpcCallContext) context.merge(apiCallContextFromContext); } @Override diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index 5e69b35cd3..001cdfdf04 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -32,6 +32,7 @@ import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; +import com.google.cloud.spanner.SpannerOptions.SpannerCallContextTimeoutConfigurator; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; @@ -40,6 +41,7 @@ import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import io.grpc.Context; import io.grpc.Server; import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -1548,4 +1550,40 @@ public void testReadDoesNotIncludeStatement() { assertThat(e.getMessage()).doesNotContain("Statement:"); } } + + @Test + public void testSpecificTimeout() { + mockSpanner.setExecuteStreamingSqlExecutionTime( + SimulatedExecutionTime.ofMinimumAndRandomTime(10000, 0)); + final DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + Context.current() + .withValue( + SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY, + SpannerCallContextTimeoutConfigurator.create() + .withExecuteQueryTimeout(Duration.ofNanos(1L))) + .run( + new Runnable() { + @Override + public void run() { + // Query should fail with a timeout. + try (ResultSet rs = client.singleUse().executeQuery(SELECT1)) { + rs.next(); + fail("missing expected DEADLINE_EXCEEDED exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED); + } + // Update should succeed. + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + return transaction.executeUpdate(UPDATE_STATEMENT); + } + }); + } + }); + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java index 65636e80fa..1cf7127b8c 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java @@ -21,17 +21,34 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.fail; +import com.google.api.gax.grpc.GrpcCallContext; import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.UnaryCallSettings; import com.google.cloud.NoCredentials; import com.google.cloud.ServiceOptions; import com.google.cloud.TransportOptions; +import com.google.cloud.spanner.SpannerOptions.SpannerCallContextTimeoutConfigurator; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStubSettings; import com.google.cloud.spanner.admin.instance.v1.stub.InstanceAdminStubSettings; import com.google.cloud.spanner.v1.stub.SpannerStubSettings; import com.google.common.base.Strings; +import com.google.spanner.v1.BatchCreateSessionsRequest; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.DeleteSessionRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.GetSessionRequest; +import com.google.spanner.v1.ListSessionsRequest; +import com.google.spanner.v1.PartitionQueryRequest; +import com.google.spanner.v1.PartitionReadRequest; +import com.google.spanner.v1.ReadRequest; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.SpannerGrpc; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -615,4 +632,237 @@ public void testCompressorName() { // ignore, this is the expected exception. } } + + @Test + public void testSpannerCallContextTimeoutConfigurator_NullValues() { + SpannerCallContextTimeoutConfigurator configurator = + SpannerCallContextTimeoutConfigurator.create(); + ApiCallContext inputCallContext = GrpcCallContext.createDefault(); + + assertThat( + configurator.configure( + inputCallContext, + BatchCreateSessionsRequest.getDefaultInstance(), + SpannerGrpc.getBatchCreateSessionsMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + CreateSessionRequest.getDefaultInstance(), + SpannerGrpc.getCreateSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + DeleteSessionRequest.getDefaultInstance(), + SpannerGrpc.getDeleteSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + GetSessionRequest.getDefaultInstance(), + SpannerGrpc.getGetSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + DeleteSessionRequest.getDefaultInstance(), + SpannerGrpc.getDeleteSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + ListSessionsRequest.getDefaultInstance(), + SpannerGrpc.getListSessionsMethod())) + .isNull(); + + assertThat( + configurator.configure( + inputCallContext, + BeginTransactionRequest.getDefaultInstance(), + SpannerGrpc.getBeginTransactionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + CommitRequest.getDefaultInstance(), + SpannerGrpc.getCommitMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + RollbackRequest.getDefaultInstance(), + SpannerGrpc.getRollbackMethod())) + .isNull(); + + assertThat( + configurator.configure( + inputCallContext, + ExecuteSqlRequest.getDefaultInstance(), + SpannerGrpc.getExecuteSqlMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + ExecuteSqlRequest.getDefaultInstance(), + SpannerGrpc.getExecuteStreamingSqlMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + ExecuteBatchDmlRequest.getDefaultInstance(), + SpannerGrpc.getExecuteBatchDmlMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, ReadRequest.getDefaultInstance(), SpannerGrpc.getReadMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + ReadRequest.getDefaultInstance(), + SpannerGrpc.getStreamingReadMethod())) + .isNull(); + + assertThat( + configurator.configure( + inputCallContext, + PartitionQueryRequest.getDefaultInstance(), + SpannerGrpc.getPartitionQueryMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + PartitionReadRequest.getDefaultInstance(), + SpannerGrpc.getPartitionReadMethod())) + .isNull(); + } + + @Test + public void testSpannerCallContextTimeoutConfigurator_WithTimeouts() { + SpannerCallContextTimeoutConfigurator configurator = + SpannerCallContextTimeoutConfigurator.create(); + configurator.withBatchUpdateTimeout(Duration.ofSeconds(1L)); + configurator.withCommitTimeout(Duration.ofSeconds(2L)); + configurator.withExecuteQueryTimeout(Duration.ofSeconds(3L)); + configurator.withExecuteUpdateTimeout(Duration.ofSeconds(4L)); + configurator.withPartitionQueryTimeout(Duration.ofSeconds(5L)); + configurator.withPartitionReadTimeout(Duration.ofSeconds(6L)); + configurator.withReadTimeout(Duration.ofSeconds(7L)); + configurator.withRollbackTimeout(Duration.ofSeconds(8L)); + + ApiCallContext inputCallContext = GrpcCallContext.createDefault(); + + assertThat( + configurator.configure( + inputCallContext, + BatchCreateSessionsRequest.getDefaultInstance(), + SpannerGrpc.getBatchCreateSessionsMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + CreateSessionRequest.getDefaultInstance(), + SpannerGrpc.getCreateSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + DeleteSessionRequest.getDefaultInstance(), + SpannerGrpc.getDeleteSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + GetSessionRequest.getDefaultInstance(), + SpannerGrpc.getGetSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + DeleteSessionRequest.getDefaultInstance(), + SpannerGrpc.getDeleteSessionMethod())) + .isNull(); + assertThat( + configurator.configure( + inputCallContext, + ListSessionsRequest.getDefaultInstance(), + SpannerGrpc.getListSessionsMethod())) + .isNull(); + + assertThat( + configurator.configure( + inputCallContext, + BeginTransactionRequest.getDefaultInstance(), + SpannerGrpc.getBeginTransactionMethod())) + .isNull(); + assertThat( + configurator + .configure( + inputCallContext, + CommitRequest.getDefaultInstance(), + SpannerGrpc.getCommitMethod()) + .getTimeout()) + .isEqualTo(Duration.ofSeconds(2L)); + assertThat( + configurator + .configure( + inputCallContext, + RollbackRequest.getDefaultInstance(), + SpannerGrpc.getRollbackMethod()) + .getTimeout()) + .isEqualTo(Duration.ofSeconds(8L)); + + assertThat( + configurator.configure( + inputCallContext, + ExecuteSqlRequest.getDefaultInstance(), + SpannerGrpc.getExecuteSqlMethod())) + .isNull(); + assertThat( + configurator + .configure( + inputCallContext, + ExecuteSqlRequest.getDefaultInstance(), + SpannerGrpc.getExecuteStreamingSqlMethod()) + .getTimeout()) + .isEqualTo(Duration.ofSeconds(3L)); + assertThat( + configurator + .configure( + inputCallContext, + ExecuteBatchDmlRequest.getDefaultInstance(), + SpannerGrpc.getExecuteBatchDmlMethod()) + .getTimeout()) + .isEqualTo(Duration.ofSeconds(1L)); + assertThat( + configurator.configure( + inputCallContext, ReadRequest.getDefaultInstance(), SpannerGrpc.getReadMethod())) + .isNull(); + assertThat( + configurator + .configure( + inputCallContext, + ReadRequest.getDefaultInstance(), + SpannerGrpc.getStreamingReadMethod()) + .getTimeout()) + .isEqualTo(Duration.ofSeconds(7L)); + + assertThat( + configurator + .configure( + inputCallContext, + PartitionQueryRequest.getDefaultInstance(), + SpannerGrpc.getPartitionQueryMethod()) + .getTimeout()) + .isEqualTo(Duration.ofSeconds(5L)); + assertThat( + configurator + .configure( + inputCallContext, + PartitionReadRequest.getDefaultInstance(), + SpannerGrpc.getPartitionReadMethod()) + .getTimeout()) + .isEqualTo(Duration.ofSeconds(6L)); + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java index 189f9ec6d3..18270dfdd4 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java @@ -20,22 +20,30 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; import static org.junit.Assume.assumeTrue; import com.google.api.core.ApiFunction; +import com.google.api.gax.rpc.ApiCallContext; import com.google.auth.oauth2.AccessToken; import com.google.auth.oauth2.OAuth2Credentials; import com.google.cloud.spanner.DatabaseAdminClient; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.InstanceAdminClient; import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.SpannerOptions.CallContextConfigurator; import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.TransactionContext; +import com.google.cloud.spanner.TransactionRunner.TransactionCallable; import com.google.cloud.spanner.admin.database.v1.MockDatabaseAdminImpl; import com.google.cloud.spanner.admin.instance.v1.MockInstanceAdminImpl; import com.google.cloud.spanner.spi.v1.SpannerRpc.Option; @@ -46,7 +54,10 @@ import com.google.spanner.admin.instance.v1.Instance; import com.google.spanner.admin.instance.v1.InstanceConfigName; import com.google.spanner.admin.instance.v1.InstanceName; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.GetSessionRequest; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.SpannerGrpc; import com.google.spanner.v1.StructType; import com.google.spanner.v1.StructType.Field; import com.google.spanner.v1.TypeCode; @@ -56,6 +67,7 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.Metadata.Key; +import io.grpc.MethodDescriptor; import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; @@ -76,6 +88,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.threeten.bp.Duration; /** Tests that opening and closing multiple Spanner instances does not leak any threads. */ @RunWith(JUnit4.class) @@ -108,6 +121,9 @@ public class GapicSpannerRpcTest { .build()) .setMetadata(SELECT1AND2_METADATA) .build(); + private static final Statement UPDATE_FOO_STATEMENT = + Statement.of("UPDATE FOO SET BAR=1 WHERE BAZ=2"); + private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN"; private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN"; private static final OAuth2Credentials STATIC_CREDENTIALS = @@ -142,6 +158,7 @@ public void startServer() throws IOException { mockSpanner = new MockSpannerServiceImpl(); mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions. mockSpanner.putStatementResult(StatementResult.query(SELECT1AND2, SELECT1_RESULTSET)); + mockSpanner.putStatementResult(StatementResult.update(UPDATE_FOO_STATEMENT, 1L)); mockInstanceAdmin = new MockInstanceAdminImpl(); mockDatabaseAdmin = new MockDatabaseAdminImpl(); @@ -303,7 +320,14 @@ public CallCredentials getCallCredentials() { GapicSpannerRpc rpc = new GapicSpannerRpc(options); // GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the // existence. - assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials()) + assertThat( + rpc.newCallContext( + optionsMap, + "/some/resource", + GetSessionRequest.getDefaultInstance(), + SpannerGrpc.getGetSessionMethod()) + .getCallOptions() + .getCredentials()) .isNotNull(); rpc.shutdown(); } @@ -323,7 +347,14 @@ public CallCredentials getCallCredentials() { }) .build(); GapicSpannerRpc rpc = new GapicSpannerRpc(options); - assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials()) + assertThat( + rpc.newCallContext( + optionsMap, + "/some/resource", + GetSessionRequest.getDefaultInstance(), + SpannerGrpc.getGetSessionMethod()) + .getCallOptions() + .getCredentials()) .isNull(); rpc.shutdown(); } @@ -336,11 +367,93 @@ public void testNoCallCredentials() { .setCredentials(STATIC_CREDENTIALS) .build(); GapicSpannerRpc rpc = new GapicSpannerRpc(options); - assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials()) + assertThat( + rpc.newCallContext( + optionsMap, + "/some/resource", + GetSessionRequest.getDefaultInstance(), + SpannerGrpc.getGetSessionMethod()) + .getCallOptions() + .getCredentials()) .isNull(); rpc.shutdown(); } + private static final class TimeoutHolder { + private Duration timeout; + } + + @Test + public void testCallContextTimeout() { + // Create a CallContextConfigurator that uses a variable timeout value. + final TimeoutHolder timeoutHolder = new TimeoutHolder(); + CallContextConfigurator configurator = + new CallContextConfigurator() { + @Override + public ApiCallContext configure( + ApiCallContext context, ReqT request, MethodDescriptor method) { + // Only configure a timeout for the ExecuteSql method as this method is used for + // executing DML statements. + if (request instanceof ExecuteSqlRequest + && method.equals(SpannerGrpc.getExecuteSqlMethod())) { + ExecuteSqlRequest sqlRequest = (ExecuteSqlRequest) request; + // Sequence numbers are only assigned for DML statements, which means that + // this is an update statement. + if (sqlRequest.getSeqno() > 0L) { + return context.withTimeout(timeoutHolder.timeout); + } + } + return null; + } + }; + + mockSpanner.setExecuteSqlExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(10, 0)); + SpannerOptions options = createSpannerOptions(); + try (Spanner spanner = options.getService()) { + final DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + Context context = + Context.current().withValue(SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY, configurator); + context.run( + new Runnable() { + @Override + public void run() { + try { + // First try with a 1ns timeout. This should always cause a DEADLINE_EXCEEDED + // exception. + timeoutHolder.timeout = Duration.ofNanos(1L); + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + return transaction.executeUpdate(UPDATE_FOO_STATEMENT); + } + }); + fail("missing expected timeout exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED); + } + + // Then try with a longer timeout. This should now succeed. + timeoutHolder.timeout = Duration.ofMinutes(1L); + Long updateCount = + client + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + return transaction.executeUpdate(UPDATE_FOO_STATEMENT); + } + }); + assertThat(updateCount).isEqualTo(1L); + } + }); + } + } + @SuppressWarnings("rawtypes") private SpannerOptions createSpannerOptions() { String endpoint = address.getHostString() + ":" + server.getPort();