Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for CallCredentials #26

Merged
merged 2 commits into from Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions google-cloud-spanner/pom.xml
Expand Up @@ -80,6 +80,10 @@
<groupId>io.grpc</groupId>
<artifactId>grpc-api</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-auth</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-context</artifactId>
Expand Down
Expand Up @@ -37,6 +37,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannelBuilder;
import java.io.IOException;
import java.net.MalformedURLException;
Expand Down Expand Up @@ -72,6 +73,16 @@ public class SpannerOptions extends ServiceOptions<Spanner, SpannerOptions> {
private final InstanceAdminStubSettings instanceAdminStubSettings;
private final DatabaseAdminStubSettings databaseAdminStubSettings;
private final Duration partitionedDmlTimeout;
private final CallCredentialsProvider callCredentialsProvider;

/**
* Interface that can be used to provide {@link CallCredentials} instead of {@link Credentials} to
* {@link SpannerOptions}.
*/
public static interface CallCredentialsProvider {
/** Return the {@link CallCredentials} to use for a gRPC call. */
CallCredentials getCallCredentials();
}

/** Default implementation of {@code SpannerFactory}. */
private static class DefaultSpannerFactory implements SpannerFactory {
Expand Down Expand Up @@ -119,6 +130,7 @@ private SpannerOptions(Builder builder) {
throw SpannerExceptionFactory.newSpannerException(e);
}
partitionedDmlTimeout = builder.partitionedDmlTimeout;
callCredentialsProvider = builder.callCredentialsProvider;
}

/** Builder for {@link SpannerOptions} instances. */
Expand Down Expand Up @@ -150,6 +162,7 @@ public static class Builder
private DatabaseAdminStubSettings.Builder databaseAdminStubSettingsBuilder =
DatabaseAdminStubSettings.newBuilder();
private Duration partitionedDmlTimeout = Duration.ofHours(2L);
private CallCredentialsProvider callCredentialsProvider;
private String emulatorHost = System.getenv("SPANNER_EMULATOR_HOST");

private Builder() {}
Expand All @@ -164,6 +177,7 @@ private Builder() {}
this.instanceAdminStubSettingsBuilder = options.instanceAdminStubSettings.toBuilder();
this.databaseAdminStubSettingsBuilder = options.databaseAdminStubSettings.toBuilder();
this.partitionedDmlTimeout = options.partitionedDmlTimeout;
this.callCredentialsProvider = options.callCredentialsProvider;
this.channelProvider = options.channelProvider;
this.channelConfigurator = options.channelConfigurator;
this.interceptorProvider = options.interceptorProvider;
Expand Down Expand Up @@ -355,6 +369,17 @@ public Builder setPartitionedDmlTimeout(Duration timeout) {
return this;
}

/**
* Sets a {@link CallCredentialsProvider} that can deliver {@link CallCredentials} to use on a
* per-gRPC basis. Any credentials returned by this {@link CallCredentialsProvider} will have
* preference above any {@link Credentials} that may have been set on the {@link SpannerOptions}
* instance.
*/
public Builder setCallCredentialsProvider(CallCredentialsProvider callCredentialsProvider) {
this.callCredentialsProvider = callCredentialsProvider;
return this;
}

/**
* Specifying this will allow the client to prefetch up to {@code prefetchChunks} {@code
* PartialResultSet} chunks for each read and query. The data size of each chunk depends on the
Expand Down Expand Up @@ -452,6 +477,10 @@ public Duration getPartitionedDmlTimeout() {
return partitionedDmlTimeout;
}

public CallCredentialsProvider getCallCredentialsProvider() {
return callCredentialsProvider;
}

public int getPrefetchChunks() {
return prefetchChunks;
}
Expand Down
Expand Up @@ -40,12 +40,14 @@
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub;
import com.google.cloud.spanner.admin.database.v1.stub.GrpcDatabaseAdminStub;
import com.google.cloud.spanner.admin.instance.v1.stub.GrpcInstanceAdminStub;
import com.google.cloud.spanner.admin.instance.v1.stub.InstanceAdminStub;
import com.google.cloud.spanner.v1.stub.GrpcSpannerStub;
import com.google.cloud.spanner.v1.stub.SpannerStub;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand Down Expand Up @@ -99,6 +101,7 @@
import com.google.spanner.v1.RollbackRequest;
import com.google.spanner.v1.Session;
import com.google.spanner.v1.Transaction;
import io.grpc.CallCredentials;
import io.grpc.Context;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
Expand Down Expand Up @@ -174,6 +177,7 @@ private synchronized void shutdown() {
private final String projectId;
private final String projectName;
private final SpannerMetadataProvider metadataProvider;
private final CallCredentialsProvider callCredentialsProvider;
private final Duration waitTimeout =
systemProperty(PROPERTY_TIMEOUT_SECONDS, DEFAULT_TIMEOUT_SECONDS);
private final Duration idleTimeout =
Expand Down Expand Up @@ -216,6 +220,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
SpannerMetadataProvider.create(
mergedHeaderProvider.getHeaders(),
internalHeaderProviderBuilder.getResourceHeaderKey());
this.callCredentialsProvider = options.getCallCredentialsProvider();

// Create a managed executor provider.
this.executorProvider =
Expand Down Expand Up @@ -702,7 +707,8 @@ private static <T> T get(final Future<T> future) throws SpannerException {
}
}

private GrpcCallContext newCallContext(@Nullable Map<Option, ?> options, String resource) {
@VisibleForTesting
GrpcCallContext newCallContext(@Nullable Map<Option, ?> options, String resource) {
return newCallContext(options, resource, null);
}

Expand All @@ -716,6 +722,13 @@ private GrpcCallContext newCallContext(
if (timeout != null) {
context = context.withTimeout(timeout);
}
if (callCredentialsProvider != null) {
CallCredentials callCredentials = callCredentialsProvider.getCallCredentials();
if (callCredentials != null) {
context =
context.withCallOptions(context.getCallOptions().withCallCredentials(callCredentials));
}
}
return context.withStreamWaitTimeout(waitTimeout).withStreamIdleTimeout(idleTimeout);
}

Expand Down
Expand Up @@ -16,12 +16,14 @@

package com.google.cloud.spanner.spi.v1;

import static com.google.common.truth.Truth.assertThat;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

import com.google.api.core.ApiFunction;
import com.google.cloud.NoCredentials;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.cloud.spanner.DatabaseAdminClient;
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.DatabaseId;
Expand All @@ -31,9 +33,11 @@
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.admin.database.v1.MockDatabaseAdminImpl;
import com.google.cloud.spanner.admin.instance.v1.MockInstanceAdminImpl;
import com.google.cloud.spanner.spi.v1.SpannerRpc.Option;
import com.google.common.base.Stopwatch;
import com.google.protobuf.ListValue;
import com.google.spanner.admin.database.v1.Database;
Expand All @@ -45,13 +49,24 @@
import com.google.spanner.v1.StructType;
import com.google.spanner.v1.StructType.Field;
import com.google.spanner.v1.TypeCode;
import io.grpc.CallCredentials;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.Metadata.Key;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import org.junit.After;
Expand Down Expand Up @@ -91,11 +106,27 @@ public class GapicSpannerRpcTest {
.build())
.setMetadata(SELECT1AND2_METADATA)
.build();
private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN";
private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN";
private static final OAuth2Credentials STATIC_CREDENTIALS =
OAuth2Credentials.create(
new AccessToken(
STATIC_OAUTH_TOKEN,
new java.util.Date(
System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS))));
private static final OAuth2Credentials VARIABLE_CREDENTIALS =
OAuth2Credentials.create(
new AccessToken(
VARIABLE_OAUTH_TOKEN,
new java.util.Date(
System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS))));

private MockSpannerServiceImpl mockSpanner;
private MockInstanceAdminImpl mockInstanceAdmin;
private MockDatabaseAdminImpl mockDatabaseAdmin;
private Server server;
private InetSocketAddress address;
private final Map<SpannerRpc.Option, Object> optionsMap = new HashMap<>();

@Before
public void startServer() throws IOException {
Expand All @@ -111,8 +142,24 @@ public void startServer() throws IOException {
.addService(mockSpanner)
.addService(mockInstanceAdmin)
.addService(mockDatabaseAdmin)
// Add a server interceptor that will check that we receive the variable OAuth token
// from the CallCredentials, and not the one set as static credentials.
.intercept(
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
String auth =
headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER));
assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN);
return Contexts.interceptCall(Context.current(), call, headers, next);
}
})
.build()
.start();
optionsMap.put(Option.CHANNEL_HINT, Long.valueOf(1L));
}

@After
Expand Down Expand Up @@ -229,6 +276,55 @@ && getNumberOfThreadsWithName(SPANNER_THREAD_NAME, false)
assertThat(getNumberOfThreadsWithName(SPANNER_THREAD_NAME, true), is(equalTo(0)));
}

@Test
public void testCallCredentialsProviderPreferenceAboveCredentials() {
SpannerOptions options =
SpannerOptions.newBuilder()
.setCredentials(STATIC_CREDENTIALS)
.setCallCredentialsProvider(
new CallCredentialsProvider() {
@Override
public CallCredentials getCallCredentials() {
return MoreCallCredentials.from(VARIABLE_CREDENTIALS);
}
})
.build();
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
// GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the
// existence.
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
.isNotNull();
rpc.shutdown();
}

@Test
public void testCallCredentialsProviderReturnsNull() {
SpannerOptions options =
SpannerOptions.newBuilder()
.setCredentials(STATIC_CREDENTIALS)
.setCallCredentialsProvider(
new CallCredentialsProvider() {
@Override
public CallCredentials getCallCredentials() {
return null;
}
})
.build();
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
.isNull();
rpc.shutdown();
}

@Test
public void testNoCallCredentials() {
SpannerOptions options = SpannerOptions.newBuilder().setCredentials(STATIC_CREDENTIALS).build();
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
.isNull();
rpc.shutdown();
}

@SuppressWarnings("rawtypes")
private SpannerOptions createSpannerOptions() {
String endpoint = address.getHostString() + ":" + server.getPort();
Expand All @@ -244,7 +340,17 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
}
})
.setHost("http://" + endpoint)
.setCredentials(NoCredentials.getInstance())
// Set static credentials that will return the static OAuth test token.
.setCredentials(STATIC_CREDENTIALS)
// Also set a CallCredentialsProvider. These credentials should take precedence above
// the static credentials.
.setCallCredentialsProvider(
new CallCredentialsProvider() {
@Override
public CallCredentials getCallCredentials() {
return MoreCallCredentials.from(VARIABLE_CREDENTIALS);
}
})
.build();
}

Expand Down