Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
feat: add mtls feature to http and grpc transport provider (#1249)
Browse files Browse the repository at this point in the history
* feat: add mtls support to grpc and http transport
  • Loading branch information
arithmetic1728 committed May 26, 2021
1 parent 3b1859e commit b863041
Show file tree
Hide file tree
Showing 16 changed files with 1,045 additions and 15 deletions.
Expand Up @@ -38,6 +38,7 @@
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -46,16 +47,22 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.CharStreams;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.net.ssl.KeyManagerFactory;
import org.threeten.bp.Duration;

/**
Expand Down Expand Up @@ -96,6 +103,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private final MtlsProvider mtlsProvider;

@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
Expand All @@ -105,6 +113,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.executor = builder.executor;
this.headerProvider = builder.headerProvider;
this.endpoint = builder.endpoint;
this.mtlsProvider = builder.mtlsProvider;
this.envProvider = builder.envProvider;
this.interceptorProvider = builder.interceptorProvider;
this.maxInboundMessageSize = builder.maxInboundMessageSize;
Expand Down Expand Up @@ -216,8 +225,13 @@ private TransportChannel createChannel() throws IOException {
int realPoolSize = MoreObjects.firstNonNull(poolSize, 1);
ChannelFactory channelFactory =
new ChannelFactory() {
@Override
public ManagedChannel createSingleChannel() throws IOException {
return InstantiatingGrpcChannelProvider.this.createSingleChannel();
try {
return InstantiatingGrpcChannelProvider.this.createSingleChannel();
} catch (GeneralSecurityException e) {
throw new IOException(e);
}
}
};
ManagedChannel outerChannel;
Expand Down Expand Up @@ -264,7 +278,21 @@ static boolean isOnComputeEngine() {
return false;
}

private ManagedChannel createSingleChannel() throws IOException {
@VisibleForTesting
ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSecurityException {
if (mtlsProvider.useMtlsClientCertificate()) {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
if (mtlsKeyStore != null) {
KeyManagerFactory factory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
factory.init(mtlsKeyStore, new char[] {});
return TlsChannelCredentials.newBuilder().keyManager(factory.getKeyManagers()).build();
}
}
return null;
}

private ManagedChannel createSingleChannel() throws IOException, GeneralSecurityException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headerProvider.getHeaders());
GrpcMetadataHandlerInterceptor metadataHandlerInterceptor =
Expand All @@ -290,7 +318,12 @@ && isOnComputeEngine()) {
builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS);
builder.defaultServiceConfig(directPathServiceConfig);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
ChannelCredentials channelCredentials = createMtlsChannelCredentials();
if (channelCredentials != null) {
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
}
}
builder =
builder
Expand Down Expand Up @@ -376,6 +409,7 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private EnvironmentProvider envProvider;
private MtlsProvider mtlsProvider = new MtlsProvider();
@Nullable private GrpcInterceptorProvider interceptorProvider;
@Nullable private Integer maxInboundMessageSize;
@Nullable private Integer maxInboundMetadataSize;
Expand Down Expand Up @@ -412,6 +446,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.channelPrimer = provider.channelPrimer;
this.attemptDirectPath = provider.attemptDirectPath;
this.directPathServiceConfig = provider.directPathServiceConfig;
this.mtlsProvider = provider.mtlsProvider;
}

/** Sets the number of available CPUs, used internally for testing. */
Expand Down Expand Up @@ -458,6 +493,12 @@ public Builder setEndpoint(String endpoint) {
return this;
}

@VisibleForTesting
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
return this;
}

/**
* Sets the GrpcInterceptorProvider for this TransportChannelProvider.
*
Expand Down
Expand Up @@ -38,6 +38,8 @@
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.collect.ImmutableList;
Expand All @@ -46,6 +48,7 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -63,8 +66,7 @@
import org.threeten.bp.Duration;

@RunWith(JUnit4.class)
public class InstantiatingGrpcChannelProviderTest {

public class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
@Test
public void testEndpoint() {
String endpoint = "localhost:8080";
Expand Down Expand Up @@ -499,4 +501,17 @@ public void testWithCustomDirectPathServiceConfig() {
ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;
assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig);
}

@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
throws IOException, GeneralSecurityException {
InstantiatingGrpcChannelProvider channelProvider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build();
return channelProvider.createMtlsChannelCredentials();
}
}
Expand Up @@ -50,6 +50,7 @@
import com.google.api.gax.rpc.StubSettings;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -83,6 +84,7 @@ private static class FakeStubSettings extends StubSettings<FakeStubSettings> {
public static final int DEFAULT_SERVICE_PORT = 443;
public static final String DEFAULT_SERVICE_ENDPOINT =
DEFAULT_SERVICE_ADDRESS + ':' + DEFAULT_SERVICE_PORT;
public static final MtlsProvider DEFAULT_MTLS_PROVIDER = new MtlsProvider();
public static final ImmutableList<String> DEFAULT_SERVICE_SCOPES =
ImmutableList.<String>builder()
.add("https://www.googleapis.com/auth/pubsub")
Expand Down Expand Up @@ -148,7 +150,9 @@ public static InstantiatingExecutorProvider.Builder defaultExecutorProviderBuild

/** Returns a builder for the default TransportChannelProvider for this service. */
public static InstantiatingGrpcChannelProvider.Builder defaultGrpcChannelProviderBuilder() {
return InstantiatingGrpcChannelProvider.newBuilder().setEndpoint(DEFAULT_SERVICE_ENDPOINT);
return InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint(DEFAULT_SERVICE_ENDPOINT)
.setMtlsProvider(DEFAULT_MTLS_PROVIDER);
}

