Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
feat: add mtls to gax
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Nov 20, 2020
1 parent 7e15b60 commit 18c26b5
Show file tree
Hide file tree
Showing 20 changed files with 1,090 additions and 11 deletions.
4 changes: 2 additions & 2 deletions dependencies.properties
Expand Up @@ -34,8 +34,8 @@ version.io_grpc=1.32.2
# 2) Replace all characters which are neither alphabetic nor digits with the underscore ('_') character
maven.com_google_api_grpc_proto_google_common_protos=com.google.api.grpc:proto-google-common-protos:1.17.0
maven.com_google_api_grpc_grpc_google_common_protos=com.google.api.grpc:grpc-google-common-protos:1.17.0
maven.com_google_auth_google_auth_library_oauth2_http=com.google.auth:google-auth-library-oauth2-http:0.22.0
maven.com_google_auth_google_auth_library_credentials=com.google.auth:google-auth-library-credentials:0.22.0
maven.com_google_auth_google_auth_library_oauth2_http=com.google.auth:google-auth-library-oauth2-http:0.22.1
maven.com_google_auth_google_auth_library_credentials=com.google.auth:google-auth-library-credentials:0.22.1
maven.io_opencensus_opencensus_api=io.opencensus:opencensus-api:0.24.0
maven.io_opencensus_opencensus_contrib_grpc_metrics=io.opencensus:opencensus-contrib-grpc-metrics:0.24.0
maven.io_opencensus_opencensus_contrib_http_util=io.opencensus:opencensus-contrib-http-util:0.24.0
Expand Down
Expand Up @@ -38,6 +38,8 @@
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.api.gax.rpc.mtls.MtlsUtils;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.base.MoreObjects;
Expand All @@ -48,13 +50,20 @@
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.net.ssl.KeyManagerFactory;
import org.threeten.bp.Duration;

/**
Expand Down Expand Up @@ -94,6 +103,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final Credentials credentials;
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
@Nullable private final MtlsProvider mtlsProvider;

@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
Expand All @@ -103,6 +113,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.executor = builder.executor;
this.headerProvider = builder.headerProvider;
this.endpoint = builder.endpoint;
this.mtlsProvider = builder.mtlsProvider;
this.envProvider = builder.envProvider;
this.interceptorProvider = builder.interceptorProvider;
this.maxInboundMessageSize = builder.maxInboundMessageSize;
Expand Down Expand Up @@ -192,6 +203,16 @@ public TransportChannelProvider withCredentials(Credentials credentials) {
return toBuilder().setCredentials(credentials).build();
}

@Override
public TransportChannelProvider withMtlsProvider(MtlsProvider provider) {
return toBuilder().setMtlsProvider(provider).build();
}

@Override
public MtlsProvider getMtlsProvider() {
return mtlsProvider;
}

@Override
public TransportChannel getTransportChannel() throws IOException {
if (needsExecutor()) {
Expand Down Expand Up @@ -257,6 +278,26 @@ static boolean isOnComputeEngine() {
}
return false;
}

SslContext createSslContext() throws IOException {
if (mtlsProvider.useMtlsClientCertificate()) {
try {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
String mtlsKeyStorePassword = mtlsProvider.getKeyStorePassword();
if (mtlsKeyStore != null) {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
KeyManagerFactory factory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
factory.init(mtlsKeyStore, mtlsKeyStorePassword.toCharArray());
sslContextBuilder.keyManager(factory);
return sslContextBuilder.build();
}
} catch (GeneralSecurityException e) {
throw new IOException(e.toString());
}
}
return null;
}

private ManagedChannel createSingleChannel() throws IOException {
GrpcHeaderInterceptor headerInterceptor =
Expand Down Expand Up @@ -314,6 +355,10 @@ && isOnComputeEngine()) {
.intercept(metadataHandlerInterceptor)
.userAgent(headerInterceptor.getUserAgentHeader())
.executor(executor);
SslContext sslContext = createSslContext();
if (sslContext != null) {
builder = ((NettyChannelBuilder) builder).sslContext(sslContext);
}

if (maxInboundMetadataSize != null) {
builder.maxInboundMetadataSize(maxInboundMetadataSize);
Expand Down Expand Up @@ -389,6 +434,7 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private EnvironmentProvider envProvider;
private MtlsProvider mtlsProvider;
@Nullable private GrpcInterceptorProvider interceptorProvider;
@Nullable private Integer maxInboundMessageSize;
@Nullable private Integer maxInboundMetadataSize;
Expand All @@ -404,6 +450,7 @@ public static final class Builder {
private Builder() {
processorCount = Runtime.getRuntime().availableProcessors();
envProvider = DirectPathEnvironmentProvider.getInstance();
mtlsProvider = MtlsUtils.getDefaultMtlsProvider();
}

private Builder(InstantiatingGrpcChannelProvider provider) {
Expand All @@ -423,6 +470,10 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.credentials = provider.credentials;
this.channelPrimer = provider.channelPrimer;
this.attemptDirectPath = provider.attemptDirectPath;
this.mtlsProvider =
provider.mtlsProvider == null
? MtlsUtils.getDefaultMtlsProvider()
: provider.mtlsProvider;
}

/** Sets the number of available CPUs, used internally for testing. */
Expand Down Expand Up @@ -469,6 +520,11 @@ public Builder setEndpoint(String endpoint) {
return this;
}

public Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
return this;
}

