Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
fix: do not override grpc default executor
Browse files Browse the repository at this point in the history
  • Loading branch information
mutianf committed Apr 28, 2021
1 parent 52e39f8 commit 73bc3ce
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 28 deletions.
Expand Up @@ -123,6 +123,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
: builder.directPathServiceConfig;
}

@Deprecated
@Override
public boolean needsExecutor() {
return executor == null;
Expand Down Expand Up @@ -200,9 +201,7 @@ public TransportChannelProvider withCredentials(Credentials credentials) {

@Override
public TransportChannel getTransportChannel() throws IOException {
if (needsExecutor()) {
throw new IllegalStateException("getTransportChannel() called when needsExecutor() is true");
} else if (needsHeaders()) {
if (needsHeaders()) {
throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true");
} else if (needsEndpoint()) {
throw new IllegalStateException("getTransportChannel() called when needsEndpoint() is true");
Expand Down
Expand Up @@ -36,24 +36,36 @@

import com.google.api.core.ApiFunction;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture;
import com.google.type.Color;
import com.google.type.Money;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -443,6 +455,62 @@ public void testWithDefaultDirectPathServiceConfig() {
assertThat(childPolicy.keySet()).containsExactly("pick_first");
}

@Test
public void testDefaultExecutor() throws Exception {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:1234")
.setHeaderProvider(FixedHeaderProvider.create())
.build();

// The default name thread name for grpc threads configured in GrpcUtil
assertThat(extractExecutorThreadName(provider)).contains("grpc-default-executor");
}

/**
* Extract the name of the channel executor thread by instantiating a channel and issuing a fake
* call.
*/
private static String extractExecutorThreadName(InstantiatingGrpcChannelProvider channelProvider)
throws IOException, ExecutionException, InterruptedException {
GrpcTransportChannel transportChannel =
(GrpcTransportChannel) channelProvider.getTransportChannel();
try {
Channel channel = transportChannel.getChannel();

ClientCall<com.google.type.Color, Money> call =
channel.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT);
Color request = Color.getDefaultInstance();

final SettableFuture<String> threadNameFuture = SettableFuture.create();

// Issue a call just to get the thread name of the channel executor
ClientCalls.asyncUnaryCall(
call,
request,
new StreamObserver<Money>() {
@Override
public void onNext(Money ignored) {
threadNameFuture.set(Thread.currentThread().getName());
}

@Override
public void onError(Throwable ignored) {
threadNameFuture.set(Thread.currentThread().getName());
}

@Override
public void onCompleted() {
threadNameFuture.set(Thread.currentThread().getName());
}
});
return threadNameFuture.get();
} finally {
transportChannel.shutdown();
transportChannel.awaitTermination(10, TimeUnit.SECONDS);
}
}

@Nullable
private static Map<String, ?> getAsObject(Map<String, ?> json, String key) {
Object mapObject = json.get(key);
Expand Down
Expand Up @@ -74,6 +74,7 @@ public boolean shouldAutoClose() {
return true;
}

@Deprecated
@Override
public boolean needsExecutor() {
return false;
Expand Down
Expand Up @@ -84,6 +84,7 @@ private InstantiatingHttpJsonChannelProvider(
this.httpTransport = httpTransport;
}

@Deprecated
@Override
public boolean needsExecutor() {
return executor == null;
Expand Down Expand Up @@ -140,9 +141,7 @@ public String getTransportName() {

@Override
public TransportChannel getTransportChannel() throws IOException {
if (needsExecutor()) {
throw new IllegalStateException("getTransportChannel() called when needsExecutor() is true");
} else if (needsHeaders()) {
if (needsHeaders()) {
throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true");
} else {
return createChannel();
Expand Down
Expand Up @@ -39,17 +39,26 @@
import com.google.api.gax.core.BackgroundResource;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/** Implementation of HttpJsonChannel which can issue http-json calls. */
@BetaApi
public class ManagedHttpJsonChannel implements HttpJsonChannel, BackgroundResource {
private static final JsonFactory JSON_FACTORY = GsonFactory.getDefaultInstance();
private static final ExecutorService DEFAULT_EXECUTOR =
Executors.newCachedThreadPool(
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("http-default-executor-%d")
.build());

private final Executor executor;
private final String endpoint;
Expand Down Expand Up @@ -134,7 +143,9 @@ public boolean awaitTermination(long duration, TimeUnit unit) throws Interrupted
public void close() {}

public static Builder newBuilder() {
return new Builder().setHeaderEnhancers(new LinkedList<HttpJsonHeaderEnhancer>());
return new Builder()
.setHeaderEnhancers(new LinkedList<HttpJsonHeaderEnhancer>())
.setExecutor(DEFAULT_EXECUTOR);
}

public static class Builder {
Expand All @@ -147,7 +158,11 @@ public static class Builder {
private Builder() {}

public Builder setExecutor(Executor executor) {
this.executor = executor;
if (executor != null) {
this.executor = executor;
} else {
this.executor = DEFAULT_EXECUTOR;
}
return this;
}

Expand Down
Expand Up @@ -31,16 +31,29 @@

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;

import com.google.api.core.ApiFuture;
import com.google.api.gax.httpjson.testing.MockHttpService;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

@RunWith(JUnit4.class)
public class InstantiatingHttpJsonChannelProviderTest {
Expand Down Expand Up @@ -94,4 +107,85 @@ public void basicTest() throws IOException {
// Make sure we can create channels OK.
provider.getTransportChannel().shutdownNow();
}

@Test
public void testDefaultExecutor() throws Exception {
// Create a mock service that will always return errors. We just want to inspect the thread that
// those errors are returned on
MockHttpService mockHttpService =
new MockHttpService(Collections.<ApiMethodDescriptor>emptyList(), "/");
mockHttpService.addException(new RuntimeException("Fake error"));
InstantiatingHttpJsonChannelProvider channelProvider =
InstantiatingHttpJsonChannelProvider.newBuilder()
.setEndpoint("localhost:1234")
.setHeaderProvider(FixedHeaderProvider.create())
.setHttpTransport(mockHttpService)
.build();

assertThat(getThreadName(channelProvider)).contains("http-default-executor");
}

@Test
public void testExecutorOverride() throws IOException, ExecutionException, InterruptedException {
MockHttpService mockHttpService =
new MockHttpService(Collections.<ApiMethodDescriptor>emptyList(), "/");
mockHttpService.addException(new RuntimeException("Fake error"));

final String expectedThreadName = "testExecutorOverrideExecutor";

ExecutorService executor =
Executors.newFixedThreadPool(
1,
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(expectedThreadName).build());
try {
InstantiatingHttpJsonChannelProvider channelProvider =
InstantiatingHttpJsonChannelProvider.newBuilder()
.setExecutor(executor)
.setEndpoint("localhost:1234")
.setHeaderProvider(FixedHeaderProvider.create())
.setHttpTransport(mockHttpService)
.build();

assertThat(getThreadName(channelProvider)).isEqualTo(expectedThreadName);
} finally {
executor.shutdown();
executor.awaitTermination(10, TimeUnit.SECONDS);
}
}

private static String getThreadName(InstantiatingHttpJsonChannelProvider provider)
throws IOException, InterruptedException, ExecutionException {
@SuppressWarnings("unchecked")
ApiMethodDescriptor<Object, Object> apiMethodDescriptor =
mock(
ApiMethodDescriptor.class,
new Answer() {
@Override
public Object answer(InvocationOnMock invocation) {
throw new UnsupportedOperationException("fake error");
}
});

HttpJsonTransportChannel transportChannel =
(HttpJsonTransportChannel) provider.getTransportChannel();
final SettableFuture<String> threadNameFuture = SettableFuture.create();
try {
HttpJsonChannel channel = transportChannel.getChannel();
ApiFuture<Object> rpcFuture =
channel.issueFutureUnaryCall(
HttpJsonCallOptions.newBuilder().build(), new Object(), apiMethodDescriptor);
rpcFuture.addListener(
new Runnable() {
@Override
public void run() {
threadNameFuture.set(Thread.currentThread().getName());
}
},
MoreExecutors.directExecutor());
} finally {
transportChannel.shutdown();
transportChannel.awaitTermination(10, TimeUnit.SECONDS);
}
return threadNameFuture.get();
}
}
25 changes: 16 additions & 9 deletions gax/src/main/java/com/google/api/gax/rpc/ClientContext.java
Expand Up @@ -49,7 +49,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -139,8 +138,13 @@ public static ClientContext create(ClientSettings settings) throws IOException {
public static ClientContext create(StubSettings settings) throws IOException {
ApiClock clock = settings.getClock();

ExecutorProvider executorProvider = settings.getExecutorProvider();
final ScheduledExecutorService executor = executorProvider.getExecutor();
ExecutorProvider workerExecutorProvider = settings.getWorkerExecutorProvider();
final ScheduledExecutorService workerExecutor = workerExecutorProvider.getExecutor();

final ScheduledExecutorService executor =
settings.getExecutorProvider() == null
? null
: settings.getExecutorProvider().getExecutor();

Credentials credentials = settings.getCredentialsProvider().getCredentials();

Expand All @@ -153,8 +157,11 @@ public static ClientContext create(StubSettings settings) throws IOException {
}

TransportChannelProvider transportChannelProvider = settings.getTransportChannelProvider();
if (transportChannelProvider.needsExecutor()) {
transportChannelProvider = transportChannelProvider.withExecutor((Executor) executor);
// After needsExecutor and StubSettings#setExecutor are deprecated, transport channel executor
// can only be set from TransportChannelProvider#withExecutor directly, and all providers will
// have default executors.
if (transportChannelProvider.needsExecutor() && executor != null) {
transportChannelProvider = transportChannelProvider.withExecutor(executor);
}
Map<String, String> headers = getHeadersFromSettings(settings);
if (transportChannelProvider.needsHeaders()) {
Expand Down Expand Up @@ -186,7 +193,7 @@ public static ClientContext create(StubSettings settings) throws IOException {
watchdogProvider = watchdogProvider.withClock(clock);
}
if (watchdogProvider.needsExecutor()) {
watchdogProvider = watchdogProvider.withExecutor(executor);
watchdogProvider = watchdogProvider.withExecutor(workerExecutor);
}
watchdog = watchdogProvider.getWatchdog();
}
Expand All @@ -196,16 +203,16 @@ public static ClientContext create(StubSettings settings) throws IOException {
if (transportChannelProvider.shouldAutoClose()) {
backgroundResources.add(transportChannel);
}
if (executorProvider.shouldAutoClose()) {
backgroundResources.add(new ExecutorAsBackgroundResource(executor));
if (workerExecutorProvider.shouldAutoClose()) {
backgroundResources.add(new ExecutorAsBackgroundResource(workerExecutor));
}
if (watchdogProvider != null && watchdogProvider.shouldAutoClose()) {
backgroundResources.add(watchdog);
}

return newBuilder()
.setBackgroundResources(backgroundResources.build())
.setExecutor(executor)
.setExecutor(workerExecutor)
.setCredentials(credentials)
.setTransportChannel(transportChannel)
.setHeaders(ImmutableMap.copyOf(settings.getHeaderProvider().getHeaders()))
Expand Down

0 comments on commit 73bc3ce

Please sign in to comment.