public static ApiClientHeaderProvider.Builder defaultGoogleServiceHeaderProviderBuilder() {
Expand Down
Expand Up @@ -30,16 +30,21 @@
package com.google.api.gax.httpjson;

import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.javanet.NetHttpTransport;
import com.google.api.core.BetaApi;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
Expand All @@ -64,24 +69,28 @@ public final class InstantiatingHttpJsonChannelProvider implements TransportChan
private final HeaderProvider headerProvider;
private final String endpoint;
private final HttpTransport httpTransport;
private final MtlsProvider mtlsProvider;

private InstantiatingHttpJsonChannelProvider(
Executor executor, HeaderProvider headerProvider, String endpoint) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = null;
this.mtlsProvider = new MtlsProvider();
}

private InstantiatingHttpJsonChannelProvider(
Executor executor,
HeaderProvider headerProvider,
String endpoint,
HttpTransport httpTransport) {
HttpTransport httpTransport,
MtlsProvider mtlsProvider) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = httpTransport;
this.mtlsProvider = mtlsProvider;
}

@Override
Expand Down Expand Up @@ -145,7 +154,11 @@ public TransportChannel getTransportChannel() throws IOException {
} else if (needsHeaders()) {
throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true");
} else {
return createChannel();
try {
return createChannel();
} catch (GeneralSecurityException e) {
throw new IOException(e);
}
}
}

Expand All @@ -160,20 +173,35 @@ public TransportChannelProvider withCredentials(Credentials credentials) {
"InstantiatingHttpJsonChannelProvider doesn't need credentials");
}

private TransportChannel createChannel() throws IOException {
HttpTransport createHttpTransport() throws IOException, GeneralSecurityException {
if (mtlsProvider.useMtlsClientCertificate()) {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
if (mtlsKeyStore != null) {
return new NetHttpTransport.Builder().trustCertificates(null, mtlsKeyStore, "").build();
}
}
return null;
}

private TransportChannel createChannel() throws IOException, GeneralSecurityException {
Map<String, String> headers = headerProvider.getHeaders();

List<HttpJsonHeaderEnhancer> headerEnhancers = Lists.newArrayList();
for (Map.Entry<String, String> header : headers.entrySet()) {
headerEnhancers.add(HttpJsonHeaderEnhancers.create(header.getKey(), header.getValue()));
}

HttpTransport httpTransportToUse = httpTransport;
if (httpTransportToUse == null) {
httpTransportToUse = createHttpTransport();
}

ManagedHttpJsonChannel channel =
ManagedHttpJsonChannel.newBuilder()
.setEndpoint(endpoint)
.setHeaderEnhancers(headerEnhancers)
.setExecutor(executor)
.setHttpTransport(httpTransport)
.setHttpTransport(httpTransportToUse)
.build();

return HttpJsonTransportChannel.newBuilder().setManagedChannel(channel).build();
Expand Down Expand Up @@ -202,6 +230,7 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private HttpTransport httpTransport;
private MtlsProvider mtlsProvider = new MtlsProvider();

private Builder() {}

Expand All @@ -210,6 +239,7 @@ private Builder(InstantiatingHttpJsonChannelProvider provider) {
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.httpTransport = provider.httpTransport;
this.mtlsProvider = provider.mtlsProvider;
}

/**
Expand Down Expand Up @@ -259,9 +289,15 @@ public String getEndpoint() {
return endpoint;
}

@VisibleForTesting
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
return this;
}

public InstantiatingHttpJsonChannelProvider build() {
return new InstantiatingHttpJsonChannelProvider(
executor, headerProvider, endpoint, httpTransport);
executor, headerProvider, endpoint, httpTransport, mtlsProvider);
}
}
}
Expand Up @@ -32,18 +32,23 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;

import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Collections;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class InstantiatingHttpJsonChannelProviderTest {
public class InstantiatingHttpJsonChannelProviderTest extends AbstractMtlsTransportChannelTest {

@Test
public void basicTest() throws IOException {
Expand Down Expand Up @@ -94,4 +99,17 @@ public void basicTest() throws IOException {
// Make sure we can create channels OK.
provider.getTransportChannel().shutdownNow();
}

@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
throws IOException, GeneralSecurityException {
InstantiatingHttpJsonChannelProvider channelProvider =
InstantiatingHttpJsonChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build();
return channelProvider.createHttpTransport();
}
}
4 changes: 4 additions & 0 deletions gax/BUILD.bazel
Expand Up @@ -51,6 +51,10 @@ java_library(
srcs = glob(["src/test/java/**/*.java"]),
javacopts = _JAVA_COPTS,
plugins = ["//:auto_value_plugin"],
resources = glob([
"src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json",
"src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem",
]),
visibility = ["//visibility:public"],
deps = [":gax"] + _COMPILE_DEPS + _TEST_COMPILE_DEPS,
)
Expand Down

0 comments on commit b863041

Please sign in to comment.