Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
fix: check Compute Engine environment for DirectPath (#1250)
Browse files Browse the repository at this point in the history
  • Loading branch information
mohanli-ml committed Nov 19, 2020
1 parent d455da2 commit 656b613
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
Expand Up @@ -44,10 +44,12 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.CharStreams;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
Expand All @@ -74,6 +76,8 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
static final long DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS = 20;
// reduce the thundering herd problem of too many channels trying to (re)connect at the same time
static final int MAX_POOL_SIZE = 1000;
static final String GCE_PRODUCTION_NAME_PRIOR_2016 = "Google";
static final String GCE_PRODUCTION_NAME_AFTER_2016 = "Google Compute Engine";

private final int processorCount;
private final Executor executor;
Expand Down Expand Up @@ -234,6 +238,26 @@ private boolean isDirectPathEnabled(String serviceAddress) {
return false;
}

// DirectPath should only be used on Compute Engine.
// Notice Windows is supported for now.
static boolean isOnComputeEngine() {
String osName = System.getProperty("os.name");
if ("Linux".equals(osName)) {
String cmd = "cat /sys/class/dmi/id/product_name";
try {
Process process = Runtime.getRuntime().exec(new String[] {"/bin/sh", "-c", cmd});
process.waitFor();
String result =
CharStreams.toString(new InputStreamReader(process.getInputStream(), "UTF-8"));
return result.contains(GCE_PRODUCTION_NAME_PRIOR_2016)
|| result.contains(GCE_PRODUCTION_NAME_AFTER_2016);
} catch (IOException | InterruptedException e) {
return false;
}
}
return false;
}

private ManagedChannel createSingleChannel() throws IOException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headerProvider.getHeaders());
Expand All @@ -250,7 +274,9 @@ private ManagedChannel createSingleChannel() throws IOException {
ManagedChannelBuilder builder;

// TODO(weiranf): Add API in ComputeEngineCredentials to check default service account.
if (isDirectPathEnabled(serviceAddress) && credentials instanceof ComputeEngineCredentials) {
if (isDirectPathEnabled(serviceAddress)
&& credentials instanceof ComputeEngineCredentials
&& isOnComputeEngine()) {
builder = ComputeEngineChannelBuilder.forAddress(serviceAddress, port);
// Set default keepAliveTime and keepAliveTimeout when directpath environment is enabled.
// Will be overridden by user defined values if any.
Expand Down
Expand Up @@ -219,7 +219,11 @@ public void testWithGCECredentials() throws IOException {
ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator =
new ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder>() {
public ManagedChannelBuilder apply(ManagedChannelBuilder channelBuilder) {
assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isTrue();
if (InstantiatingGrpcChannelProvider.isOnComputeEngine()) {
assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isTrue();
} else {
assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isFalse();
}
return channelBuilder;
}
};
Expand All @@ -234,7 +238,11 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder channelBuilder) {
.withEndpoint("localhost:8080");

assertThat(provider.needsCredentials()).isTrue();
provider = provider.withCredentials(ComputeEngineCredentials.create());
if (InstantiatingGrpcChannelProvider.isOnComputeEngine()) {
provider = provider.withCredentials(ComputeEngineCredentials.create());
} else {
provider = provider.withCredentials(CloudShellCredentials.create(3000));
}
assertThat(provider.needsCredentials()).isFalse();

provider.getTransportChannel().shutdownNow();
Expand Down

0 comments on commit 656b613

Please sign in to comment.