/**
* Sets the GrpcInterceptorProvider for this TransportChannelProvider.
*
Expand All @@ -485,6 +541,10 @@ public String getEndpoint() {
return endpoint;
}

public MtlsProvider getMtlsProvider() {
return mtlsProvider;
}

/** The maximum message size allowed to be received on the channel. */
public Builder setMaxInboundMessageSize(Integer max) {
this.maxInboundMessageSize = max;
Expand Down
Expand Up @@ -37,6 +37,8 @@
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.api.gax.rpc.mtls.MtlsTransportChannelBaseTest;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import io.grpc.ManagedChannel;
Expand All @@ -55,7 +57,7 @@
import org.threeten.bp.Duration;

@RunWith(JUnit4.class)
public class InstantiatingGrpcChannelProviderTest {
public class InstantiatingGrpcChannelProviderTest extends MtlsTransportChannelBaseTest {

@Test
public void testEndpoint() {
Expand Down Expand Up @@ -378,4 +380,15 @@ public void testWithPrimeChannel() throws IOException {
.primeChannel(Mockito.any(ManagedChannel.class));
}
}

@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider) throws IOException {
return InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build()
.createSslContext();
}
}
Expand Up @@ -36,6 +36,7 @@
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import io.grpc.CallOptions;
import io.grpc.Channel;
Expand Down Expand Up @@ -202,4 +203,14 @@ List<Metadata> getSubmittedHeaders() {
return submittedHeaders;
}
}

@Override
public TransportChannelProvider withMtlsProvider(MtlsProvider provider) {
throw new UnsupportedOperationException("LocalChannelProvider doesn't need MtlsProvider");
}

@Override
public MtlsProvider getMtlsProvider() {
return null;
}
}
Expand Up @@ -30,16 +30,21 @@
package com.google.api.gax.httpjson;

import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.javanet.NetHttpTransport;
import com.google.api.core.BetaApi;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.api.gax.rpc.mtls.MtlsUtils;
import com.google.auth.Credentials;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
Expand All @@ -64,24 +69,28 @@ public final class InstantiatingHttpJsonChannelProvider implements TransportChan
private final HeaderProvider headerProvider;
private final String endpoint;
private final HttpTransport httpTransport;
private final MtlsProvider mtlsProvider;

private InstantiatingHttpJsonChannelProvider(
Executor executor, HeaderProvider headerProvider, String endpoint) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = null;
this.mtlsProvider = MtlsUtils.getDefaultMtlsProvider();
}

private InstantiatingHttpJsonChannelProvider(
Executor executor,
HeaderProvider headerProvider,
String endpoint,
HttpTransport httpTransport) {
HttpTransport httpTransport,
MtlsProvider mtlsProvider) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = httpTransport;
this.mtlsProvider = mtlsProvider;
}

@Override
Expand Down Expand Up @@ -160,6 +169,23 @@ public TransportChannelProvider withCredentials(Credentials credentials) {
"InstantiatingHttpJsonChannelProvider doesn't need credentials");
}

HttpTransport createHttpTransport() throws IOException {
if (mtlsProvider.useMtlsClientCertificate()) {
try {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
String mtlsKeyStorePassword = mtlsProvider.getKeyStorePassword();
if (mtlsKeyStore != null) {
return new NetHttpTransport.Builder()
.trustCertificates(null, mtlsKeyStore, mtlsKeyStorePassword)
.build();
}
} catch (GeneralSecurityException e) {
throw new IOException(e.toString());
}
}
return null;
}

private TransportChannel createChannel() throws IOException {
Map<String, String> headers = headerProvider.getHeaders();

Expand All @@ -168,12 +194,17 @@ private TransportChannel createChannel() throws IOException {
headerEnhancers.add(HttpJsonHeaderEnhancers.create(header.getKey(), header.getValue()));
}

HttpTransport httpTransportToUse = httpTransport;
if (httpTransportToUse == null) {
httpTransportToUse = createHttpTransport();
}

ManagedHttpJsonChannel channel =
ManagedHttpJsonChannel.newBuilder()
.setEndpoint(endpoint)
.setHeaderEnhancers(headerEnhancers)
.setExecutor(executor)
.setHttpTransport(httpTransport)
.setHttpTransport(httpTransportToUse)
.build();

return HttpJsonTransportChannel.newBuilder().setManagedChannel(channel).build();
Expand All @@ -184,6 +215,16 @@ public String getEndpoint() {
return endpoint;
}

@Override
public TransportChannelProvider withMtlsProvider(MtlsProvider provider) {
return toBuilder().setMtlsProvider(provider).build();
}

@Override
public MtlsProvider getMtlsProvider() {
return mtlsProvider;
}

@Override
public boolean shouldAutoClose() {
return true;
Expand All @@ -202,14 +243,21 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private HttpTransport httpTransport;
private MtlsProvider mtlsProvider;

private Builder() {}
private Builder() {
mtlsProvider = MtlsUtils.getDefaultMtlsProvider();
}

private Builder(InstantiatingHttpJsonChannelProvider provider) {
this.executor = provider.executor;
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.httpTransport = provider.httpTransport;
this.mtlsProvider =
provider.mtlsProvider == null
? MtlsUtils.getDefaultMtlsProvider()
: provider.mtlsProvider;
}

/**
Expand Down Expand Up @@ -259,9 +307,14 @@ public String getEndpoint() {
return endpoint;
}

public Builder setMtlsProvider(MtlsProvider provider) {
this.mtlsProvider = provider;
return this;
}

public InstantiatingHttpJsonChannelProvider build() {
return new InstantiatingHttpJsonChannelProvider(
executor, headerProvider, endpoint, httpTransport);
executor, headerProvider, endpoint, httpTransport, mtlsProvider);
}
}
}

0 comments on commit 18c26b5

Please sign in to comment.