Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: respect PDML timeout when using streaming RPC #338

Merged
merged 3 commits into from Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkState;

import com.google.api.gax.grpc.GrpcStatusCode;
import com.google.api.gax.rpc.DeadlineExceededException;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.UnavailableException;
import com.google.cloud.spanner.SessionImpl.SessionTransaction;
Expand Down Expand Up @@ -77,13 +78,12 @@ private ByteString initTransaction() {
* statement, and will retry the stream if an {@link UnavailableException} is thrown, using the
* last seen resume token if the server returns any.
*/
long executeStreamingPartitionedUpdate(final Statement statement, Duration timeout) {
long executeStreamingPartitionedUpdate(final Statement statement, final Duration timeout) {
checkState(isValid, "Partitioned DML has been invalidated by a new operation on the session");
log.log(Level.FINER, "Starting PartitionedUpdate statement");
boolean foundStats = false;
long updateCount = 0L;
Duration remainingTimeout = timeout;
Stopwatch stopWatch = Stopwatch.createStarted();
Stopwatch stopWatch = createStopwatchStarted();
try {
// Loop to catch AbortedExceptions.
while (true) {
Expand All @@ -105,8 +105,13 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
}
}
while (true) {
remainingTimeout =
remainingTimeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS);
Duration remainingTimeout =
timeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS);
if (remainingTimeout.isNegative() || remainingTimeout.isZero()) {
// The total deadline has been exceeded while retrying.
throw new DeadlineExceededException(
null, GrpcStatusCode.of(Code.DEADLINE_EXCEEDED), false);
}
try {
builder.setResumeToken(resumeToken);
ServerStream<PartialResultSet> stream =
Expand Down Expand Up @@ -157,6 +162,10 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
}
}

Stopwatch createStopwatchStarted() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it be hard to inject the Stopwatch instead? Not a big deal, just wondering if we could use composition here. Disregard if it would be too hard to inject.

return Stopwatch.createStarted();
}

@Override
public void invalidate() {
isValid = false;
Expand Down
Expand Up @@ -309,7 +309,7 @@ public GapicSpannerRpc(final SpannerOptions options) {

// Set a keepalive time of 120 seconds to help long running
// commit GRPC calls succeed
.setKeepAliveTime(Duration.ofSeconds(GRPC_KEEPALIVE_SECONDS * 1000))
.setKeepAliveTime(Duration.ofSeconds(GRPC_KEEPALIVE_SECONDS))

// Then check if SpannerOptions provides an InterceptorProvider. Create a default
// SpannerInterceptorProvider if none is provided
Expand Down Expand Up @@ -365,6 +365,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
.setStreamWatchdogProvider(watchdogProvider)
.executeSqlSettings()
.setRetrySettings(partitionedDmlRetrySettings);
pdmlSettings.executeStreamingSqlSettings().setRetrySettings(partitionedDmlRetrySettings);
// The stream watchdog will by default only check for a timeout every 10 seconds, so if the
// timeout is less than 10 seconds, it would be ignored for the first 10 seconds unless we
// also change the StreamWatchdogCheckInterval.
Expand Down
@@ -0,0 +1,298 @@
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyMap;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.api.gax.grpc.GrpcStatusCode;
import com.google.api.gax.rpc.AbortedException;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.UnavailableException;
import com.google.cloud.spanner.spi.v1.SpannerRpc;
import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.spanner.v1.BeginTransactionRequest;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.ExecuteSqlRequest.QueryMode;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.ResultSetStats;
import com.google.spanner.v1.Transaction;
import com.google.spanner.v1.TransactionSelector;
import io.grpc.Status.Code;
import java.util.Collections;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.threeten.bp.Duration;

@SuppressWarnings("unchecked")
@RunWith(JUnit4.class)
public class PartitionedDmlTransactionTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests!


@Mock private SpannerRpc rpc;

@Mock private SessionImpl session;

private final String sessionId = "projects/p/instances/i/databases/d/sessions/s";
private final ByteString txId = ByteString.copyFromUtf8("tx");
private final ByteString resumeToken = ByteString.copyFromUtf8("resume");
private final String sql = "UPDATE FOO SET BAR=1 WHERE TRUE";
private final ExecuteSqlRequest executeRequestWithoutResumeToken =
ExecuteSqlRequest.newBuilder()
.setQueryMode(QueryMode.NORMAL)
.setSession(sessionId)
.setSql(sql)
.setTransaction(TransactionSelector.newBuilder().setId(txId))
.build();
private final ExecuteSqlRequest executeRequestWithResumeToken =
executeRequestWithoutResumeToken.toBuilder().setResumeToken(resumeToken).build();

@Before
public void setup() {
MockitoAnnotations.initMocks(this);
when(session.getName()).thenReturn(sessionId);
when(session.getOptions()).thenReturn(Collections.EMPTY_MAP);
when(rpc.beginTransaction(any(BeginTransactionRequest.class), anyMap()))
.thenReturn(Transaction.newBuilder().setId(txId).build());
}

@Test
public void testExecuteStreamingPartitionedUpdate() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream = mock(ServerStream.class);
when(stream.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateAborted() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
any(ExecuteSqlRequest.class), anyMap(), any(Duration.class)))
.thenReturn(stream1, stream2);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc, times(2))
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateUnavailable() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new UnavailableException(
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
when(stream1.iterator()).thenReturn(iterator);
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream2);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateUnavailableAndThenDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new UnavailableException(
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}

@Test
public void testExecuteStreamingPartitionedUpdateAbortedAndThenDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}

@Test
public void testExecuteStreamingPartitionedUpdateMultipleAbortsUntilDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
long ticks = 0L;

@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenAnswer(
new Answer<Long>() {
@Override
public Long answer(InvocationOnMock invocation) throws Throwable {
return TimeUnit.NANOSECONDS.convert(++ticks, TimeUnit.MINUTES);
}
});
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
// It should start a transaction exactly 10 times (10 ticks == 10 minutes).
verify(rpc, times(10)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
// The last transaction should timeout before it starts the actual statement execution, which
// means that the execute method is only executed 9 times.
verify(rpc, times(9))
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}
}