Skip to content

Commit

Permalink
allow the use of CallCredentials
Browse files Browse the repository at this point in the history
Allow the user to supply io.grpc.CallCredentials instead
of only com.google.auth.Credentials. Any CallCredentials
supplied will take precedence above the Credentials set on
SpannerOptions.

Fixes #6373
  • Loading branch information
olavloite committed Oct 3, 2019
1 parent 19eee78 commit 997d8b9
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 2 deletions.
Expand Up @@ -20,6 +20,7 @@
import com.google.api.gax.grpc.GrpcInterceptorProvider;
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.Credentials;
import com.google.cloud.ServiceDefaults;
import com.google.cloud.ServiceOptions;
import com.google.cloud.ServiceRpc;
Expand All @@ -37,6 +38,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannelBuilder;
import java.io.IOException;
import java.net.MalformedURLException;
Expand Down Expand Up @@ -72,6 +74,16 @@ public class SpannerOptions extends ServiceOptions<Spanner, SpannerOptions> {
private final InstanceAdminStubSettings instanceAdminStubSettings;
private final DatabaseAdminStubSettings databaseAdminStubSettings;
private final Duration partitionedDmlTimeout;
private final CallCredentialsProvider callCredentialsProvider;

/**
* Interface that can be used to provide {@link CallCredentials} instead of {@link Credentials} to
* {@link SpannerOptions}.
*/
public static interface CallCredentialsProvider {
/** Return the {@link CallCredentials} to use for a gRPC call. */
CallCredentials getCallCredentials();
}

/** Default implementation of {@code SpannerFactory}. */
private static class DefaultSpannerFactory implements SpannerFactory {
Expand Down Expand Up @@ -119,6 +131,7 @@ private SpannerOptions(Builder builder) {
throw SpannerExceptionFactory.newSpannerException(e);
}
partitionedDmlTimeout = builder.partitionedDmlTimeout;
callCredentialsProvider = builder.callCredentialsProvider;
}

/** Builder for {@link SpannerOptions} instances. */
Expand Down Expand Up @@ -150,6 +163,7 @@ public static class Builder
private DatabaseAdminStubSettings.Builder databaseAdminStubSettingsBuilder =
DatabaseAdminStubSettings.newBuilder();
private Duration partitionedDmlTimeout = Duration.ofHours(2L);
private CallCredentialsProvider callCredentialsProvider;

private Builder() {}

Expand All @@ -163,6 +177,7 @@ private Builder() {}
this.instanceAdminStubSettingsBuilder = options.instanceAdminStubSettings.toBuilder();
this.databaseAdminStubSettingsBuilder = options.databaseAdminStubSettings.toBuilder();
this.partitionedDmlTimeout = options.partitionedDmlTimeout;
this.callCredentialsProvider = options.callCredentialsProvider;
this.channelProvider = options.channelProvider;
this.channelConfigurator = options.channelConfigurator;
this.interceptorProvider = options.interceptorProvider;
Expand Down Expand Up @@ -354,6 +369,17 @@ public Builder setPartitionedDmlTimeout(Duration timeout) {
return this;
}

/**
* Sets a {@link CallCredentialsProvider} that can deliver {@link CallCredentials} to use on a
* per-gRPC basis. Any credentials returned by this {@link CallCredentialsProvider} will have
* preference above any {@link Credentials} that may have been set on the {@link SpannerOptions}
* instance.
*/
public Builder setCallCredentialsProvider(CallCredentialsProvider callCredentialsProvider) {
this.callCredentialsProvider = callCredentialsProvider;
return this;
}

/**
* Specifying this will allow the client to prefetch up to {@code prefetchChunks} {@code
* PartialResultSet} chunks for each read and query. The data size of each chunk depends on the
Expand Down Expand Up @@ -426,6 +452,10 @@ public Duration getPartitionedDmlTimeout() {
return partitionedDmlTimeout;
}

public CallCredentialsProvider getCallCredentialsProvider() {
return callCredentialsProvider;
}

public int getPrefetchChunks() {
return prefetchChunks;
}
Expand Down
Expand Up @@ -40,6 +40,7 @@
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub;
import com.google.cloud.spanner.admin.database.v1.stub.GrpcDatabaseAdminStub;
import com.google.cloud.spanner.admin.instance.v1.stub.GrpcInstanceAdminStub;
Expand Down Expand Up @@ -94,6 +95,7 @@
import com.google.spanner.v1.RollbackRequest;
import com.google.spanner.v1.Session;
import com.google.spanner.v1.Transaction;
import io.grpc.CallCredentials;
import io.grpc.Context;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -166,6 +168,7 @@ private synchronized void shutdown() {
private final String projectId;
private final String projectName;
private final SpannerMetadataProvider metadataProvider;
private final CallCredentialsProvider callCredentialsProvider;
private final Duration waitTimeout =
systemProperty(PROPERTY_TIMEOUT_SECONDS, DEFAULT_TIMEOUT_SECONDS);
private final Duration idleTimeout =
Expand Down Expand Up @@ -201,6 +204,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
SpannerMetadataProvider.create(
mergedHeaderProvider.getHeaders(),
internalHeaderProviderBuilder.getResourceHeaderKey());
this.callCredentialsProvider = options.getCallCredentialsProvider();

// Create a managed executor provider.
this.executorProvider =
Expand Down Expand Up @@ -631,6 +635,13 @@ private GrpcCallContext newCallContext(
if (timeout != null) {
context = context.withTimeout(timeout);
}
if (callCredentialsProvider != null) {
CallCredentials callCredentials = callCredentialsProvider.getCallCredentials();
if (callCredentials != null) {
context =
context.withCallOptions(context.getCallOptions().withCallCredentials(callCredentials));
}
}
return context.withStreamWaitTimeout(waitTimeout).withStreamIdleTimeout(idleTimeout);
}

Expand Down
Expand Up @@ -16,12 +16,14 @@

package com.google.cloud.spanner.spi.v1;

import static com.google.common.truth.Truth.assertThat;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;

import com.google.api.core.ApiFunction;
import com.google.cloud.NoCredentials;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.cloud.spanner.DatabaseAdminClient;
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.DatabaseId;
Expand All @@ -31,6 +33,7 @@
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.admin.database.v1.MockDatabaseAdminImpl;
import com.google.cloud.spanner.admin.instance.v1.MockInstanceAdminImpl;
Expand All @@ -45,11 +48,22 @@
import com.google.spanner.v1.StructType;
import com.google.spanner.v1.StructType.Field;
import com.google.spanner.v1.TypeCode;
import io.grpc.CallCredentials;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.Metadata.Key;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.sql.Date;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -91,6 +105,8 @@ public class GapicSpannerRpcTest {
.build())
.setMetadata(SELECT1AND2_METADATA)
.build();
private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN";
private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN";
private MockSpannerServiceImpl mockSpanner;
private MockInstanceAdminImpl mockInstanceAdmin;
private MockDatabaseAdminImpl mockDatabaseAdmin;
Expand All @@ -111,6 +127,21 @@ public void startServer() throws IOException {
.addService(mockSpanner)
.addService(mockInstanceAdmin)
.addService(mockDatabaseAdmin)
// Add a server interceptor that will check that we receive the variable OAuth token
// from the CallCredentials, and not the one set as static credentials.
.intercept(
new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
String auth =
headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER));
assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN);
return Contexts.interceptCall(Context.current(), call, headers, next);
}
})
.build()
.start();
}
Expand Down Expand Up @@ -244,7 +275,29 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
}
})
.setHost("http://" + endpoint)
.setCredentials(NoCredentials.getInstance())
// Set static credentials that will return the static OAuth test token.
.setCredentials(
OAuth2Credentials.create(
new AccessToken(
STATIC_OAUTH_TOKEN,
new Date(
System.currentTimeMillis()
+ TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS)))))
// Also set a CallCredentialsProvider. These credentials should take precedence above
// the static credentials.
.setCallCredentialsProvider(
new CallCredentialsProvider() {
@Override
public CallCredentials getCallCredentials() {
return MoreCallCredentials.from(
OAuth2Credentials.create(
new AccessToken(
VARIABLE_OAUTH_TOKEN,
new Date(
System.currentTimeMillis()
+ TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS)))));
}
})
.build();
}

Expand Down

0 comments on commit 997d8b9

Please sign in to comment.