diff --git a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/JwtCredentialsWithAudience.java b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/JwtCredentialsWithAudience.java
new file mode 100644
index 000000000..a88652769
--- /dev/null
+++ b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/JwtCredentialsWithAudience.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.bigtable.data.v2.internal;
+
+import com.google.api.core.InternalApi;
+import com.google.auth.Credentials;
+import com.google.auth.RequestMetadataCallback;
+import com.google.auth.oauth2.ServiceAccountJwtAccessCredentials;
+import java.io.IOException;
+import java.net.URI;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Executor;
+
+/**
+ * Internal helper to fix the mapping between JWT audiences and service endpoints.
+ *
+ * In most cases JWT audiences correspond to service endpoints. However, in some cases they
+ * diverge. To workaround this, this class hardcodes the audience and forces the underlying
+ * implementation to use it.
+ *
+ *
Internal Only - public for technical reasons
+ */
+@InternalApi
+public class JwtCredentialsWithAudience extends Credentials {
+ private final ServiceAccountJwtAccessCredentials delegate;
+
+ public JwtCredentialsWithAudience(ServiceAccountJwtAccessCredentials delegate, URI audience) {
+ this.delegate = delegate.toBuilder().setDefaultAudience(audience).build();
+ }
+
+ @Override
+ public String getAuthenticationType() {
+ return delegate.getAuthenticationType();
+ }
+
+ @Override
+ public Map> getRequestMetadata() throws IOException {
+ return delegate.getRequestMetadata();
+ }
+
+ @Override
+ public void getRequestMetadata(URI ignored, Executor executor, RequestMetadataCallback callback) {
+ delegate.getRequestMetadata(null, executor, callback);
+ }
+
+ @Override
+ public Map> getRequestMetadata(URI ignored) throws IOException {
+ return delegate.getRequestMetadata(null);
+ }
+
+ @Override
+ public boolean hasRequestMetadata() {
+ return delegate.hasRequestMetadata();
+ }
+
+ @Override
+ public boolean hasRequestMetadataOnly() {
+ return delegate.hasRequestMetadataOnly();
+ }
+
+ @Override
+ public void refresh() throws IOException {
+ delegate.refresh();
+ }
+}
diff --git a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java
index 62619f5bb..161dde232 100644
--- a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java
+++ b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java
@@ -21,6 +21,7 @@
import com.google.api.gax.batching.BatcherImpl;
import com.google.api.gax.batching.FlowController;
import com.google.api.gax.core.BackgroundResource;
+import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.grpc.GaxGrpcProperties;
import com.google.api.gax.grpc.GrpcCallContext;
@@ -42,6 +43,7 @@
import com.google.api.gax.tracing.TracedServerStreamingCallable;
import com.google.api.gax.tracing.TracedUnaryCallable;
import com.google.auth.Credentials;
+import com.google.auth.oauth2.ServiceAccountJwtAccessCredentials;
import com.google.bigtable.v2.BigtableGrpc;
import com.google.bigtable.v2.CheckAndMutateRowRequest;
import com.google.bigtable.v2.CheckAndMutateRowResponse;
@@ -56,6 +58,7 @@
import com.google.bigtable.v2.SampleRowKeysRequest;
import com.google.bigtable.v2.SampleRowKeysResponse;
import com.google.cloud.bigtable.Version;
+import com.google.cloud.bigtable.data.v2.internal.JwtCredentialsWithAudience;
import com.google.cloud.bigtable.data.v2.internal.RequestContext;
import com.google.cloud.bigtable.data.v2.models.BulkMutation;
import com.google.cloud.bigtable.data.v2.models.ConditionalRowMutation;
@@ -94,6 +97,8 @@
import io.opencensus.tags.Tagger;
import io.opencensus.tags.Tags;
import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
@@ -146,6 +151,9 @@ public static EnhancedBigtableStubSettings finalizeSettings(
// TODO: this implementation is on the cusp of unwieldy, if we end up adding more features
// consider splitting it up by feature.
+ // workaround JWT audience issues
+ patchCredentials(builder);
+
// Inject channel priming
if (settings.isRefreshingChannel()) {
// Fix the credentials so that they can be shared
@@ -218,6 +226,41 @@ public static EnhancedBigtableStubSettings finalizeSettings(
return builder.build();
}
+ private static void patchCredentials(EnhancedBigtableStubSettings.Builder settings)
+ throws IOException {
+ int i = settings.getEndpoint().lastIndexOf(":");
+ String host = settings.getEndpoint().substring(0, i);
+ String audience = settings.getJwtAudienceMapping().get(host);
+
+ if (audience == null) {
+ return;
+ }
+ URI audienceUri = null;
+ try {
+ audienceUri = new URI(audience);
+ } catch (URISyntaxException e) {
+ throw new IllegalStateException("invalid JWT audience override", e);
+ }
+
+ CredentialsProvider credentialsProvider = settings.getCredentialsProvider();
+ if (credentialsProvider == null) {
+ return;
+ }
+
+ Credentials credentials = credentialsProvider.getCredentials();
+ if (credentials == null) {
+ return;
+ }
+
+ if (!(credentials instanceof ServiceAccountJwtAccessCredentials)) {
+ return;
+ }
+
+ ServiceAccountJwtAccessCredentials jwtCreds = (ServiceAccountJwtAccessCredentials) credentials;
+ JwtCredentialsWithAudience patchedCreds = new JwtCredentialsWithAudience(jwtCreds, audienceUri);
+ settings.setCredentialsProvider(FixedCredentialsProvider.create(patchedCreds));
+ }
+
public EnhancedBigtableStub(EnhancedBigtableStubSettings settings, ClientContext clientContext) {
this.settings = settings;
this.clientContext = clientContext;
diff --git a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java
index c5e39e460..e5424c586 100644
--- a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java
+++ b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java
@@ -16,6 +16,7 @@
package com.google.cloud.bigtable.data.v2.stub;
import com.google.api.core.BetaApi;
+import com.google.api.core.InternalApi;
import com.google.api.gax.batching.BatchingCallSettings;
import com.google.api.gax.batching.BatchingSettings;
import com.google.api.gax.batching.FlowControlSettings;
@@ -151,12 +152,20 @@ public class EnhancedBigtableStubSettings extends StubSettings DEFAULT_JWT_AUDIENCE_MAPPING =
+ ImmutableMap.of("batch-bigtable.googleapis.com", "https://bigtable.googleapis.com/");
+
private final String projectId;
private final String instanceId;
private final String appProfileId;
private final boolean isRefreshingChannel;
private ImmutableList primedTableIds;
private HeaderTracer headerTracer;
+ private final Map jwtAudienceMapping;
private final ServerStreamingCallSettings readRowsSettings;
private final UnaryCallSettings readRowSettings;
@@ -191,6 +200,7 @@ private EnhancedBigtableStubSettings(Builder builder) {
isRefreshingChannel = builder.isRefreshingChannel;
primedTableIds = builder.primedTableIds;
headerTracer = builder.headerTracer;
+ jwtAudienceMapping = builder.jwtAudienceMapping;
// Per method settings.
readRowsSettings = builder.readRowsSettings.build();
@@ -240,6 +250,11 @@ HeaderTracer getHeaderTracer() {
return headerTracer;
}
+ @InternalApi("Used for internal testing")
+ public Map getJwtAudienceMapping() {
+ return jwtAudienceMapping;
+ }
+
/** Returns a builder for the default ChannelProvider for this service. */
public static InstantiatingGrpcChannelProvider.Builder defaultGrpcTransportProviderBuilder() {
return BigtableStubSettings.defaultGrpcTransportProviderBuilder()
@@ -498,6 +513,7 @@ public static class Builder extends StubSettings.Builder primedTableIds;
private HeaderTracer headerTracer;
+ private Map jwtAudienceMapping;
private final ServerStreamingCallSettings.Builder readRowsSettings;
private final UnaryCallSettings.Builder readRowSettings;
@@ -522,6 +538,7 @@ private Builder() {
this.isRefreshingChannel = false;
primedTableIds = ImmutableList.of();
headerTracer = HeaderTracer.newBuilder().build();
+ jwtAudienceMapping = DEFAULT_JWT_AUDIENCE_MAPPING;
setCredentialsProvider(defaultCredentialsProviderBuilder().build());
// Defaults provider
@@ -629,6 +646,7 @@ private Builder(EnhancedBigtableStubSettings settings) {
isRefreshingChannel = settings.isRefreshingChannel;
primedTableIds = settings.primedTableIds;
headerTracer = settings.headerTracer;
+ jwtAudienceMapping = settings.jwtAudienceMapping;
// Per method settings.
readRowsSettings = settings.readRowsSettings.toBuilder();
@@ -762,6 +780,17 @@ HeaderTracer getHeaderTracer() {
return headerTracer;
}
+ @InternalApi("Used for internal testing")
+ public Builder setJwtAudienceMapping(Map jwtAudienceMapping) {
+ this.jwtAudienceMapping = Preconditions.checkNotNull(jwtAudienceMapping);
+ return this;
+ }
+
+ @InternalApi("Used for internal testing")
+ public Map getJwtAudienceMapping() {
+ return jwtAudienceMapping;
+ }
+
/** Returns the builder for the settings used for calls to readRows. */
public ServerStreamingCallSettings.Builder readRowsSettings() {
return readRowsSettings;
@@ -842,6 +871,7 @@ public String toString() {
.add("isRefreshingChannel", isRefreshingChannel)
.add("primedTableIds", primedTableIds)
.add("headerTracer", headerTracer)
+ .add("jwtAudienceMapping", jwtAudienceMapping)
.add("readRowsSettings", readRowsSettings)
.add("readRowSettings", readRowSettings)
.add("sampleRowKeysSettings", sampleRowKeysSettings)
diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java
index b0f60be60..8af4bdafb 100644
--- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java
+++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java
@@ -716,6 +716,7 @@ public void verifyDefaultHeaderTracerNotNullTest() {
"isRefreshingChannel",
"primedTableIds",
"headerTracer",
+ "jwtAudienceMapping",
"readRowsSettings",
"readRowSettings",
"sampleRowKeysSettings",
diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubTest.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubTest.java
index 8cb82359a..ae045123f 100644
--- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubTest.java
+++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubTest.java
@@ -17,15 +17,21 @@
import static com.google.common.truth.Truth.assertThat;
+import com.google.api.client.json.gson.GsonFactory;
+import com.google.api.client.json.webtoken.JsonWebSignature;
import com.google.api.gax.batching.Batcher;
import com.google.api.gax.batching.BatcherImpl;
import com.google.api.gax.batching.BatchingSettings;
import com.google.api.gax.batching.FlowControlSettings;
import com.google.api.gax.batching.FlowController.LimitExceededBehavior;
+import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.core.NoCredentialsProvider;
import com.google.api.gax.grpc.GaxGrpcProperties;
import com.google.api.gax.grpc.GrpcCallContext;
+import com.google.api.gax.grpc.GrpcTransportChannel;
+import com.google.api.gax.rpc.FixedTransportChannelProvider;
import com.google.api.gax.rpc.ServerStreamingCallable;
+import com.google.auth.oauth2.ServiceAccountJwtAccessCredentials;
import com.google.bigtable.v2.BigtableGrpc;
import com.google.bigtable.v2.MutateRowsRequest;
import com.google.bigtable.v2.MutateRowsResponse;
@@ -42,6 +48,7 @@
import com.google.cloud.bigtable.data.v2.models.Row;
import com.google.cloud.bigtable.data.v2.models.RowMutationEntry;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Queues;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
@@ -49,7 +56,10 @@
import io.grpc.BindableService;
import io.grpc.Context;
import io.grpc.Deadline;
+import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
+import io.grpc.Metadata.Key;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
@@ -63,6 +73,9 @@
import io.opencensus.trace.export.SpanExporter.Handler;
import io.opencensus.trace.samplers.Samplers;
import java.io.IOException;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
import java.util.Collection;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
@@ -122,6 +135,98 @@ public void tearDown() {
serviceHelper.shutdown();
}
+ @Test
+ public void testJwtAudience()
+ throws InterruptedException, IOException, NoSuchAlgorithmException, ExecutionException {
+ // close default stub - need to create custom one
+ enhancedBigtableStub.close();
+
+ // Create fake jwt creds
+ KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
+ KeyPair keyPair = keyGen.genKeyPair();
+
+ ServiceAccountJwtAccessCredentials jwtCreds =
+ ServiceAccountJwtAccessCredentials.newBuilder()
+ .setClientId("fake-id")
+ .setClientEmail("fake@example.com")
+ .setPrivateKey(keyPair.getPrivate())
+ .setPrivateKeyId("fake-private-key")
+ .build();
+
+ // Create a stub with overridden audience
+ String expectedAudience = "http://localaudience";
+ EnhancedBigtableStubSettings settings =
+ defaultSettings
+ .toBuilder()
+ .setJwtAudienceMapping(ImmutableMap.of("localhost", expectedAudience))
+ .setCredentialsProvider(FixedCredentialsProvider.create(jwtCreds))
+ .build();
+ enhancedBigtableStub = EnhancedBigtableStub.create(settings);
+
+ // Send rpc and grab the credentials sent
+ enhancedBigtableStub.readRowCallable().futureCall(Query.create("fake-table")).get();
+ Metadata metadata = metadataInterceptor.headers.take();
+
+ String authValue = metadata.get(Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER));
+ String expectedPrefix = "Bearer ";
+ assertThat(authValue).startsWith(expectedPrefix);
+ String jwtStr = authValue.substring(expectedPrefix.length());
+ JsonWebSignature parsed = JsonWebSignature.parse(GsonFactory.getDefaultInstance(), jwtStr);
+ assertThat(parsed.getPayload().getAudience()).isEqualTo(expectedAudience);
+ }
+
+ @Test
+ public void testBatchJwtAudience()
+ throws InterruptedException, IOException, NoSuchAlgorithmException, ExecutionException {
+ // close default stub - need to create custom one
+ enhancedBigtableStub.close();
+
+ // Create fake jwt creds
+ KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
+ KeyPair keyPair = keyGen.genKeyPair();
+
+ ServiceAccountJwtAccessCredentials jwtCreds =
+ ServiceAccountJwtAccessCredentials.newBuilder()
+ .setClientId("fake-id")
+ .setClientEmail("fake@example.com")
+ .setPrivateKey(keyPair.getPrivate())
+ .setPrivateKeyId("fake-private-key")
+ .build();
+
+ // Create a fixed channel that will ignore the default endpoint and connect to the emulator
+ ManagedChannel emulatorChannel =
+ ManagedChannelBuilder.forAddress("localhost", serviceHelper.getPort())
+ .usePlaintext()
+ .build();
+
+ Metadata metadata;
+ try {
+ EnhancedBigtableStubSettings settings =
+ EnhancedBigtableStubSettings.newBuilder()
+ .setProjectId("fake-project")
+ .setInstanceId("fake-instance")
+ .setEndpoint("batch-bigtable.googleapis.com:443")
+ .setCredentialsProvider(FixedCredentialsProvider.create(jwtCreds))
+ .setTransportChannelProvider(
+ FixedTransportChannelProvider.create(
+ GrpcTransportChannel.create(emulatorChannel)))
+ .build();
+ enhancedBigtableStub = EnhancedBigtableStub.create(settings);
+ // Send rpc and grab the credentials sent
+ enhancedBigtableStub.readRowCallable().futureCall(Query.create("fake-table")).get();
+ metadata = metadataInterceptor.headers.take();
+ } finally {
+ emulatorChannel.shutdown();
+ }
+
+ String authValue = metadata.get(Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER));
+ String expectedPrefix = "Bearer ";
+ assertThat(authValue).startsWith(expectedPrefix);
+ String jwtStr = authValue.substring(expectedPrefix.length());
+ JsonWebSignature parsed = JsonWebSignature.parse(GsonFactory.getDefaultInstance(), jwtStr);
+ assertThat(parsed.getPayload().getAudience()).isEqualTo("https://bigtable.googleapis.com/");
+ }
+
@Test
public void testCreateReadRowsCallable() throws InterruptedException {
ServerStreamingCallable streamingCallable =