diff --git a/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java b/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java index e16d03c3e..cdae46ffc 100644 --- a/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java +++ b/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java @@ -181,11 +181,12 @@ public void cancel() { } isStarted = true; - // Propagate the totalTimeout as the overall stream deadline. + // Propagate the totalTimeout as the overall stream deadline, so long as the user + // has not provided a timeout via the ApiCallContext. If they have, retain it. Duration totalTimeout = outerRetryingFuture.getAttemptSettings().getGlobalSettings().getTotalTimeout(); - if (totalTimeout != null && context != null) { + if (totalTimeout != null && context != null && context.getTimeout() == null) { context = context.withTimeout(totalTimeout); } @@ -217,7 +218,10 @@ public Void call() { ApiCallContext attemptContext = context; - if (!outerRetryingFuture.getAttemptSettings().getRpcTimeout().isZero()) { + // Set the streamWaitTimeout to the attempt RPC Timeout, only if the context + // does not already have a timeout set by a user via withStreamWaitTimeout. + if (!outerRetryingFuture.getAttemptSettings().getRpcTimeout().isZero() + && attemptContext.getStreamWaitTimeout() == null) { attemptContext = attemptContext.withStreamWaitTimeout( outerRetryingFuture.getAttemptSettings().getRpcTimeout()); diff --git a/gax/src/test/java/com/google/api/gax/rpc/ServerStreamingAttemptCallableTest.java b/gax/src/test/java/com/google/api/gax/rpc/ServerStreamingAttemptCallableTest.java index 68b48fb08..5cdf8bc4a 100644 --- a/gax/src/test/java/com/google/api/gax/rpc/ServerStreamingAttemptCallableTest.java +++ b/gax/src/test/java/com/google/api/gax/rpc/ServerStreamingAttemptCallableTest.java @@ -41,6 +41,7 @@ import com.google.api.gax.rpc.testing.FakeCallContext; import com.google.api.gax.rpc.testing.MockStreamingApi.MockServerStreamingCall; import com.google.api.gax.rpc.testing.MockStreamingApi.MockServerStreamingCallable; +import com.google.api.gax.tracing.NoopApiTracer; import com.google.common.collect.Queues; import com.google.common.truth.Truth; import java.util.concurrent.BlockingDeque; @@ -51,6 +52,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; import org.threeten.bp.Duration; @RunWith(JUnit4.class) @@ -59,22 +61,25 @@ public class ServerStreamingAttemptCallableTest { private AccumulatingObserver observer; private FakeRetryingFuture fakeRetryingFuture; private StreamResumptionStrategy resumptionStrategy; + private static Duration totalTimeout = Duration.ofHours(1); + private FakeCallContext mockedCallContext; @Before public void setUp() { innerCallable = new MockServerStreamingCallable<>(); observer = new AccumulatingObserver(true); resumptionStrategy = new MyStreamResumptionStrategy(); + mockedCallContext = Mockito.mock(FakeCallContext.class); } private ServerStreamingAttemptCallable createCallable() { + return createCallable(FakeCallContext.createDefault()); + } + + private ServerStreamingAttemptCallable createCallable(ApiCallContext context) { ServerStreamingAttemptCallable callable = new ServerStreamingAttemptCallable<>( - innerCallable, - resumptionStrategy, - "request", - FakeCallContext.createDefault(), - observer); + innerCallable, resumptionStrategy, "request", context, observer); fakeRetryingFuture = new FakeRetryingFuture(callable); callable.setExternalFuture(fakeRetryingFuture); @@ -82,6 +87,81 @@ private ServerStreamingAttemptCallable createCallable() { return callable; } + @Test + public void testUserProvidedContextTimeout() { + // Mock up the ApiCallContext as if the user provided a timeout and streamWaitTimeout. + Mockito.doReturn(NoopApiTracer.getInstance()).when(mockedCallContext).getTracer(); + Mockito.doReturn(Duration.ofHours(5)).when(mockedCallContext).getTimeout(); + Mockito.doReturn(Duration.ofHours(5)).when(mockedCallContext).getStreamWaitTimeout(); + + ServerStreamingAttemptCallable callable = createCallable(mockedCallContext); + callable.start(); + + // Ensure that the callable did not overwrite the user provided timeouts + Mockito.verify(mockedCallContext, Mockito.times(1)).getTimeout(); + Mockito.verify(mockedCallContext, Mockito.never()).withTimeout(totalTimeout); + Mockito.verify(mockedCallContext, Mockito.times(1)).getStreamWaitTimeout(); + Mockito.verify(mockedCallContext, Mockito.never()) + .withStreamWaitTimeout(Mockito.any(Duration.class)); + + // Should notify outer observer + Truth.assertThat(observer.controller).isNotNull(); + + // Should configure the inner controller correctly. + MockServerStreamingCall call = innerCallable.popLastCall(); + Truth.assertThat(call.getController().isAutoFlowControlEnabled()).isTrue(); + Truth.assertThat(call.getRequest()).isEqualTo("request"); + + // Send a response in auto flow mode. + call.getController().getObserver().onResponse("response1"); + call.getController().getObserver().onResponse("response2"); + call.getController().getObserver().onComplete(); + + // Make sure the responses are received + Truth.assertThat(observer.responses).containsExactly("response1", "response2").inOrder(); + fakeRetryingFuture.assertSuccess(); + } + + @Test + public void testNoUserProvidedContextTimeout() { + // Mock up the ApiCallContext as if the user did not provide custom timeouts. + Mockito.doReturn(NoopApiTracer.getInstance()).when(mockedCallContext).getTracer(); + Mockito.doReturn(null).when(mockedCallContext).getTimeout(); + Mockito.doReturn(null).when(mockedCallContext).getStreamWaitTimeout(); + Mockito.doReturn(mockedCallContext).when(mockedCallContext).withTimeout(totalTimeout); + Mockito.doReturn(mockedCallContext) + .when(mockedCallContext) + .withStreamWaitTimeout(Mockito.any(Duration.class)); + + ServerStreamingAttemptCallable callable = createCallable(mockedCallContext); + callable.start(); + + // Ensure that the callable configured the timeouts via the Settings in the + // absence of user-defined timeouts. + Mockito.verify(mockedCallContext, Mockito.times(1)).getTimeout(); + Mockito.verify(mockedCallContext, Mockito.times(1)).withTimeout(totalTimeout); + Mockito.verify(mockedCallContext, Mockito.times(1)).getStreamWaitTimeout(); + Mockito.verify(mockedCallContext, Mockito.times(1)) + .withStreamWaitTimeout(Mockito.any(Duration.class)); + + // Should notify outer observer + Truth.assertThat(observer.controller).isNotNull(); + + // Should configure the inner controller correctly. + MockServerStreamingCall call = innerCallable.popLastCall(); + Truth.assertThat(call.getController().isAutoFlowControlEnabled()).isTrue(); + Truth.assertThat(call.getRequest()).isEqualTo("request"); + + // Send a response in auto flow mode. + call.getController().getObserver().onResponse("response1"); + call.getController().getObserver().onResponse("response2"); + call.getController().getObserver().onComplete(); + + // Make sure the responses are received + Truth.assertThat(observer.responses).containsExactly("response1", "response2").inOrder(); + fakeRetryingFuture.assertSuccess(); + } + @Test public void testNoErrorsAutoFlow() { ServerStreamingAttemptCallable callable = createCallable(); @@ -396,8 +476,7 @@ private static class FakeRetryingFuture extends AbstractApiFuture this.attemptCallable = attemptCallable; attemptSettings = TimedAttemptSettings.newBuilder() - .setGlobalSettings( - RetrySettings.newBuilder().setTotalTimeout(Duration.ofHours(1)).build()) + .setGlobalSettings(RetrySettings.newBuilder().setTotalTimeout(totalTimeout).build()) .setFirstAttemptStartTimeNanos(0) .setAttemptCount(0) .setOverallAttemptCount(0)