From 1112357be1c5fb9c4abfba48989fe8217853876a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 14 Jan 2020 09:43:59 +0100 Subject: [PATCH] feat: add support for CallCredentials (#26) * feat: add support for CallCredentials Adds support for using CallCredentials that could vary per call instead of only static credentials passed in at startup. Fixes #18 * fix: declare dependency usage --- google-cloud-spanner/pom.xml | 4 + .../google/cloud/spanner/SpannerOptions.java | 29 +++++ .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 15 ++- .../spanner/spi/v1/GapicSpannerRpcTest.java | 110 +++++++++++++++++- 4 files changed, 155 insertions(+), 3 deletions(-) diff --git a/google-cloud-spanner/pom.xml b/google-cloud-spanner/pom.xml index 95d059bc12..e76c35f50a 100644 --- a/google-cloud-spanner/pom.xml +++ b/google-cloud-spanner/pom.xml @@ -80,6 +80,10 @@ io.grpc grpc-api + + io.grpc + grpc-auth + io.grpc grpc-context diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 70cc945cbb..a6aa502909 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -37,6 +37,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.grpc.CallCredentials; import io.grpc.ManagedChannelBuilder; import java.io.IOException; import java.net.MalformedURLException; @@ -72,6 +73,16 @@ public class SpannerOptions extends ServiceOptions { private final InstanceAdminStubSettings instanceAdminStubSettings; private final DatabaseAdminStubSettings databaseAdminStubSettings; private final Duration partitionedDmlTimeout; + private final CallCredentialsProvider callCredentialsProvider; + + /** + * Interface that can be used to provide {@link CallCredentials} instead of {@link Credentials} to + * {@link SpannerOptions}. + */ + public static interface CallCredentialsProvider { + /** Return the {@link CallCredentials} to use for a gRPC call. */ + CallCredentials getCallCredentials(); + } /** Default implementation of {@code SpannerFactory}. */ private static class DefaultSpannerFactory implements SpannerFactory { @@ -119,6 +130,7 @@ private SpannerOptions(Builder builder) { throw SpannerExceptionFactory.newSpannerException(e); } partitionedDmlTimeout = builder.partitionedDmlTimeout; + callCredentialsProvider = builder.callCredentialsProvider; } /** Builder for {@link SpannerOptions} instances. */ @@ -150,6 +162,7 @@ public static class Builder private DatabaseAdminStubSettings.Builder databaseAdminStubSettingsBuilder = DatabaseAdminStubSettings.newBuilder(); private Duration partitionedDmlTimeout = Duration.ofHours(2L); + private CallCredentialsProvider callCredentialsProvider; private String emulatorHost = System.getenv("SPANNER_EMULATOR_HOST"); private Builder() {} @@ -164,6 +177,7 @@ private Builder() {} this.instanceAdminStubSettingsBuilder = options.instanceAdminStubSettings.toBuilder(); this.databaseAdminStubSettingsBuilder = options.databaseAdminStubSettings.toBuilder(); this.partitionedDmlTimeout = options.partitionedDmlTimeout; + this.callCredentialsProvider = options.callCredentialsProvider; this.channelProvider = options.channelProvider; this.channelConfigurator = options.channelConfigurator; this.interceptorProvider = options.interceptorProvider; @@ -355,6 +369,17 @@ public Builder setPartitionedDmlTimeout(Duration timeout) { return this; } + /** + * Sets a {@link CallCredentialsProvider} that can deliver {@link CallCredentials} to use on a + * per-gRPC basis. Any credentials returned by this {@link CallCredentialsProvider} will have + * preference above any {@link Credentials} that may have been set on the {@link SpannerOptions} + * instance. + */ + public Builder setCallCredentialsProvider(CallCredentialsProvider callCredentialsProvider) { + this.callCredentialsProvider = callCredentialsProvider; + return this; + } + /** * Specifying this will allow the client to prefetch up to {@code prefetchChunks} {@code * PartialResultSet} chunks for each read and query. The data size of each chunk depends on the @@ -452,6 +477,10 @@ public Duration getPartitionedDmlTimeout() { return partitionedDmlTimeout; } + public CallCredentialsProvider getCallCredentialsProvider() { + return callCredentialsProvider; + } + public int getPrefetchChunks() { return prefetchChunks; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 49d6d9077d..917a5fdb61 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -40,12 +40,14 @@ import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub; import com.google.cloud.spanner.admin.database.v1.stub.GrpcDatabaseAdminStub; import com.google.cloud.spanner.admin.instance.v1.stub.GrpcInstanceAdminStub; import com.google.cloud.spanner.admin.instance.v1.stub.InstanceAdminStub; import com.google.cloud.spanner.v1.stub.GrpcSpannerStub; import com.google.cloud.spanner.v1.stub.SpannerStub; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -99,6 +101,7 @@ import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.Session; import com.google.spanner.v1.Transaction; +import io.grpc.CallCredentials; import io.grpc.Context; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; @@ -174,6 +177,7 @@ private synchronized void shutdown() { private final String projectId; private final String projectName; private final SpannerMetadataProvider metadataProvider; + private final CallCredentialsProvider callCredentialsProvider; private final Duration waitTimeout = systemProperty(PROPERTY_TIMEOUT_SECONDS, DEFAULT_TIMEOUT_SECONDS); private final Duration idleTimeout = @@ -216,6 +220,7 @@ public GapicSpannerRpc(final SpannerOptions options) { SpannerMetadataProvider.create( mergedHeaderProvider.getHeaders(), internalHeaderProviderBuilder.getResourceHeaderKey()); + this.callCredentialsProvider = options.getCallCredentialsProvider(); // Create a managed executor provider. this.executorProvider = @@ -702,7 +707,8 @@ private static T get(final Future future) throws SpannerException { } } - private GrpcCallContext newCallContext(@Nullable Map options, String resource) { + @VisibleForTesting + GrpcCallContext newCallContext(@Nullable Map options, String resource) { return newCallContext(options, resource, null); } @@ -716,6 +722,13 @@ private GrpcCallContext newCallContext( if (timeout != null) { context = context.withTimeout(timeout); } + if (callCredentialsProvider != null) { + CallCredentials callCredentials = callCredentialsProvider.getCallCredentials(); + if (callCredentials != null) { + context = + context.withCallOptions(context.getCallOptions().withCallCredentials(callCredentials)); + } + } return context.withStreamWaitTimeout(waitTimeout).withStreamIdleTimeout(idleTimeout); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java index dd860a6173..413799bc2c 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java @@ -16,12 +16,14 @@ package com.google.cloud.spanner.spi.v1; +import static com.google.common.truth.Truth.assertThat; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import com.google.api.core.ApiFunction; -import com.google.cloud.NoCredentials; +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; import com.google.cloud.spanner.DatabaseAdminClient; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; @@ -31,9 +33,11 @@ import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.admin.database.v1.MockDatabaseAdminImpl; import com.google.cloud.spanner.admin.instance.v1.MockInstanceAdminImpl; +import com.google.cloud.spanner.spi.v1.SpannerRpc.Option; import com.google.common.base.Stopwatch; import com.google.protobuf.ListValue; import com.google.spanner.admin.database.v1.Database; @@ -45,13 +49,24 @@ import com.google.spanner.v1.StructType; import com.google.spanner.v1.StructType.Field; import com.google.spanner.v1.TypeCode; +import io.grpc.CallCredentials; +import io.grpc.Context; +import io.grpc.Contexts; import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.Metadata.Key; import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.auth.MoreCallCredentials; import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; import java.io.IOException; import java.net.InetSocketAddress; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; import org.junit.After; @@ -91,11 +106,27 @@ public class GapicSpannerRpcTest { .build()) .setMetadata(SELECT1AND2_METADATA) .build(); + private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN"; + private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN"; + private static final OAuth2Credentials STATIC_CREDENTIALS = + OAuth2Credentials.create( + new AccessToken( + STATIC_OAUTH_TOKEN, + new java.util.Date( + System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS)))); + private static final OAuth2Credentials VARIABLE_CREDENTIALS = + OAuth2Credentials.create( + new AccessToken( + VARIABLE_OAUTH_TOKEN, + new java.util.Date( + System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS)))); + private MockSpannerServiceImpl mockSpanner; private MockInstanceAdminImpl mockInstanceAdmin; private MockDatabaseAdminImpl mockDatabaseAdmin; private Server server; private InetSocketAddress address; + private final Map optionsMap = new HashMap<>(); @Before public void startServer() throws IOException { @@ -111,8 +142,24 @@ public void startServer() throws IOException { .addService(mockSpanner) .addService(mockInstanceAdmin) .addService(mockDatabaseAdmin) + // Add a server interceptor that will check that we receive the variable OAuth token + // from the CallCredentials, and not the one set as static credentials. + .intercept( + new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + String auth = + headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER)); + assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN); + return Contexts.interceptCall(Context.current(), call, headers, next); + } + }) .build() .start(); + optionsMap.put(Option.CHANNEL_HINT, Long.valueOf(1L)); } @After @@ -229,6 +276,55 @@ && getNumberOfThreadsWithName(SPANNER_THREAD_NAME, false) assertThat(getNumberOfThreadsWithName(SPANNER_THREAD_NAME, true), is(equalTo(0))); } + @Test + public void testCallCredentialsProviderPreferenceAboveCredentials() { + SpannerOptions options = + SpannerOptions.newBuilder() + .setCredentials(STATIC_CREDENTIALS) + .setCallCredentialsProvider( + new CallCredentialsProvider() { + @Override + public CallCredentials getCallCredentials() { + return MoreCallCredentials.from(VARIABLE_CREDENTIALS); + } + }) + .build(); + GapicSpannerRpc rpc = new GapicSpannerRpc(options); + // GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the + // existence. + assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials()) + .isNotNull(); + rpc.shutdown(); + } + + @Test + public void testCallCredentialsProviderReturnsNull() { + SpannerOptions options = + SpannerOptions.newBuilder() + .setCredentials(STATIC_CREDENTIALS) + .setCallCredentialsProvider( + new CallCredentialsProvider() { + @Override + public CallCredentials getCallCredentials() { + return null; + } + }) + .build(); + GapicSpannerRpc rpc = new GapicSpannerRpc(options); + assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials()) + .isNull(); + rpc.shutdown(); + } + + @Test + public void testNoCallCredentials() { + SpannerOptions options = SpannerOptions.newBuilder().setCredentials(STATIC_CREDENTIALS).build(); + GapicSpannerRpc rpc = new GapicSpannerRpc(options); + assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials()) + .isNull(); + rpc.shutdown(); + } + @SuppressWarnings("rawtypes") private SpannerOptions createSpannerOptions() { String endpoint = address.getHostString() + ":" + server.getPort(); @@ -244,7 +340,17 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) { } }) .setHost("http://" + endpoint) - .setCredentials(NoCredentials.getInstance()) + // Set static credentials that will return the static OAuth test token. + .setCredentials(STATIC_CREDENTIALS) + // Also set a CallCredentialsProvider. These credentials should take precedence above + // the static credentials. + .setCallCredentialsProvider( + new CallCredentialsProvider() { + @Override + public CallCredentials getCallCredentials() { + return MoreCallCredentials.from(VARIABLE_CREDENTIALS); + } + }) .build(); }