Skip to content

Commit

Permalink
fix: respect PDML timeout when using streaming RPC (#338)
Browse files Browse the repository at this point in the history
* fix: respect PDML timeout when using streaming RPC

* fix: check for negative or zero deadline

* fix: subtract from original timeout to get remaining
  • Loading branch information
olavloite committed Jul 15, 2020
1 parent 78c3192 commit d67f108
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 5 deletions.
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkState;

import com.google.api.gax.grpc.GrpcStatusCode;
import com.google.api.gax.rpc.DeadlineExceededException;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.UnavailableException;
import com.google.cloud.spanner.SessionImpl.SessionTransaction;
Expand Down Expand Up @@ -77,13 +78,12 @@ private ByteString initTransaction() {
* statement, and will retry the stream if an {@link UnavailableException} is thrown, using the
* last seen resume token if the server returns any.
*/
long executeStreamingPartitionedUpdate(final Statement statement, Duration timeout) {
long executeStreamingPartitionedUpdate(final Statement statement, final Duration timeout) {
checkState(isValid, "Partitioned DML has been invalidated by a new operation on the session");
log.log(Level.FINER, "Starting PartitionedUpdate statement");
boolean foundStats = false;
long updateCount = 0L;
Duration remainingTimeout = timeout;
Stopwatch stopWatch = Stopwatch.createStarted();
Stopwatch stopWatch = createStopwatchStarted();
try {
// Loop to catch AbortedExceptions.
while (true) {
Expand All @@ -105,8 +105,13 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
}
}
while (true) {
remainingTimeout =
remainingTimeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS);
Duration remainingTimeout =
timeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS);
if (remainingTimeout.isNegative() || remainingTimeout.isZero()) {
// The total deadline has been exceeded while retrying.
throw new DeadlineExceededException(
null, GrpcStatusCode.of(Code.DEADLINE_EXCEEDED), false);
}
try {
builder.setResumeToken(resumeToken);
ServerStream<PartialResultSet> stream =
Expand Down Expand Up @@ -157,6 +162,10 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
}
}

Stopwatch createStopwatchStarted() {
return Stopwatch.createStarted();
}

@Override
public void invalidate() {
isValid = false;
Expand Down
Expand Up @@ -369,6 +369,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
.setStreamWatchdogProvider(watchdogProvider)
.executeSqlSettings()
.setRetrySettings(partitionedDmlRetrySettings);
pdmlSettings.executeStreamingSqlSettings().setRetrySettings(partitionedDmlRetrySettings);
// The stream watchdog will by default only check for a timeout every 10 seconds, so if the
// timeout is less than 10 seconds, it would be ignored for the first 10 seconds unless we
// also change the StreamWatchdogCheckInterval.
Expand Down
@@ -0,0 +1,298 @@
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyMap;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.api.gax.grpc.GrpcStatusCode;
import com.google.api.gax.rpc.AbortedException;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.UnavailableException;
import com.google.cloud.spanner.spi.v1.SpannerRpc;
import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.spanner.v1.BeginTransactionRequest;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.ExecuteSqlRequest.QueryMode;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.ResultSetStats;
import com.google.spanner.v1.Transaction;
import com.google.spanner.v1.TransactionSelector;
import io.grpc.Status.Code;
import java.util.Collections;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.threeten.bp.Duration;

@SuppressWarnings("unchecked")
@RunWith(JUnit4.class)
public class PartitionedDmlTransactionTest {

@Mock private SpannerRpc rpc;

@Mock private SessionImpl session;

private final String sessionId = "projects/p/instances/i/databases/d/sessions/s";
private final ByteString txId = ByteString.copyFromUtf8("tx");
private final ByteString resumeToken = ByteString.copyFromUtf8("resume");
private final String sql = "UPDATE FOO SET BAR=1 WHERE TRUE";
private final ExecuteSqlRequest executeRequestWithoutResumeToken =
ExecuteSqlRequest.newBuilder()
.setQueryMode(QueryMode.NORMAL)
.setSession(sessionId)
.setSql(sql)
.setTransaction(TransactionSelector.newBuilder().setId(txId))
.build();
private final ExecuteSqlRequest executeRequestWithResumeToken =
executeRequestWithoutResumeToken.toBuilder().setResumeToken(resumeToken).build();

@Before
public void setup() {
MockitoAnnotations.initMocks(this);
when(session.getName()).thenReturn(sessionId);
when(session.getOptions()).thenReturn(Collections.EMPTY_MAP);
when(rpc.beginTransaction(any(BeginTransactionRequest.class), anyMap()))
.thenReturn(Transaction.newBuilder().setId(txId).build());
}

@Test
public void testExecuteStreamingPartitionedUpdate() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream = mock(ServerStream.class);
when(stream.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateAborted() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
any(ExecuteSqlRequest.class), anyMap(), any(Duration.class)))
.thenReturn(stream1, stream2);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc, times(2))
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateUnavailable() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new UnavailableException(
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
when(stream1.iterator()).thenReturn(iterator);
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream2);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateUnavailableAndThenDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new UnavailableException(
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}

@Test
public void testExecuteStreamingPartitionedUpdateAbortedAndThenDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}

@Test
public void testExecuteStreamingPartitionedUpdateMultipleAbortsUntilDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
long ticks = 0L;

@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenAnswer(
new Answer<Long>() {
@Override
public Long answer(InvocationOnMock invocation) throws Throwable {
return TimeUnit.NANOSECONDS.convert(++ticks, TimeUnit.MINUTES);
}
});
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
// It should start a transaction exactly 10 times (10 ticks == 10 minutes).
verify(rpc, times(10)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
// The last transaction should timeout before it starts the actual statement execution, which
// means that the execute method is only executed 9 times.
verify(rpc, times(9))
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}
}

0 comments on commit d67f108

Please sign in to comment.