Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
gax-grpc: allow custom direct path service config (#1235)
Browse files Browse the repository at this point in the history
  • Loading branch information
dapengzhang0 committed Dec 21, 2020
1 parent ef9b3aa commit 20ce65b
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 21 deletions.
Expand Up @@ -40,6 +40,7 @@
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -94,6 +95,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final Credentials credentials;
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;

@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
Expand All @@ -115,6 +117,10 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.credentials = builder.credentials;
this.channelPrimer = builder.channelPrimer;
this.attemptDirectPath = builder.attemptDirectPath;
this.directPathServiceConfig =
builder.directPathServiceConfig == null
? getDefaultDirectPathServiceConfig()
: builder.directPathServiceConfig;
}

@Override
Expand Down Expand Up @@ -271,7 +277,7 @@ private ManagedChannel createSingleChannel() throws IOException {
int port = Integer.parseInt(endpoint.substring(colon + 1));
String serviceAddress = endpoint.substring(0, colon);

ManagedChannelBuilder builder;
ManagedChannelBuilder<?> builder;

// TODO(weiranf): Add API in ComputeEngineCredentials to check default service account.
if (isDirectPathEnabled(serviceAddress)
Expand All @@ -282,26 +288,7 @@ && isOnComputeEngine()) {
// Will be overridden by user defined values if any.
builder.keepAliveTime(DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS, TimeUnit.SECONDS);
builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS);

// When channel pooling is enabled, force the pick_first grpclb strategy.
// This is necessary to avoid the multiplicative effect of creating channel pool with
// `poolSize` number of `ManagedChannel`s, each with a `subSetting` number of number of
// subchannels.
// See the service config proto definition for more details:
// https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L182
ImmutableMap<String, Object> pickFirstStrategy =
ImmutableMap.<String, Object>of("pick_first", ImmutableMap.of());

ImmutableMap<String, Object> childPolicy =
ImmutableMap.<String, Object>of("childPolicy", ImmutableList.of(pickFirstStrategy));

ImmutableMap<String, Object> grpcLbPolicy =
ImmutableMap.<String, Object>of("grpclb", childPolicy);

ImmutableMap<String, Object> loadBalancingConfig =
ImmutableMap.<String, Object>of("loadBalancingConfig", ImmutableList.of(grpcLbPolicy));

builder.defaultServiceConfig(loadBalancingConfig);
builder.defaultServiceConfig(directPathServiceConfig);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
}
Expand Down Expand Up @@ -400,6 +387,7 @@ public static final class Builder {
@Nullable private Credentials credentials;
@Nullable private ChannelPrimer channelPrimer;
@Nullable private Boolean attemptDirectPath;
@Nullable private ImmutableMap<String, ?> directPathServiceConfig;

private Builder() {
processorCount = Runtime.getRuntime().availableProcessors();
Expand Down Expand Up @@ -610,6 +598,21 @@ public Builder setAttemptDirectPath(boolean attemptDirectPath) {
return this;
}

/**
* Sets a service config for direct path. If direct path is not enabled, the provided service
* config will be ignored.
*
* <p>See <a href=
* "https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto">
* the service config proto definition</a> for more details.
*/
@InternalApi("For internal use by google-cloud-java clients only")
public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
Preconditions.checkNotNull(serviceConfig, "serviceConfig");
this.directPathServiceConfig = ImmutableMap.copyOf(serviceConfig);
return this;
}

public InstantiatingGrpcChannelProvider build() {
return new InstantiatingGrpcChannelProvider(this);
}
Expand All @@ -634,6 +637,25 @@ public ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> getChannelConfi
}
}

private static ImmutableMap<String, ?> getDefaultDirectPathServiceConfig() {
// When channel pooling is enabled, force the pick_first grpclb strategy.
// This is necessary to avoid the multiplicative effect of creating channel pool with
// `poolSize` number of `ManagedChannel`s, each with a `subSetting` number of number of
// subchannels.
// See the service config proto definition for more details:
// https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto
ImmutableMap<String, Object> pickFirstStrategy =
ImmutableMap.<String, Object>of("pick_first", ImmutableMap.of());

ImmutableMap<String, Object> childPolicy =
ImmutableMap.<String, Object>of("childPolicy", ImmutableList.of(pickFirstStrategy));

ImmutableMap<String, Object> grpcLbPolicy =
ImmutableMap.<String, Object>of("grpclb", childPolicy);

return ImmutableMap.<String, Object>of("loadBalancingConfig", ImmutableList.of(grpcLbPolicy));
}

private static void validateEndpoint(String endpoint) {
int colon = endpoint.lastIndexOf(':');
if (colon < 0) {
Expand Down
Expand Up @@ -29,6 +29,7 @@
*/
package com.google.api.gax.grpc;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
Expand All @@ -39,14 +40,21 @@
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import javax.annotation.Nullable;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -378,4 +386,78 @@ public void testWithPrimeChannel() throws IOException {
.primeChannel(Mockito.any(ManagedChannel.class));
}
}

@Test
public void testWithDefaultDirectPathServiceConfig() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder().build();

ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;

List<Map<String, ?>> lbConfigs = getAsObjectList(defaultServiceConfig, "loadBalancingConfig");
assertThat(lbConfigs).hasSize(1);
Map<String, ?> lbConfig = lbConfigs.get(0);
Map<String, ?> grpclb = getAsObject(lbConfig, "grpclb");
List<Map<String, ?>> childPolicies = getAsObjectList(grpclb, "childPolicy");
assertThat(childPolicies).hasSize(1);
Map<String, ?> childPolicy = childPolicies.get(0);
assertThat(childPolicy.keySet()).containsExactly("pick_first");
}

@Nullable
private static Map<String, ?> getAsObject(Map<String, ?> json, String key) {
Object mapObject = json.get(key);
if (mapObject == null) {
return null;
}
return checkObject(mapObject);
}

@SuppressWarnings("unchecked")
private static Map<String, ?> checkObject(Object json) {
checkArgument(json instanceof Map, "Invalid json object representation: %s", json);
for (Map.Entry<Object, Object> entry : ((Map<Object, Object>) json).entrySet()) {
checkArgument(entry.getKey() instanceof String, "Key is not string");
}
return (Map<String, ?>) json;
}

private static List<Map<String, ?>> getAsObjectList(Map<String, ?> json, String key) {
Object listObject = json.get(key);
if (listObject == null) {
return null;
}
return checkListOfObjects(listObject);
}

@SuppressWarnings("unchecked")
private static List<Map<String, ?>> checkListOfObjects(Object listObject) {
checkArgument(listObject instanceof List, "Passed object is not a list");
List<Map<String, ?>> list = new ArrayList<>();
for (Object object : ((List<Object>) listObject)) {
list.add(checkObject(object));
}
return list;
}

@Test
public void testWithCustomDirectPathServiceConfig() {
ImmutableMap<String, Object> pickFirstStrategy =
ImmutableMap.<String, Object>of("round_robin", ImmutableMap.of());
ImmutableMap<String, Object> childPolicy =
ImmutableMap.<String, Object>of(
"childPolicy", ImmutableList.of(pickFirstStrategy), "foo", "bar");
ImmutableMap<String, Object> grpcLbPolicy =
ImmutableMap.<String, Object>of("grpclb", childPolicy);
Map<String, Object> passedServiceConfig = new HashMap<>();
passedServiceConfig.put("loadBalancingConfig", ImmutableList.of(grpcLbPolicy));

InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setDirectPathServiceConfig(passedServiceConfig)
.build();

ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;
assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig);
}
}

0 comments on commit 20ce65b

Please sign in to comment.