Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

feat: add mtls feature to http and grpc transport provider #1249

Merged
merged 14 commits into from May 26, 2021
2 changes: 1 addition & 1 deletion dependencies.properties
Expand Up @@ -25,7 +25,7 @@ version.gax_httpjson=0.81.1-SNAPSHOT
# with the sources.
version.com_google_protobuf=3.15.2
version.google_java_format=1.1
version.io_grpc=1.36.0
version.io_grpc=1.37.0
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

# Maven artifacts.
# Note, the actual name of each property matters (bazel build scripts depend on it).
Expand Down
Expand Up @@ -38,6 +38,7 @@
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -46,16 +47,22 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.CharStreams;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.net.ssl.KeyManagerFactory;
import org.threeten.bp.Duration;

/**
Expand Down Expand Up @@ -96,6 +103,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private final MtlsProvider mtlsProvider;

@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
Expand All @@ -105,6 +113,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.executor = builder.executor;
this.headerProvider = builder.headerProvider;
this.endpoint = builder.endpoint;
this.mtlsProvider = builder.mtlsProvider;
this.envProvider = builder.envProvider;
this.interceptorProvider = builder.interceptorProvider;
this.maxInboundMessageSize = builder.maxInboundMessageSize;
Expand Down Expand Up @@ -264,6 +273,24 @@ static boolean isOnComputeEngine() {
return false;
}

@VisibleForTesting
ChannelCredentials createMtlsChannelCredentials() throws IOException {
if (mtlsProvider.useMtlsClientCertificate()) {
try {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
if (mtlsKeyStore != null) {
KeyManagerFactory factory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
factory.init(mtlsKeyStore, new char[] {});
return TlsChannelCredentials.newBuilder().keyManager(factory.getKeyManagers()).build();
}
} catch (GeneralSecurityException e) {
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
throw new IOException(e.toString());
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
}
}
return null;
}

private ManagedChannel createSingleChannel() throws IOException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headerProvider.getHeaders());
Expand All @@ -290,7 +317,12 @@ && isOnComputeEngine()) {
builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS);
builder.defaultServiceConfig(directPathServiceConfig);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
ChannelCredentials mTlsChannelCredentials = createMtlsChannelCredentials();
if (mTlsChannelCredentials != null) {
builder = Grpc.newChannelBuilder(endpoint, mTlsChannelCredentials);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
}
}
builder =
builder
Expand Down Expand Up @@ -376,6 +408,7 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private EnvironmentProvider envProvider;
private MtlsProvider mtlsProvider = new MtlsProvider();
@Nullable private GrpcInterceptorProvider interceptorProvider;
@Nullable private Integer maxInboundMessageSize;
@Nullable private Integer maxInboundMetadataSize;
Expand Down Expand Up @@ -412,6 +445,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.channelPrimer = provider.channelPrimer;
this.attemptDirectPath = provider.attemptDirectPath;
this.directPathServiceConfig = provider.directPathServiceConfig;
this.mtlsProvider = provider.mtlsProvider;
}

/** Sets the number of available CPUs, used internally for testing. */
Expand Down Expand Up @@ -458,6 +492,12 @@ public Builder setEndpoint(String endpoint) {
return this;
}

@VisibleForTesting
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
return this;
}

/**
* Sets the GrpcInterceptorProvider for this TransportChannelProvider.
*
Expand Down
Expand Up @@ -38,6 +38,8 @@
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.collect.ImmutableList;
Expand All @@ -63,8 +65,7 @@
import org.threeten.bp.Duration;

@RunWith(JUnit4.class)
public class InstantiatingGrpcChannelProviderTest {

public class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
@Test
public void testEndpoint() {
String endpoint = "localhost:8080";
Expand Down Expand Up @@ -499,4 +500,15 @@ public void testWithCustomDirectPathServiceConfig() {
ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;
assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig);
}

@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider) throws IOException {
return InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build()
.createMtlsChannelCredentials();
}
}
Expand Up @@ -50,6 +50,7 @@
import com.google.api.gax.rpc.StubSettings;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -83,6 +84,7 @@ private static class FakeStubSettings extends StubSettings<FakeStubSettings> {
public static final int DEFAULT_SERVICE_PORT = 443;
public static final String DEFAULT_SERVICE_ENDPOINT =
DEFAULT_SERVICE_ADDRESS + ':' + DEFAULT_SERVICE_PORT;
public static final MtlsProvider DEFAULT_MTLS_PROVIDER = new MtlsProvider();
public static final ImmutableList<String> DEFAULT_SERVICE_SCOPES =
ImmutableList.<String>builder()
.add("https://www.googleapis.com/auth/pubsub")
Expand Down Expand Up @@ -148,7 +150,9 @@ public static InstantiatingExecutorProvider.Builder defaultExecutorProviderBuild

/** Returns a builder for the default TransportChannelProvider for this service. */
public static InstantiatingGrpcChannelProvider.Builder defaultGrpcChannelProviderBuilder() {
return InstantiatingGrpcChannelProvider.newBuilder().setEndpoint(DEFAULT_SERVICE_ENDPOINT);
return InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint(DEFAULT_SERVICE_ENDPOINT)
.setMtlsProvider(DEFAULT_MTLS_PROVIDER);
}

