diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 9797b4d7d..dcfa5c39e 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -38,6 +38,7 @@ import com.google.api.gax.rpc.HeaderProvider; import com.google.api.gax.rpc.TransportChannel; import com.google.api.gax.rpc.TransportChannelProvider; +import com.google.api.gax.rpc.mtls.MtlsProvider; import com.google.auth.Credentials; import com.google.auth.oauth2.ComputeEngineCredentials; import com.google.common.annotations.VisibleForTesting; @@ -46,16 +47,22 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.CharStreams; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.TlsChannelCredentials; import io.grpc.alts.ComputeEngineChannelBuilder; import java.io.IOException; import java.io.InputStreamReader; +import java.security.GeneralSecurityException; +import java.security.KeyStore; import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import javax.net.ssl.KeyManagerFactory; import org.threeten.bp.Duration; /** @@ -96,6 +103,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @Nullable private final ChannelPrimer channelPrimer; @Nullable private final Boolean attemptDirectPath; @VisibleForTesting final ImmutableMap directPathServiceConfig; + @Nullable private final MtlsProvider mtlsProvider; @Nullable private final ApiFunction channelConfigurator; @@ -105,6 +113,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) { this.executor = builder.executor; this.headerProvider = builder.headerProvider; this.endpoint = builder.endpoint; + this.mtlsProvider = builder.mtlsProvider; this.envProvider = builder.envProvider; this.interceptorProvider = builder.interceptorProvider; this.maxInboundMessageSize = builder.maxInboundMessageSize; @@ -216,8 +225,13 @@ private TransportChannel createChannel() throws IOException { int realPoolSize = MoreObjects.firstNonNull(poolSize, 1); ChannelFactory channelFactory = new ChannelFactory() { + @Override public ManagedChannel createSingleChannel() throws IOException { - return InstantiatingGrpcChannelProvider.this.createSingleChannel(); + try { + return InstantiatingGrpcChannelProvider.this.createSingleChannel(); + } catch (GeneralSecurityException e) { + throw new IOException(e); + } } }; ManagedChannel outerChannel; @@ -264,7 +278,21 @@ static boolean isOnComputeEngine() { return false; } - private ManagedChannel createSingleChannel() throws IOException { + @VisibleForTesting + ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSecurityException { + if (mtlsProvider.useMtlsClientCertificate()) { + KeyStore mtlsKeyStore = mtlsProvider.getKeyStore(); + if (mtlsKeyStore != null) { + KeyManagerFactory factory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + factory.init(mtlsKeyStore, new char[] {}); + return TlsChannelCredentials.newBuilder().keyManager(factory.getKeyManagers()).build(); + } + } + return null; + } + + private ManagedChannel createSingleChannel() throws IOException, GeneralSecurityException { GrpcHeaderInterceptor headerInterceptor = new GrpcHeaderInterceptor(headerProvider.getHeaders()); GrpcMetadataHandlerInterceptor metadataHandlerInterceptor = @@ -290,7 +318,12 @@ && isOnComputeEngine()) { builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS); builder.defaultServiceConfig(directPathServiceConfig); } else { - builder = ManagedChannelBuilder.forAddress(serviceAddress, port); + ChannelCredentials channelCredentials = createMtlsChannelCredentials(); + if (channelCredentials != null) { + builder = Grpc.newChannelBuilder(endpoint, channelCredentials); + } else { + builder = ManagedChannelBuilder.forAddress(serviceAddress, port); + } } builder = builder @@ -376,6 +409,7 @@ public static final class Builder { private HeaderProvider headerProvider; private String endpoint; private EnvironmentProvider envProvider; + private MtlsProvider mtlsProvider = new MtlsProvider(); @Nullable private GrpcInterceptorProvider interceptorProvider; @Nullable private Integer maxInboundMessageSize; @Nullable private Integer maxInboundMetadataSize; @@ -412,6 +446,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) { this.channelPrimer = provider.channelPrimer; this.attemptDirectPath = provider.attemptDirectPath; this.directPathServiceConfig = provider.directPathServiceConfig; + this.mtlsProvider = provider.mtlsProvider; } /** Sets the number of available CPUs, used internally for testing. */ @@ -458,6 +493,12 @@ public Builder setEndpoint(String endpoint) { return this; } + @VisibleForTesting + Builder setMtlsProvider(MtlsProvider mtlsProvider) { + this.mtlsProvider = mtlsProvider; + return this; + } + /** * Sets the GrpcInterceptorProvider for this TransportChannelProvider. * diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index 47f612db5..72f3cc464 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -38,6 +38,8 @@ import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; import com.google.api.gax.rpc.HeaderProvider; import com.google.api.gax.rpc.TransportChannelProvider; +import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest; +import com.google.api.gax.rpc.mtls.MtlsProvider; import com.google.auth.oauth2.CloudShellCredentials; import com.google.auth.oauth2.ComputeEngineCredentials; import com.google.common.collect.ImmutableList; @@ -46,6 +48,7 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.alts.ComputeEngineChannelBuilder; import java.io.IOException; +import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -63,8 +66,7 @@ import org.threeten.bp.Duration; @RunWith(JUnit4.class) -public class InstantiatingGrpcChannelProviderTest { - +public class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest { @Test public void testEndpoint() { String endpoint = "localhost:8080"; @@ -499,4 +501,17 @@ public void testWithCustomDirectPathServiceConfig() { ImmutableMap defaultServiceConfig = provider.directPathServiceConfig; assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig); } + + @Override + protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider) + throws IOException, GeneralSecurityException { + InstantiatingGrpcChannelProvider channelProvider = + InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint("localhost:8080") + .setMtlsProvider(provider) + .setHeaderProvider(Mockito.mock(HeaderProvider.class)) + .setExecutor(Mockito.mock(Executor.class)) + .build(); + return channelProvider.createMtlsChannelCredentials(); + } } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/SettingsTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/SettingsTest.java index 65ca3d061..f40716d4f 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/SettingsTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/SettingsTest.java @@ -50,6 +50,7 @@ import com.google.api.gax.rpc.StubSettings; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.api.gax.rpc.UnaryCallSettings; +import com.google.api.gax.rpc.mtls.MtlsProvider; import com.google.auth.Credentials; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -83,6 +84,7 @@ private static class FakeStubSettings extends StubSettings { public static final int DEFAULT_SERVICE_PORT = 443; public static final String DEFAULT_SERVICE_ENDPOINT = DEFAULT_SERVICE_ADDRESS + ':' + DEFAULT_SERVICE_PORT; + public static final MtlsProvider DEFAULT_MTLS_PROVIDER = new MtlsProvider(); public static final ImmutableList DEFAULT_SERVICE_SCOPES = ImmutableList.builder() .add("https://www.googleapis.com/auth/pubsub") @@ -148,7 +150,9 @@ public static InstantiatingExecutorProvider.Builder defaultExecutorProviderBuild /** Returns a builder for the default TransportChannelProvider for this service. */ public static InstantiatingGrpcChannelProvider.Builder defaultGrpcChannelProviderBuilder() { - return InstantiatingGrpcChannelProvider.newBuilder().setEndpoint(DEFAULT_SERVICE_ENDPOINT); + return InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint(DEFAULT_SERVICE_ENDPOINT) + .setMtlsProvider(DEFAULT_MTLS_PROVIDER); } public static ApiClientHeaderProvider.Builder defaultGoogleServiceHeaderProviderBuilder() { diff --git a/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java b/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java index 198ce4d7d..b6da11a39 100644 --- a/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java +++ b/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java @@ -30,6 +30,7 @@ package com.google.api.gax.httpjson; import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.javanet.NetHttpTransport; import com.google.api.core.BetaApi; import com.google.api.core.InternalExtensionOnly; import com.google.api.gax.core.ExecutorProvider; @@ -37,9 +38,13 @@ import com.google.api.gax.rpc.HeaderProvider; import com.google.api.gax.rpc.TransportChannel; import com.google.api.gax.rpc.TransportChannelProvider; +import com.google.api.gax.rpc.mtls.MtlsProvider; import com.google.auth.Credentials; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStore; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; @@ -64,6 +69,7 @@ public final class InstantiatingHttpJsonChannelProvider implements TransportChan private final HeaderProvider headerProvider; private final String endpoint; private final HttpTransport httpTransport; + private final MtlsProvider mtlsProvider; private InstantiatingHttpJsonChannelProvider( Executor executor, HeaderProvider headerProvider, String endpoint) { @@ -71,17 +77,20 @@ private InstantiatingHttpJsonChannelProvider( this.headerProvider = headerProvider; this.endpoint = endpoint; this.httpTransport = null; + this.mtlsProvider = new MtlsProvider(); } private InstantiatingHttpJsonChannelProvider( Executor executor, HeaderProvider headerProvider, String endpoint, - HttpTransport httpTransport) { + HttpTransport httpTransport, + MtlsProvider mtlsProvider) { this.executor = executor; this.headerProvider = headerProvider; this.endpoint = endpoint; this.httpTransport = httpTransport; + this.mtlsProvider = mtlsProvider; } @Override @@ -145,7 +154,11 @@ public TransportChannel getTransportChannel() throws IOException { } else if (needsHeaders()) { throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true"); } else { - return createChannel(); + try { + return createChannel(); + } catch (GeneralSecurityException e) { + throw new IOException(e); + } } } @@ -160,7 +173,17 @@ public TransportChannelProvider withCredentials(Credentials credentials) { "InstantiatingHttpJsonChannelProvider doesn't need credentials"); } - private TransportChannel createChannel() throws IOException { + HttpTransport createHttpTransport() throws IOException, GeneralSecurityException { + if (mtlsProvider.useMtlsClientCertificate()) { + KeyStore mtlsKeyStore = mtlsProvider.getKeyStore(); + if (mtlsKeyStore != null) { + return new NetHttpTransport.Builder().trustCertificates(null, mtlsKeyStore, "").build(); + } + } + return null; + } + + private TransportChannel createChannel() throws IOException, GeneralSecurityException { Map headers = headerProvider.getHeaders(); List headerEnhancers = Lists.newArrayList(); @@ -168,12 +191,17 @@ private TransportChannel createChannel() throws IOException { headerEnhancers.add(HttpJsonHeaderEnhancers.create(header.getKey(), header.getValue())); } + HttpTransport httpTransportToUse = httpTransport; + if (httpTransportToUse == null) { + httpTransportToUse = createHttpTransport(); + } + ManagedHttpJsonChannel channel = ManagedHttpJsonChannel.newBuilder() .setEndpoint(endpoint) .setHeaderEnhancers(headerEnhancers) .setExecutor(executor) - .setHttpTransport(httpTransport) + .setHttpTransport(httpTransportToUse) .build(); return HttpJsonTransportChannel.newBuilder().setManagedChannel(channel).build(); @@ -202,6 +230,7 @@ public static final class Builder { private HeaderProvider headerProvider; private String endpoint; private HttpTransport httpTransport; + private MtlsProvider mtlsProvider = new MtlsProvider(); private Builder() {} @@ -210,6 +239,7 @@ private Builder(InstantiatingHttpJsonChannelProvider provider) { this.headerProvider = provider.headerProvider; this.endpoint = provider.endpoint; this.httpTransport = provider.httpTransport; + this.mtlsProvider = provider.mtlsProvider; } /** @@ -259,9 +289,15 @@ public String getEndpoint() { return endpoint; } + @VisibleForTesting + Builder setMtlsProvider(MtlsProvider mtlsProvider) { + this.mtlsProvider = mtlsProvider; + return this; + } + public InstantiatingHttpJsonChannelProvider build() { return new InstantiatingHttpJsonChannelProvider( - executor, headerProvider, endpoint, httpTransport); + executor, headerProvider, endpoint, httpTransport, mtlsProvider); } } } diff --git a/gax-httpjson/src/test/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProviderTest.java b/gax-httpjson/src/test/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProviderTest.java index baff7bee5..bc00f164b 100644 --- a/gax-httpjson/src/test/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProviderTest.java +++ b/gax-httpjson/src/test/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProviderTest.java @@ -32,8 +32,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import com.google.api.gax.rpc.HeaderProvider; import com.google.api.gax.rpc.TransportChannelProvider; +import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest; +import com.google.api.gax.rpc.mtls.MtlsProvider; import java.io.IOException; +import java.security.GeneralSecurityException; import java.util.Collections; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -41,9 +45,10 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; @RunWith(JUnit4.class) -public class InstantiatingHttpJsonChannelProviderTest { +public class InstantiatingHttpJsonChannelProviderTest extends AbstractMtlsTransportChannelTest { @Test public void basicTest() throws IOException { @@ -94,4 +99,17 @@ public void basicTest() throws IOException { // Make sure we can create channels OK. provider.getTransportChannel().shutdownNow(); } + + @Override + protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider) + throws IOException, GeneralSecurityException { + InstantiatingHttpJsonChannelProvider channelProvider = + InstantiatingHttpJsonChannelProvider.newBuilder() + .setEndpoint("localhost:8080") + .setMtlsProvider(provider) + .setHeaderProvider(Mockito.mock(HeaderProvider.class)) + .setExecutor(Mockito.mock(Executor.class)) + .build(); + return channelProvider.createHttpTransport(); + } } diff --git a/gax/BUILD.bazel b/gax/BUILD.bazel index 78deafdaa..d46b3c7a2 100644 --- a/gax/BUILD.bazel +++ b/gax/BUILD.bazel @@ -51,6 +51,10 @@ java_library( srcs = glob(["src/test/java/**/*.java"]), javacopts = _JAVA_COPTS, plugins = ["//:auto_value_plugin"], + resources = glob([ + "src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json", + "src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem", + ]), visibility = ["//visibility:public"], deps = [":gax"] + _COMPILE_DEPS + _TEST_COMPILE_DEPS, ) diff --git a/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java b/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java index ac2335235..c8a1920c9 100644 --- a/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java +++ b/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java @@ -36,6 +36,7 @@ import com.google.api.gax.core.ExecutorAsBackgroundResource; import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.rpc.internal.QuotaProjectIdHidingCredentials; +import com.google.api.gax.rpc.mtls.MtlsProvider; import com.google.api.gax.tracing.ApiTracerFactory; import com.google.api.gax.tracing.NoopApiTracerFactory; import com.google.auth.Credentials; @@ -132,6 +133,29 @@ public static ClientContext create(ClientSettings settings) throws IOException { return create(settings.getStubSettings()); } + /** Returns the endpoint that should be used. See https://google.aip.dev/auth/4114. */ + static String getEndpoint( + String endpoint, + String mtlsEndpoint, + boolean switchToMtlsEndpointAllowed, + MtlsProvider mtlsProvider) + throws IOException { + if (switchToMtlsEndpointAllowed) { + switch (mtlsProvider.getMtlsEndpointUsagePolicy()) { + case ALWAYS: + return mtlsEndpoint; + case NEVER: + return endpoint; + default: + if (mtlsProvider.useMtlsClientCertificate() && mtlsProvider.getKeyStore() != null) { + return mtlsEndpoint; + } + return endpoint; + } + } + return endpoint; + } + /** * Instantiates the executor, credentials, and transport context based on the given client * settings. @@ -160,12 +184,18 @@ public static ClientContext create(StubSettings settings) throws IOException { if (transportChannelProvider.needsHeaders()) { transportChannelProvider = transportChannelProvider.withHeaders(headers); } - if (transportChannelProvider.needsEndpoint()) { - transportChannelProvider = transportChannelProvider.withEndpoint(settings.getEndpoint()); - } if (transportChannelProvider.needsCredentials() && credentials != null) { transportChannelProvider = transportChannelProvider.withCredentials(credentials); } + String endpoint = + getEndpoint( + settings.getEndpoint(), + settings.getMtlsEndpoint(), + settings.getSwitchToMtlsEndpointAllowed(), + new MtlsProvider()); + if (transportChannelProvider.needsEndpoint()) { + transportChannelProvider = transportChannelProvider.withEndpoint(endpoint); + } TransportChannel transportChannel = transportChannelProvider.getTransportChannel(); ApiCallContext defaultCallContext = diff --git a/gax/src/main/java/com/google/api/gax/rpc/StubSettings.java b/gax/src/main/java/com/google/api/gax/rpc/StubSettings.java index 1b6090107..25fc85512 100644 --- a/gax/src/main/java/com/google/api/gax/rpc/StubSettings.java +++ b/gax/src/main/java/com/google/api/gax/rpc/StubSettings.java @@ -70,11 +70,20 @@ public abstract class StubSettings> { private final TransportChannelProvider transportChannelProvider; private final ApiClock clock; private final String endpoint; + private final String mtlsEndpoint; private final String quotaProjectId; @Nullable private final WatchdogProvider streamWatchdogProvider; @Nonnull private final Duration streamWatchdogCheckInterval; @Nonnull private final ApiTracerFactory tracerFactory; + /** + * Indicate when creating transport whether it is allowed to use mTLS endpoint instead of the + * default endpoint. Only the endpoint set by client libraries is allowed. User provided endpoint + * should always be used as it is. Client libraries can set it via {@link + * Builder#setSwitchToMtlsEndpointAllowed} method. + */ + private final boolean switchToMtlsEndpointAllowed; + /** Constructs an instance of StubSettings. */ protected StubSettings(Builder builder) { this.executorProvider = builder.executorProvider; @@ -84,6 +93,8 @@ protected StubSettings(Builder builder) { this.internalHeaderProvider = builder.internalHeaderProvider; this.clock = builder.clock; this.endpoint = builder.endpoint; + this.mtlsEndpoint = builder.mtlsEndpoint; + this.switchToMtlsEndpointAllowed = builder.switchToMtlsEndpointAllowed; this.quotaProjectId = builder.quotaProjectId; this.streamWatchdogProvider = builder.streamWatchdogProvider; this.streamWatchdogCheckInterval = builder.streamWatchdogCheckInterval; @@ -120,6 +131,15 @@ public final String getEndpoint() { return endpoint; } + public final String getMtlsEndpoint() { + return mtlsEndpoint; + } + + /** Limit the visibility to this package only since only this package needs it. */ + final boolean getSwitchToMtlsEndpointAllowed() { + return switchToMtlsEndpointAllowed; + } + public final String getQuotaProjectId() { return quotaProjectId; } @@ -155,6 +175,8 @@ public String toString() { .add("internalHeaderProvider", internalHeaderProvider) .add("clock", clock) .add("endpoint", endpoint) + .add("mtlsEndpoint", mtlsEndpoint) + .add("switchToMtlsEndpointAllowed", switchToMtlsEndpointAllowed) .add("quotaProjectId", quotaProjectId) .add("streamWatchdogProvider", streamWatchdogProvider) .add("streamWatchdogCheckInterval", streamWatchdogCheckInterval) @@ -174,11 +196,20 @@ public abstract static class Builder< private TransportChannelProvider transportChannelProvider; private ApiClock clock; private String endpoint; + private String mtlsEndpoint; private String quotaProjectId; @Nullable private WatchdogProvider streamWatchdogProvider; @Nonnull private Duration streamWatchdogCheckInterval; @Nonnull private ApiTracerFactory tracerFactory; + /** + * Indicate when creating transport whether it is allowed to use mTLS endpoint instead of the + * default endpoint. Only the endpoint set by client libraries is allowed. User provided + * endpoint should always be used as it is. Client libraries can set it via {@link + * Builder#setSwitchToMtlsEndpointAllowed} method. + */ + private boolean switchToMtlsEndpointAllowed = false; + /** Create a builder from a StubSettings object. */ protected Builder(StubSettings settings) { this.executorProvider = settings.executorProvider; @@ -188,6 +219,8 @@ protected Builder(StubSettings settings) { this.internalHeaderProvider = settings.internalHeaderProvider; this.clock = settings.clock; this.endpoint = settings.endpoint; + this.mtlsEndpoint = settings.mtlsEndpoint; + this.switchToMtlsEndpointAllowed = settings.switchToMtlsEndpointAllowed; this.quotaProjectId = settings.quotaProjectId; this.streamWatchdogProvider = settings.streamWatchdogProvider; this.streamWatchdogCheckInterval = settings.streamWatchdogCheckInterval; @@ -220,6 +253,7 @@ protected Builder(ClientContext clientContext) { this.internalHeaderProvider = new NoHeaderProvider(); this.clock = NanoClock.getDefaultClock(); this.endpoint = null; + this.mtlsEndpoint = null; this.quotaProjectId = null; this.streamWatchdogProvider = InstantiatingWatchdogProvider.create(); this.streamWatchdogCheckInterval = Duration.ofSeconds(10); @@ -234,6 +268,9 @@ protected Builder(ClientContext clientContext) { FixedHeaderProvider.create(clientContext.getInternalHeaders()); this.clock = clientContext.getClock(); this.endpoint = clientContext.getEndpoint(); + if (this.endpoint != null) { + this.mtlsEndpoint = this.endpoint.replace("googleapis.com", "mtls.googleapis.com"); + } this.streamWatchdogProvider = FixedWatchdogProvider.create(clientContext.getStreamWatchdog()); this.streamWatchdogCheckInterval = clientContext.getStreamWatchdogCheckInterval(); @@ -334,6 +371,20 @@ public B setClock(ApiClock clock) { public B setEndpoint(String endpoint) { this.endpoint = endpoint; + this.switchToMtlsEndpointAllowed = false; + if (this.endpoint != null && this.mtlsEndpoint == null) { + this.mtlsEndpoint = this.endpoint.replace("googleapis.com", "mtls.googleapis.com"); + } + return self(); + } + + protected B setSwitchToMtlsEndpointAllowed(boolean switchToMtlsEndpointAllowed) { + this.switchToMtlsEndpointAllowed = switchToMtlsEndpointAllowed; + return self(); + } + + public B setMtlsEndpoint(String mtlsEndpoint) { + this.mtlsEndpoint = mtlsEndpoint; return self(); } @@ -408,6 +459,10 @@ public String getEndpoint() { return endpoint; } + public String getMtlsEndpoint() { + return mtlsEndpoint; + } + /** Gets the QuotaProjectId that was previously set on this Builder. */ public String getQuotaProjectId() { return quotaProjectId; @@ -445,6 +500,8 @@ public String toString() { .add("internalHeaderProvider", internalHeaderProvider) .add("clock", clock) .add("endpoint", endpoint) + .add("mtlsEndpoint", mtlsEndpoint) + .add("switchToMtlsEndpointAllowed", switchToMtlsEndpointAllowed) .add("quotaProjectId", quotaProjectId) .add("streamWatchdogProvider", streamWatchdogProvider) .add("streamWatchdogCheckInterval", streamWatchdogCheckInterval) diff --git a/gax/src/main/java/com/google/api/gax/rpc/mtls/ContextAwareMetadataJson.java b/gax/src/main/java/com/google/api/gax/rpc/mtls/ContextAwareMetadataJson.java new file mode 100644 index 000000000..25d7d27de --- /dev/null +++ b/gax/src/main/java/com/google/api/gax/rpc/mtls/ContextAwareMetadataJson.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.api.gax.rpc.mtls; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Key; +import com.google.api.core.BetaApi; +import com.google.common.collect.ImmutableList; +import java.util.List; + +/** Data class representing context_aware_metadata.json file. */ +@BetaApi +public class ContextAwareMetadataJson extends GenericJson { + /** Cert provider command */ + @Key("cert_provider_command") + private List commands; + + /** Returns the cert provider command. */ + public final ImmutableList getCommands() { + return ImmutableList.copyOf(commands); + } +} diff --git a/gax/src/main/java/com/google/api/gax/rpc/mtls/MtlsProvider.java b/gax/src/main/java/com/google/api/gax/rpc/mtls/MtlsProvider.java new file mode 100644 index 000000000..367e8bede --- /dev/null +++ b/gax/src/main/java/com/google/api/gax/rpc/mtls/MtlsProvider.java @@ -0,0 +1,197 @@ +/* + * Copyright 2021 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.api.gax.rpc.mtls; + +import com.google.api.client.json.JsonParser; +import com.google.api.client.json.gson.GsonFactory; +import com.google.api.client.util.SecurityUtils; +import com.google.api.core.BetaApi; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.util.List; + +/** + * Provider class for mutual TLS. It is used to configure the mutual TLS in the transport with the + * default client certificate on device. + */ +@BetaApi +public class MtlsProvider { + interface EnvironmentProvider { + String getenv(String name); + } + + static class SystemEnvironmentProvider implements EnvironmentProvider { + @Override + public String getenv(String name) { + return System.getenv(name); + } + } + + interface ProcessProvider { + public Process createProcess(InputStream metadata) throws IOException; + } + + static class DefaultProcessProvider implements ProcessProvider { + @Override + public Process createProcess(InputStream metadata) throws IOException { + if (metadata == null) { + return null; + } + List command = extractCertificateProviderCommand(metadata); + return new ProcessBuilder(command).start(); + } + } + + private static final String DEFAULT_CONTEXT_AWARE_METADATA_PATH = + System.getProperty("user.home") + "/.secureConnect/context_aware_metadata.json"; + + private String metadataPath; + private EnvironmentProvider envProvider; + private ProcessProvider processProvider; + + /** + * The policy for mutual TLS endpoint usage. NEVER means always use regular endpoint; ALWAYS means + * always use mTLS endpoint; AUTO means auto switch to mTLS endpoint if client certificate exists + * and should be used. + */ + public enum MtlsEndpointUsagePolicy { + NEVER, + AUTO, + ALWAYS; + } + + @VisibleForTesting + MtlsProvider( + EnvironmentProvider envProvider, ProcessProvider processProvider, String metadataPath) { + this.envProvider = envProvider; + this.processProvider = processProvider; + this.metadataPath = metadataPath; + } + + public MtlsProvider() { + this( + new SystemEnvironmentProvider(), + new DefaultProcessProvider(), + DEFAULT_CONTEXT_AWARE_METADATA_PATH); + } + + /** + * Returns if mutual TLS client certificate should be used. If the value is true, the key store + * from {@link #getKeyStore()} will be used to configure mutual TLS transport. + */ + public boolean useMtlsClientCertificate() { + String useClientCertificate = envProvider.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE"); + return "true".equals(useClientCertificate); + } + + /** Returns the current mutual TLS endpoint usage policy. */ + public MtlsEndpointUsagePolicy getMtlsEndpointUsagePolicy() { + String mtlsEndpointUsagePolicy = envProvider.getenv("GOOGLE_API_USE_MTLS_ENDPOINT"); + if ("never".equals(mtlsEndpointUsagePolicy)) { + return MtlsEndpointUsagePolicy.NEVER; + } else if ("always".equals(mtlsEndpointUsagePolicy)) { + return MtlsEndpointUsagePolicy.ALWAYS; + } + return MtlsEndpointUsagePolicy.AUTO; + } + + /** The mutual TLS key store created with the default client certificate on device. */ + public KeyStore getKeyStore() throws IOException { + try (InputStream stream = new FileInputStream(metadataPath)) { + return getKeyStore(stream, processProvider); + } catch (InterruptedException e) { + throw new IOException("Interrupted executing certificate provider command", e); + } catch (GeneralSecurityException e) { + // Return null as if the file doesn't exist. + return null; + } catch (FileNotFoundException exception) { + // If the metadata file doesn't exist, then there is no key store, just return null. + return null; + } + } + + @VisibleForTesting + static KeyStore getKeyStore(InputStream metadata, ProcessProvider processProvider) + throws IOException, InterruptedException, GeneralSecurityException { + Process process = processProvider.createProcess(metadata); + + // Run the command and timeout after 1000 milliseconds. + int exitCode = runCertificateProviderCommand(process, 1000); + if (exitCode != 0) { + throw new IOException("Cert provider command failed with exit code: " + exitCode); + } + + // Create mTLS key store with the input certificates from shell command. + return SecurityUtils.createMtlsKeyStore(process.getInputStream()); + } + + @VisibleForTesting + static ImmutableList extractCertificateProviderCommand(InputStream contextAwareMetadata) + throws IOException { + JsonParser parser = new GsonFactory().createJsonParser(contextAwareMetadata); + ContextAwareMetadataJson json = parser.parse(ContextAwareMetadataJson.class); + return json.getCommands(); + } + + @VisibleForTesting + static int runCertificateProviderCommand(Process commandProcess, long timeoutMilliseconds) + throws IOException, InterruptedException { + long startTime = System.currentTimeMillis(); + long remainTime = timeoutMilliseconds; + + // In the while loop, keep checking if the process is terminated every 100 milliseconds + // until timeout is reached or process is terminated. In getKeyStore we set timeout to + // 1000 milliseconds, so 100 millisecond is a good number for the sleep. + while (remainTime > 0) { + Thread.sleep(Math.min(remainTime + 1, 100)); + remainTime -= System.currentTimeMillis() - startTime; + + try { + return commandProcess.exitValue(); + } catch (IllegalThreadStateException ignored) { + // exitValue throws IllegalThreadStateException if process has not yet terminated. + // Once the process is terminated, exitValue no longer throws exception. Therefore + // in the while loop, we use exitValue to check if process is terminated. See + // https://docs.oracle.com/javase/7/docs/api/java/lang/Process.html#exitValue() + // for more details. + } + } + + commandProcess.destroy(); + throw new IOException("cert provider command timed out"); + } +} diff --git a/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java b/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java index 174b35164..9bacd72c0 100644 --- a/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java +++ b/gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java @@ -30,6 +30,10 @@ package com.google.api.gax.rpc; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import com.google.api.core.ApiClock; import com.google.api.gax.core.BackgroundResource; @@ -37,8 +41,12 @@ import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.core.FixedExecutorProvider; +import com.google.api.gax.rpc.mtls.MtlsProvider; +import com.google.api.gax.rpc.mtls.MtlsProvider.MtlsEndpointUsagePolicy; import com.google.api.gax.rpc.testing.FakeChannel; import com.google.api.gax.rpc.testing.FakeClientSettings; +import com.google.api.gax.rpc.testing.FakeMtlsProvider; +import com.google.api.gax.rpc.testing.FakeStubSettings; import com.google.api.gax.rpc.testing.FakeTransportChannel; import com.google.auth.Credentials; import com.google.auth.oauth2.GoogleCredentials; @@ -597,4 +605,118 @@ public void testUserAgentConcat() throws Exception { assertThat(transportChannel.getHeaders()) .containsEntry("user-agent", "user-supplied-agent internal-agent"); } + + private static String endpoint = "https://foo.googleapis.com"; + private static String mtlsEndpoint = "https://foo.mtls.googleapis.com"; + + @Test + public void testAutoUseMtlsEndpoint() throws IOException { + // Test the case client certificate exists and mTLS endpoint is selected. + boolean switchToMtlsEndpointAllowed = true; + MtlsProvider provider = + new FakeMtlsProvider( + true, + MtlsEndpointUsagePolicy.AUTO, + FakeMtlsProvider.createTestMtlsKeyStore(), + "", + false); + String endpointSelected = + ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider); + assertEquals(mtlsEndpoint, endpointSelected); + } + + @Test + public void testEndpointNotOverridable() throws IOException { + // Test the case that switching to mTLS endpoint is not allowed so the original endpoint is + // selected. + boolean switchToMtlsEndpointAllowed = false; + MtlsProvider provider = + new FakeMtlsProvider( + true, + MtlsEndpointUsagePolicy.AUTO, + FakeMtlsProvider.createTestMtlsKeyStore(), + "", + false); + String endpointSelected = + ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider); + assertEquals(endpoint, endpointSelected); + } + + @Test + public void testNoClientCertificate() throws IOException { + // Test the case that client certificates doesn't exists so the original endpoint is selected. + boolean switchToMtlsEndpointAllowed = true; + MtlsProvider provider = + new FakeMtlsProvider(true, MtlsEndpointUsagePolicy.AUTO, null, "", false); + String endpointSelected = + ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider); + assertEquals(endpoint, endpointSelected); + } + + @Test + public void testAlwaysUseMtlsEndpoint() throws IOException { + // Test the case that mTLS endpoint is always used. + boolean switchToMtlsEndpointAllowed = true; + MtlsProvider provider = + new FakeMtlsProvider(false, MtlsEndpointUsagePolicy.ALWAYS, null, "", false); + String endpointSelected = + ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider); + assertEquals(mtlsEndpoint, endpointSelected); + } + + @Test + public void testNeverUseMtlsEndpoint() throws IOException { + // Test the case that mTLS endpoint is never used. + boolean switchToMtlsEndpointAllowed = true; + MtlsProvider provider = + new FakeMtlsProvider( + true, + MtlsEndpointUsagePolicy.NEVER, + FakeMtlsProvider.createTestMtlsKeyStore(), + "", + false); + String endpointSelected = + ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider); + assertEquals(endpoint, endpointSelected); + } + + @Test + public void testGetKeyStoreThrows() throws IOException { + // Test the case that getKeyStore throws exceptions. + try { + boolean switchToMtlsEndpointAllowed = true; + MtlsProvider provider = + new FakeMtlsProvider(true, MtlsEndpointUsagePolicy.AUTO, null, "", true); + ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider); + fail("should throw an exception"); + } catch (IOException e) { + assertTrue( + "expected getKeyStore to throw an exception", + e.getMessage().contains("getKeyStore throws exception")); + } + } + + @Test + public void testSwitchToMtlsEndpointAllowed() throws IOException { + StubSettings settings = new FakeStubSettings.Builder().setEndpoint(endpoint).build(); + assertFalse(settings.getSwitchToMtlsEndpointAllowed()); + assertEquals(endpoint, settings.getEndpoint()); + + settings = + new FakeStubSettings.Builder() + .setEndpoint(endpoint) + .setSwitchToMtlsEndpointAllowed(true) + .build(); + assertTrue(settings.getSwitchToMtlsEndpointAllowed()); + assertEquals(endpoint, settings.getEndpoint()); + + // Test setEndpoint sets the switchToMtlsEndpointAllowed value to false. + settings = + new FakeStubSettings.Builder() + .setSwitchToMtlsEndpointAllowed(true) + .setEndpoint(endpoint) + .build(); + assertFalse(settings.getSwitchToMtlsEndpointAllowed()); + assertEquals(endpoint, settings.getEndpoint()); + } } diff --git a/gax/src/test/java/com/google/api/gax/rpc/mtls/AbstractMtlsTransportChannelTest.java b/gax/src/test/java/com/google/api/gax/rpc/mtls/AbstractMtlsTransportChannelTest.java new file mode 100644 index 000000000..a75c7b812 --- /dev/null +++ b/gax/src/test/java/com/google/api/gax/rpc/mtls/AbstractMtlsTransportChannelTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.api.gax.rpc.mtls; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.api.gax.rpc.mtls.MtlsProvider.MtlsEndpointUsagePolicy; +import com.google.api.gax.rpc.testing.FakeMtlsProvider; +import java.io.IOException; +import java.security.GeneralSecurityException; +import org.junit.Test; + +public abstract class AbstractMtlsTransportChannelTest { + /** + * Returns the mTLS object from the created transport channel. mTLS object is created with mTLS + * keystore. For HttpJsonTransportChannel, the mTLS object is the SslContext; for + * GrpcTransportChannel, the mTLS object is the ChannelCredentials. The transport channel is mTLS + * if and only if the related mTLS object is not null. + */ + protected abstract Object getMtlsObjectFromTransportChannel(MtlsProvider provider) + throws IOException, GeneralSecurityException; + + @Test + public void testNotUseClientCertificate() throws IOException, GeneralSecurityException { + MtlsProvider provider = + new FakeMtlsProvider(false, MtlsEndpointUsagePolicy.AUTO, null, "", false); + assertNull(getMtlsObjectFromTransportChannel(provider)); + } + + @Test + public void testUseClientCertificate() throws IOException, GeneralSecurityException { + MtlsProvider provider = + new FakeMtlsProvider( + true, + MtlsEndpointUsagePolicy.AUTO, + FakeMtlsProvider.createTestMtlsKeyStore(), + "", + false); + assertNotNull(getMtlsObjectFromTransportChannel(provider)); + } + + @Test + public void testNoClientCertificate() throws IOException, GeneralSecurityException { + MtlsProvider provider = + new FakeMtlsProvider(true, MtlsEndpointUsagePolicy.AUTO, null, "", false); + assertNull(getMtlsObjectFromTransportChannel(provider)); + } + + @Test + public void testGetKeyStoreThrows() throws GeneralSecurityException { + // Test the case where provider.getKeyStore() throws. + MtlsProvider provider = + new FakeMtlsProvider(true, MtlsEndpointUsagePolicy.AUTO, null, "", true); + try { + getMtlsObjectFromTransportChannel(provider); + fail("should throw an exception"); + } catch (IOException e) { + assertTrue( + "expected getKeyStore to throw an exception", + e.getMessage().contains("getKeyStore throws exception")); + } + } +} diff --git a/gax/src/test/java/com/google/api/gax/rpc/mtls/MtlsProviderTest.java b/gax/src/test/java/com/google/api/gax/rpc/mtls/MtlsProviderTest.java new file mode 100644 index 000000000..1888402f0 --- /dev/null +++ b/gax/src/test/java/com/google/api/gax/rpc/mtls/MtlsProviderTest.java @@ -0,0 +1,233 @@ +/* + * Copyright 2021 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.api.gax.rpc.mtls; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.GeneralSecurityException; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MtlsProviderTest { + static class TestEnvironmentProvider implements MtlsProvider.EnvironmentProvider { + private final String useClientCertificate; + private final String useMtlsEndpoint; + + TestEnvironmentProvider(String useClientCertificate, String useMtlsEndpoint) { + this.useClientCertificate = useClientCertificate; + this.useMtlsEndpoint = useMtlsEndpoint; + } + + @Override + public String getenv(String name) { + if (name.equals("GOOGLE_API_USE_MTLS_ENDPOINT")) { + return useMtlsEndpoint; + } + return useClientCertificate; + } + } + + static class TestCertProviderCommandProcess extends Process { + private boolean runForever; + private int exitValue; + + public TestCertProviderCommandProcess(int exitValue, boolean runForever) { + this.runForever = runForever; + this.exitValue = exitValue; + } + + @Override + public OutputStream getOutputStream() { + return null; + } + + @Override + public InputStream getInputStream() { + return null; + } + + @Override + public InputStream getErrorStream() { + return null; + } + + @Override + public int waitFor() throws InterruptedException { + return 0; + } + + @Override + public int exitValue() { + if (runForever) { + throw new IllegalThreadStateException(); + } + return exitValue; + } + + @Override + public void destroy() {} + } + + static class TestProcessProvider implements MtlsProvider.ProcessProvider { + private int exitCode; + + public TestProcessProvider(int exitCode) { + this.exitCode = exitCode; + } + + @Override + public Process createProcess(InputStream metadata) throws IOException { + return new TestCertProviderCommandProcess(exitCode, false); + } + } + + @Test + public void testUseMtlsEndpointAlways() { + MtlsProvider mtlsProvider = + new MtlsProvider( + new TestEnvironmentProvider("false", "always"), + new TestProcessProvider(0), + "/path/to/missing/file"); + assertEquals( + MtlsProvider.MtlsEndpointUsagePolicy.ALWAYS, mtlsProvider.getMtlsEndpointUsagePolicy()); + } + + @Test + public void testUseMtlsEndpointAuto() { + MtlsProvider mtlsProvider = + new MtlsProvider( + new TestEnvironmentProvider("false", "auto"), + new TestProcessProvider(0), + "/path/to/missing/file"); + assertEquals( + MtlsProvider.MtlsEndpointUsagePolicy.AUTO, mtlsProvider.getMtlsEndpointUsagePolicy()); + } + + @Test + public void testUseMtlsEndpointNever() { + MtlsProvider mtlsProvider = + new MtlsProvider( + new TestEnvironmentProvider("false", "never"), + new TestProcessProvider(0), + "/path/to/missing/file"); + assertEquals( + MtlsProvider.MtlsEndpointUsagePolicy.NEVER, mtlsProvider.getMtlsEndpointUsagePolicy()); + } + + @Test + public void testUseMtlsClientCertificateTrue() { + MtlsProvider mtlsProvider = + new MtlsProvider( + new TestEnvironmentProvider("true", "auto"), + new TestProcessProvider(0), + "/path/to/missing/file"); + assertTrue(mtlsProvider.useMtlsClientCertificate()); + } + + @Test + public void testUseMtlsClientCertificateFalse() { + MtlsProvider mtlsProvider = + new MtlsProvider( + new TestEnvironmentProvider("false", "auto"), + new TestProcessProvider(0), + "/path/to/missing/file"); + assertFalse(mtlsProvider.useMtlsClientCertificate()); + } + + @Test + public void testGetKeyStore() throws IOException { + MtlsProvider mtlsProvider = + new MtlsProvider( + new TestEnvironmentProvider("false", "always"), + new TestProcessProvider(0), + "/path/to/missing/file"); + assertNull(mtlsProvider.getKeyStore()); + } + + @Test + public void testGetKeyStoreNonZeroExitCode() + throws IOException, InterruptedException, GeneralSecurityException { + try { + InputStream metadata = + this.getClass() + .getClassLoader() + .getResourceAsStream("com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem"); + MtlsProvider.getKeyStore(metadata, new TestProcessProvider(1)); + fail("should throw an exception"); + } catch (IOException e) { + assertTrue( + "expected to fail with nonzero exit code", + e.getMessage().contains("Cert provider command failed with exit code: 1")); + } + } + + @Test + public void testExtractCertificateProviderCommand() throws IOException { + InputStream inputStream = + this.getClass() + .getClassLoader() + .getResourceAsStream("com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json"); + List command = MtlsProvider.extractCertificateProviderCommand(inputStream); + assertEquals(2, command.size()); + assertEquals("some_binary", command.get(0)); + assertEquals("some_argument", command.get(1)); + } + + @Test + public void testRunCertificateProviderCommandSuccess() throws IOException, InterruptedException { + Process certCommandProcess = new TestCertProviderCommandProcess(0, false); + int exitValue = MtlsProvider.runCertificateProviderCommand(certCommandProcess, 100); + assertEquals(0, exitValue); + } + + @Test + public void testRunCertificateProviderCommandTimeout() throws InterruptedException { + Process certCommandProcess = new TestCertProviderCommandProcess(0, true); + try { + MtlsProvider.runCertificateProviderCommand(certCommandProcess, 100); + fail("should throw an exception"); + } catch (IOException e) { + assertTrue( + "expected to fail with timeout", + e.getMessage().contains("cert provider command timed out")); + } + } +} diff --git a/gax/src/test/java/com/google/api/gax/rpc/testing/FakeMtlsProvider.java b/gax/src/test/java/com/google/api/gax/rpc/testing/FakeMtlsProvider.java new file mode 100644 index 000000000..e5f5c3a91 --- /dev/null +++ b/gax/src/test/java/com/google/api/gax/rpc/testing/FakeMtlsProvider.java @@ -0,0 +1,90 @@ +/* + * Copyright 2021 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.api.gax.rpc.testing; + +import com.google.api.client.util.SecurityUtils; +import com.google.api.core.InternalApi; +import com.google.api.gax.rpc.mtls.MtlsProvider; +import java.io.IOException; +import java.io.InputStream; +import java.security.GeneralSecurityException; +import java.security.KeyStore; + +@InternalApi("for testing") +public class FakeMtlsProvider extends MtlsProvider { + private boolean useClientCertificate; + private MtlsEndpointUsagePolicy mtlsEndpointUsagePolicy; + private KeyStore keyStore; + private boolean throwExceptionForGetKeyStore; + + public FakeMtlsProvider( + boolean useClientCertificate, + MtlsEndpointUsagePolicy mtlsEndpointUsagePolicy, + KeyStore keystore, + String keyStorePassword, + boolean throwExceptionForGetKeyStore) { + super(); + this.useClientCertificate = useClientCertificate; + this.mtlsEndpointUsagePolicy = mtlsEndpointUsagePolicy; + this.keyStore = keystore; + this.throwExceptionForGetKeyStore = throwExceptionForGetKeyStore; + } + + @Override + public boolean useMtlsClientCertificate() { + return useClientCertificate; + } + + @Override + public MtlsEndpointUsagePolicy getMtlsEndpointUsagePolicy() { + return mtlsEndpointUsagePolicy; + } + + @Override + public KeyStore getKeyStore() throws IOException { + if (throwExceptionForGetKeyStore) { + throw new IOException("getKeyStore throws exception"); + } + return keyStore; + } + + public static KeyStore createTestMtlsKeyStore() throws IOException { + try { + InputStream certAndKey = + FakeMtlsProvider.class + .getClassLoader() + .getResourceAsStream("com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem"); + return SecurityUtils.createMtlsKeyStore(certAndKey); + } catch (GeneralSecurityException e) { + throw new IOException("Failed to create key store", e); + } + } +} diff --git a/gax/src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem b/gax/src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem new file mode 100644 index 000000000..f95c93bcf --- /dev/null +++ b/gax/src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem @@ -0,0 +1,30 @@ +-----BEGIN CERTIFICATE----- +MIICGzCCAYSgAwIBAgIIWrt6xtmHPs4wDQYJKoZIhvcNAQEFBQAwMzExMC8GA1UE +AxMoMTAwOTEyMDcyNjg3OC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbTAeFw0x +MjEyMDExNjEwNDRaFw0yMjExMjkxNjEwNDRaMDMxMTAvBgNVBAMTKDEwMDkxMjA3 +MjY4NzguYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20wgZ8wDQYJKoZIhvcNAQEB +BQADgY0AMIGJAoGBAL1SdY8jTUVU7O4/XrZLYTw0ON1lV6MQRGajFDFCqD2Fd9tQ +GLW8Iftx9wfXe1zuaehJSgLcyCxazfyJoN3RiONBihBqWY6d3lQKqkgsRTNZkdFJ +Wdzl/6CxhK9sojh2p0r3tydtv9iwq5fuuWIvtODtT98EgphhncQAqkKoF3zVAgMB +AAGjODA2MAwGA1UdEwEB/wQCMAAwDgYDVR0PAQH/BAQDAgeAMBYGA1UdJQEB/wQM +MAoGCCsGAQUFBwMCMA0GCSqGSIb3DQEBBQUAA4GBAD8XQEqzGePa9VrvtEGpf+R4 +fkxKbcYAzqYq202nKu0kfjhIYkYSBj6gi348YaxE64yu60TVl42l5HThmswUheW4 +uQIaq36JvwvsDP5Zoj5BgiNSnDAFQp+jJFBRUA5vooJKgKgMDf/r/DCOsbO6VJF1 +kWwa9n19NFiV0z3m6isj +-----END CERTIFICATE----- +-----BEGIN PRIVATE KEY----- +MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAL1SdY8jTUVU7O4/ +XrZLYTw0ON1lV6MQRGajFDFCqD2Fd9tQGLW8Iftx9wfXe1zuaehJSgLcyCxazfyJ +oN3RiONBihBqWY6d3lQKqkgsRTNZkdFJWdzl/6CxhK9sojh2p0r3tydtv9iwq5fu +uWIvtODtT98EgphhncQAqkKoF3zVAgMBAAECgYB51B9cXe4yiGTzJ4pOKpHGySAy +sC1F/IjXt2eeD3PuKv4m/hL4l7kScpLx0+NJuQ4j8U2UK/kQOdrGANapB1ZbMZAK +/q0xmIUzdNIDiGSoTXGN2mEfdsEpQ/Xiv0lyhYBBPC/K4sYIpHccnhSRQUZlWLLY +lE5cFNKC9b7226mNvQJBAPt0hfCNIN0kUYOA9jdLtx7CE4ySGMPf5KPBuzPd8ty1 +fxaFm9PB7B76VZQYmHcWy8rT5XjoLJHrmGW1ZvP+iDsCQQDAvnKoarPOGb5iJfkq +RrA4flf1TOlf+1+uqIOJ94959jkkJeb0gv/TshDnm6/bWn+1kJylQaKygCizwPwB +Z84vAkA0Duur4YvsPJijoQ9YY1SGCagCcjyuUKwFOxaGpmyhRPIKt56LOJqpzyno +fy8ReKa4VyYq4eZYT249oFCwMwIBAkAROPNF2UL3x5UbcAkznd1hLujtIlI4IV4L +XUNjsJtBap7we/KHJq11XRPlniO4lf2TW7iji5neGVWJulTKS1xBAkAerktk4Hsw +ErUaUG1s/d+Sgc8e/KMeBElV+NxGhcWEeZtfHMn/6VOlbzY82JyvC9OKC80A5CAE +VUV6b25kqrcu +-----END PRIVATE KEY----- diff --git a/gax/src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json b/gax/src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json new file mode 100644 index 000000000..62f10dd25 --- /dev/null +++ b/gax/src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json @@ -0,0 +1,9 @@ +{ + "cert_provider_command": [ + "some_binary", + "some_argument" + ], + "device_resource_ids": [ + "123" + ] +}