Skip to content

Commit

Permalink
feat: add support for CallCredentials (#26)
Browse files Browse the repository at this point in the history
* feat: add support for CallCredentials

Adds support for using CallCredentials that could vary per call
instead of only static credentials passed in at startup.

Fixes #18

* fix: declare dependency usage
  • Loading branch information
olavloite committed Jan 14, 2020
1 parent 064eab0 commit 1112357
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 3 deletions.
4 changes: 4 additions & 0 deletions google-cloud-spanner/pom.xml
Expand Up @@ -80,6 +80,10 @@
<groupId>io.grpc</groupId>
<artifactId>grpc-api</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-auth</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-context</artifactId>
Expand Down
Expand Up @@ -37,6 +37,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannelBuilder;
import java.io.IOException;
import java.net.MalformedURLException;
Expand Down Expand Up @@ -72,6 +73,16 @@ public class SpannerOptions extends ServiceOptions<Spanner, SpannerOptions> {
private final InstanceAdminStubSettings instanceAdminStubSettings;
private final DatabaseAdminStubSettings databaseAdminStubSettings;
private final Duration partitionedDmlTimeout;
private final CallCredentialsProvider callCredentialsProvider;

/**
* Interface that can be used to provide {@link CallCredentials} instead of {@link Credentials} to
* {@link SpannerOptions}.
*/
public static interface CallCredentialsProvider {
/** Return the {@link CallCredentials} to use for a gRPC call. */
CallCredentials getCallCredentials();
}

/** Default implementation of {@code SpannerFactory}. */
private static class DefaultSpannerFactory implements SpannerFactory {
Expand Down Expand Up @@ -119,6 +130,7 @@ private SpannerOptions(Builder builder) {
throw SpannerExceptionFactory.newSpannerException(e);
}
partitionedDmlTimeout = builder.partitionedDmlTimeout;
callCredentialsProvider = builder.callCredentialsProvider;
}

/** Builder for {@link SpannerOptions} instances. */
Expand Down Expand Up @@ -150,6 +162,7 @@ public static class Builder
private DatabaseAdminStubSettings.Builder databaseAdminStubSettingsBuilder =
DatabaseAdminStubSettings.newBuilder();
private Duration partitionedDmlTimeout = Duration.ofHours(2L);
private CallCredentialsProvider callCredentialsProvider;
private String emulatorHost = System.getenv("SPANNER_EMULATOR_HOST");

private Builder() {}
Expand All @@ -164,6 +177,7 @@ private Builder() {}
this.instanceAdminStubSettingsBuilder = options.instanceAdminStubSettings.toBuilder();
this.databaseAdminStubSettingsBuilder = options.databaseAdminStubSettings.toBuilder();
this.partitionedDmlTimeout = options.partitionedDmlTimeout;
this.callCredentialsProvider = options.callCredentialsProvider;
this.channelProvider = options.channelProvider;
this.channelConfigurator = options.channelConfigurator;
this.interceptorProvider = options.interceptorProvider;
Expand Down Expand Up @@ -355,6 +369,17 @@ public Builder setPartitionedDmlTimeout(Duration timeout) {
return this;
}

/**
* Sets a {@link CallCredentialsProvider} that can deliver {@link CallCredentials} to use on a
* per-gRPC basis. Any credentials returned by this {@link CallCredentialsProvider} will have
* preference above any {@link Credentials} that may have been set on the {@link SpannerOptions}
* instance.
*/
public Builder setCallCredentialsProvider(CallCredentialsProvider callCredentialsProvider) {
this.callCredentialsProvider = callCredentialsProvider;
return this;
}

/**
* Specifying this will allow the client to prefetch up to {@code prefetchChunks} {@code
* PartialResultSet} chunks for each read and query. The data size of each chunk depends on the
Expand Down Expand Up @@ -452,6 +477,10 @@ public Duration getPartitionedDmlTimeout() {
return partitionedDmlTimeout;
}

public CallCredentialsProvider getCallCredentialsProvider() {
return callCredentialsProvider;
}

public int getPrefetchChunks() {
return prefetchChunks;
}
Expand Down
Expand Up @@ -40,12 +40,14 @@
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub;
import com.google.cloud.spanner.admin.database.v1.stub.GrpcDatabaseAdminStub;
import com.google.cloud.spanner.admin.instance.v1.stub.GrpcInstanceAdminStub;
import com.google.cloud.spanner.admin.instance.v1.stub.InstanceAdminStub;
import com.google.cloud.spanner.v1.stub.GrpcSpannerStub;
import com.google.cloud.spanner.v1.stub.SpannerStub;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand Down Expand Up @@ -99,6 +101,7 @@
import com.google.spanner.v1.RollbackRequest;
import com.google.spanner.v1.Session;
import com.google.spanner.v1.Transaction;
import io.grpc.CallCredentials;
import io.grpc.Context;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
Expand Down Expand Up @@ -174,6 +177,7 @@ private synchronized void shutdown() {
private final String projectId;
private final String projectName;
private final SpannerMetadataProvider metadataProvider;
private final CallCredentialsProvider callCredentialsProvider;
private final Duration waitTimeout =
systemProperty(PROPERTY_TIMEOUT_SECONDS, DEFAULT_TIMEOUT_SECONDS);
private final Duration idleTimeout =
Expand Down Expand Up @@ -216,6 +220,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
SpannerMetadataProvider.create(
mergedHeaderProvider.getHeaders(),
internalHeaderProviderBuilder.getResourceHeaderKey());
this.callCredentialsProvider = options.getCallCredentialsProvider();

// Create a managed executor provider.
this.executorProvider =
Expand Down Expand Up @@ -702,7 +707,8 @@ private static <T> T get(final Future<T> future) throws SpannerException {
}
}

private GrpcCallContext newCallContext(@Nullable Map<Option, ?> options, String resource) {
@VisibleForTesting
GrpcCallContext newCallContext(@Nullable Map<Option, ?> options, String resource) {
return newCallContext(options, resource, null);
}

Expand All @@ -716,6 +722,13 @@ private GrpcCallContext newCallContext(
if (timeout != null) {
context = context.withTimeout(timeout);
}
if (callCredentialsProvider != null) {
CallCredentials callCredentials = callCredentialsProvider.getCallCredentials();
if (callCredentials != null) {
context =
context.withCallOptions(context.getCallOptions().withCallCredentials(callCredentials));
}
}
return context.withStreamWaitTimeout(waitTimeout).withStreamIdleTimeout(idleTimeout);
}

Expand Down
Expand Up @@ -16,12 +16,14 @@

package com.google.cloud.spanner.spi.v1;

import static com.google.common.truth.Truth.assertThat;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

import com.google.api.core.ApiFunction;
import com.google.cloud.NoCredentials;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.cloud.spanner.DatabaseAdminClient;
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.DatabaseId;
Expand All @@ -31,9 +33,11 @@
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.admin.database.v1.MockDatabaseAdminImpl;
import com.google.cloud.spanner.admin.instance.v1.MockInstanceAdminImpl;
import com.google.cloud.spanner.spi.v1.SpannerRpc.Option;
import com.google.common.base.Stopwatch;
import com.google.protobuf.ListValue;
import com.google.spanner.admin.database.v1.Database;
Expand All @@ -45,13 +49,24 @@
import com.google.spanner.v1.StructType;
import com.google.spanner.v1.StructType.Field;
import com.google.spanner.v1.TypeCode;
import io.grpc.CallCredentials;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.Metadata.Key;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import org.junit.After;
Expand Down Expand Up @@ -91,11 +106,27 @@ public class GapicSpannerRpcTest {
.build())
.setMetadata(SELECT1AND2_METADATA)
.build();
private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN";
private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN";
private static final OAuth2Credentials STATIC_CREDENTIALS =
OAuth2Credentials.create(
new AccessToken(
STATIC_OAUTH_TOKEN,
new java.util.Date(
System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS))));
private static final OAuth2Credentials VARIABLE_CREDENTIALS =
OAuth2Credentials.create(
new AccessToken(
VARIABLE_OAUTH_TOKEN,
new java.util.Date(
System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS))));

private MockSpannerServiceImpl mockSpanner;
private MockInstanceAdminImpl mockInstanceAdmin;
private MockDatabaseAdminImpl mockDatabaseAdmin;
private Server server;
private InetSocketAddress address;
private final Map<SpannerRpc.Option, Object> optionsMap = new HashMap<>();

@Before
public void startServer() throws IOException {
Expand All @@ -111,8 +142,24 @@ public void startServer() throws IOException {
.addService(mockSpanner)
.addService(mockInstanceAdmin)
.addService(mockDatabaseAdmin)
// Add a server interceptor that will check that we receive the variable OAuth token
// from the CallCredentials, and not the one set as static credentials.
.intercept(
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
String auth =
headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER));
assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN);
return Contexts.interceptCall(Context.current(), call, headers, next);
}
})
.build()
.start();
optionsMap.put(Option.CHANNEL_HINT, Long.valueOf(1L));
}

@After
Expand Down Expand Up @@ -229,6 +276,55 @@ && getNumberOfThreadsWithName(SPANNER_THREAD_NAME, false)
assertThat(getNumberOfThreadsWithName(SPANNER_THREAD_NAME, true), is(equalTo(0)));
}

@Test
public void testCallCredentialsProviderPreferenceAboveCredentials() {
SpannerOptions options =
SpannerOptions.newBuilder()
.setCredentials(STATIC_CREDENTIALS)
.setCallCredentialsProvider(
new CallCredentialsProvider() {
@Override
public CallCredentials getCallCredentials() {
return MoreCallCredentials.from(VARIABLE_CREDENTIALS);
}
})
.build();
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
// GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the
// existence.
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
.isNotNull();
rpc.shutdown();
}

@Test
public void testCallCredentialsProviderReturnsNull() {
SpannerOptions options =
SpannerOptions.newBuilder()
.setCredentials(STATIC_CREDENTIALS)
.setCallCredentialsProvider(
new CallCredentialsProvider() {
@Override
public CallCredentials getCallCredentials() {
return null;
}
})
.build();
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
.isNull();
rpc.shutdown();
}

@Test
public void testNoCallCredentials() {
SpannerOptions options = SpannerOptions.newBuilder().setCredentials(STATIC_CREDENTIALS).build();
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
.isNull();
rpc.shutdown();
}

@SuppressWarnings("rawtypes")
private SpannerOptions createSpannerOptions() {
String endpoint = address.getHostString() + ":" + server.getPort();
Expand All @@ -244,7 +340,17 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
}
})
.setHost("http://" + endpoint)
.setCredentials(NoCredentials.getInstance())
// Set static credentials that will return the static OAuth test token.
.setCredentials(STATIC_CREDENTIALS)
// Also set a CallCredentialsProvider. These credentials should take precedence above
// the static credentials.
.setCallCredentialsProvider(
new CallCredentialsProvider() {
@Override
public CallCredentials getCallCredentials() {
return MoreCallCredentials.from(VARIABLE_CREDENTIALS);
}
})
.build();
}

Expand Down

0 comments on commit 1112357

Please sign in to comment.