public static ApiClientHeaderProvider.Builder defaultGoogleServiceHeaderProviderBuilder() {
Expand Down
Expand Up @@ -30,16 +30,21 @@
package com.google.api.gax.httpjson;

import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.javanet.NetHttpTransport;
import com.google.api.core.BetaApi;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
Expand All @@ -64,24 +69,28 @@ public final class InstantiatingHttpJsonChannelProvider implements TransportChan
private final HeaderProvider headerProvider;
private final String endpoint;
private final HttpTransport httpTransport;
private final MtlsProvider mtlsProvider;

private InstantiatingHttpJsonChannelProvider(
Executor executor, HeaderProvider headerProvider, String endpoint) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = null;
this.mtlsProvider = new MtlsProvider();
}

private InstantiatingHttpJsonChannelProvider(
Executor executor,
HeaderProvider headerProvider,
String endpoint,
HttpTransport httpTransport) {
HttpTransport httpTransport,
MtlsProvider mtlsProvider) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = httpTransport;
this.mtlsProvider = mtlsProvider;
}

@Override
Expand Down Expand Up @@ -160,6 +169,20 @@ public TransportChannelProvider withCredentials(Credentials credentials) {
"InstantiatingHttpJsonChannelProvider doesn't need credentials");
}

HttpTransport createHttpTransport() throws IOException {
if (mtlsProvider.useMtlsClientCertificate()) {
try {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
if (mtlsKeyStore != null) {
return new NetHttpTransport.Builder().trustCertificates(null, mtlsKeyStore, "").build();
}
} catch (GeneralSecurityException e) {
throw new IOException(e.toString());
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
}
}
return null;
}

private TransportChannel createChannel() throws IOException {
Map<String, String> headers = headerProvider.getHeaders();

Expand All @@ -168,12 +191,17 @@ private TransportChannel createChannel() throws IOException {
headerEnhancers.add(HttpJsonHeaderEnhancers.create(header.getKey(), header.getValue()));
}

HttpTransport httpTransportToUse = httpTransport;
if (httpTransportToUse == null) {
httpTransportToUse = createHttpTransport();
}

ManagedHttpJsonChannel channel =
ManagedHttpJsonChannel.newBuilder()
.setEndpoint(endpoint)
.setHeaderEnhancers(headerEnhancers)
.setExecutor(executor)
.setHttpTransport(httpTransport)
.setHttpTransport(httpTransportToUse)
.build();

return HttpJsonTransportChannel.newBuilder().setManagedChannel(channel).build();
Expand Down Expand Up @@ -202,6 +230,7 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private HttpTransport httpTransport;
private MtlsProvider mtlsProvider = new MtlsProvider();

private Builder() {}

Expand All @@ -210,6 +239,7 @@ private Builder(InstantiatingHttpJsonChannelProvider provider) {
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.httpTransport = provider.httpTransport;
this.mtlsProvider = provider.mtlsProvider;
}

/**
Expand Down Expand Up @@ -259,9 +289,15 @@ public String getEndpoint() {
return endpoint;
}

@VisibleForTesting
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
return this;
}

public InstantiatingHttpJsonChannelProvider build() {
return new InstantiatingHttpJsonChannelProvider(
executor, headerProvider, endpoint, httpTransport);
executor, headerProvider, endpoint, httpTransport, mtlsProvider);
}
}
}
Expand Up @@ -32,7 +32,10 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;

import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.Executor;
Expand All @@ -41,9 +44,10 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class InstantiatingHttpJsonChannelProviderTest {
public class InstantiatingHttpJsonChannelProviderTest extends AbstractMtlsTransportChannelTest {

@Test
public void basicTest() throws IOException {
Expand Down Expand Up @@ -94,4 +98,15 @@ public void basicTest() throws IOException {
// Make sure we can create channels OK.
provider.getTransportChannel().shutdownNow();
}

@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider) throws IOException {
return InstantiatingHttpJsonChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build()
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
.createHttpTransport();
}
}
4 changes: 4 additions & 0 deletions gax/BUILD.bazel
Expand Up @@ -51,6 +51,10 @@ java_library(
srcs = glob(["src/test/java/**/*.java"]),
javacopts = _JAVA_COPTS,
plugins = ["//:auto_value_plugin"],
resources = glob([
"src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json",
"src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem",
]),
visibility = ["//visibility:public"],
deps = [":gax"] + _COMPILE_DEPS + _TEST_COMPILE_DEPS,
)
Expand Down