diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java index 96ae390dd6..a938d17447 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.api.gax.grpc.GrpcStatusCode; +import com.google.api.gax.rpc.DeadlineExceededException; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.UnavailableException; import com.google.cloud.spanner.SessionImpl.SessionTransaction; @@ -77,13 +78,12 @@ private ByteString initTransaction() { * statement, and will retry the stream if an {@link UnavailableException} is thrown, using the * last seen resume token if the server returns any. */ - long executeStreamingPartitionedUpdate(final Statement statement, Duration timeout) { + long executeStreamingPartitionedUpdate(final Statement statement, final Duration timeout) { checkState(isValid, "Partitioned DML has been invalidated by a new operation on the session"); log.log(Level.FINER, "Starting PartitionedUpdate statement"); boolean foundStats = false; long updateCount = 0L; - Duration remainingTimeout = timeout; - Stopwatch stopWatch = Stopwatch.createStarted(); + Stopwatch stopWatch = createStopwatchStarted(); try { // Loop to catch AbortedExceptions. while (true) { @@ -105,8 +105,13 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo } } while (true) { - remainingTimeout = - remainingTimeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS); + Duration remainingTimeout = + timeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS); + if (remainingTimeout.isNegative() || remainingTimeout.isZero()) { + // The total deadline has been exceeded while retrying. + throw new DeadlineExceededException( + null, GrpcStatusCode.of(Code.DEADLINE_EXCEEDED), false); + } try { builder.setResumeToken(resumeToken); ServerStream stream = @@ -157,6 +162,10 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo } } + Stopwatch createStopwatchStarted() { + return Stopwatch.createStarted(); + } + @Override public void invalidate() { isValid = false; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index b1e571510d..187a6e9a22 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -369,6 +369,7 @@ public GapicSpannerRpc(final SpannerOptions options) { .setStreamWatchdogProvider(watchdogProvider) .executeSqlSettings() .setRetrySettings(partitionedDmlRetrySettings); + pdmlSettings.executeStreamingSqlSettings().setRetrySettings(partitionedDmlRetrySettings); // The stream watchdog will by default only check for a timeout every 10 seconds, so if the // timeout is less than 10 seconds, it would be ignored for the first 10 seconds unless we // also change the StreamWatchdogCheckInterval. diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/PartitionedDmlTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/PartitionedDmlTransactionTest.java new file mode 100644 index 0000000000..8a158eb12e --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/PartitionedDmlTransactionTest.java @@ -0,0 +1,298 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyMap; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.gax.grpc.GrpcStatusCode; +import com.google.api.gax.rpc.AbortedException; +import com.google.api.gax.rpc.ServerStream; +import com.google.api.gax.rpc.UnavailableException; +import com.google.cloud.spanner.spi.v1.SpannerRpc; +import com.google.common.base.Stopwatch; +import com.google.common.base.Ticker; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; +import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.Transaction; +import com.google.spanner.v1.TransactionSelector; +import io.grpc.Status.Code; +import java.util.Collections; +import java.util.Iterator; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.threeten.bp.Duration; + +@SuppressWarnings("unchecked") +@RunWith(JUnit4.class) +public class PartitionedDmlTransactionTest { + + @Mock private SpannerRpc rpc; + + @Mock private SessionImpl session; + + private final String sessionId = "projects/p/instances/i/databases/d/sessions/s"; + private final ByteString txId = ByteString.copyFromUtf8("tx"); + private final ByteString resumeToken = ByteString.copyFromUtf8("resume"); + private final String sql = "UPDATE FOO SET BAR=1 WHERE TRUE"; + private final ExecuteSqlRequest executeRequestWithoutResumeToken = + ExecuteSqlRequest.newBuilder() + .setQueryMode(QueryMode.NORMAL) + .setSession(sessionId) + .setSql(sql) + .setTransaction(TransactionSelector.newBuilder().setId(txId)) + .build(); + private final ExecuteSqlRequest executeRequestWithResumeToken = + executeRequestWithoutResumeToken.toBuilder().setResumeToken(resumeToken).build(); + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + when(session.getName()).thenReturn(sessionId); + when(session.getOptions()).thenReturn(Collections.EMPTY_MAP); + when(rpc.beginTransaction(any(BeginTransactionRequest.class), anyMap())) + .thenReturn(Transaction.newBuilder().setId(txId).build()); + } + + @Test + public void testExecuteStreamingPartitionedUpdate() { + ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build(); + PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build(); + PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build(); + ServerStream stream = mock(ServerStream.class); + when(stream.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator()); + when(rpc.executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class))) + .thenReturn(stream); + + PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc); + long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10)); + assertThat(count).isEqualTo(1000L); + verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap()); + verify(rpc) + .executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)); + } + + @Test + public void testExecuteStreamingPartitionedUpdateAborted() { + ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build(); + PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build(); + PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build(); + ServerStream stream1 = mock(ServerStream.class); + Iterator iterator = mock(Iterator.class); + when(iterator.hasNext()).thenReturn(true, true, false); + when(iterator.next()) + .thenReturn(p1) + .thenThrow( + new AbortedException( + "transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true)); + when(stream1.iterator()).thenReturn(iterator); + ServerStream stream2 = mock(ServerStream.class); + when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator()); + when(rpc.executeStreamingPartitionedDml( + any(ExecuteSqlRequest.class), anyMap(), any(Duration.class))) + .thenReturn(stream1, stream2); + + PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc); + long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10)); + assertThat(count).isEqualTo(1000L); + verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap()); + verify(rpc, times(2)) + .executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)); + } + + @Test + public void testExecuteStreamingPartitionedUpdateUnavailable() { + ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build(); + PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build(); + PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build(); + ServerStream stream1 = mock(ServerStream.class); + Iterator iterator = mock(Iterator.class); + when(iterator.hasNext()).thenReturn(true, true, false); + when(iterator.next()) + .thenReturn(p1) + .thenThrow( + new UnavailableException( + "temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true)); + when(stream1.iterator()).thenReturn(iterator); + ServerStream stream2 = mock(ServerStream.class); + when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator()); + when(rpc.executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class))) + .thenReturn(stream1); + when(rpc.executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class))) + .thenReturn(stream2); + + PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc); + long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10)); + assertThat(count).isEqualTo(1000L); + verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap()); + verify(rpc) + .executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)); + verify(rpc) + .executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class)); + } + + @Test + public void testExecuteStreamingPartitionedUpdateUnavailableAndThenDeadlineExceeded() { + PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build(); + ServerStream stream1 = mock(ServerStream.class); + Iterator iterator = mock(Iterator.class); + when(iterator.hasNext()).thenReturn(true, true, false); + when(iterator.next()) + .thenReturn(p1) + .thenThrow( + new UnavailableException( + "temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true)); + when(stream1.iterator()).thenReturn(iterator); + when(rpc.executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class))) + .thenReturn(stream1); + + PartitionedDMLTransaction tx = + new PartitionedDMLTransaction(session, rpc) { + @Override + Stopwatch createStopwatchStarted() { + Ticker ticker = mock(Ticker.class); + when(ticker.read()) + .thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES)); + return Stopwatch.createStarted(ticker); + } + }; + try { + tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10)); + fail("missing expected DEADLINE_EXCEEDED exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED); + verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap()); + verify(rpc) + .executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)); + } + } + + @Test + public void testExecuteStreamingPartitionedUpdateAbortedAndThenDeadlineExceeded() { + PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build(); + ServerStream stream1 = mock(ServerStream.class); + Iterator iterator = mock(Iterator.class); + when(iterator.hasNext()).thenReturn(true, true, false); + when(iterator.next()) + .thenReturn(p1) + .thenThrow( + new AbortedException( + "transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true)); + when(stream1.iterator()).thenReturn(iterator); + when(rpc.executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class))) + .thenReturn(stream1); + + PartitionedDMLTransaction tx = + new PartitionedDMLTransaction(session, rpc) { + @Override + Stopwatch createStopwatchStarted() { + Ticker ticker = mock(Ticker.class); + when(ticker.read()) + .thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES)); + return Stopwatch.createStarted(ticker); + } + }; + try { + tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10)); + fail("missing expected DEADLINE_EXCEEDED exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED); + verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap()); + verify(rpc) + .executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)); + } + } + + @Test + public void testExecuteStreamingPartitionedUpdateMultipleAbortsUntilDeadlineExceeded() { + PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build(); + ServerStream stream1 = mock(ServerStream.class); + Iterator iterator = mock(Iterator.class); + when(iterator.hasNext()).thenReturn(true); + when(iterator.next()) + .thenReturn(p1) + .thenThrow( + new AbortedException( + "transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true)); + when(stream1.iterator()).thenReturn(iterator); + when(rpc.executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class))) + .thenReturn(stream1); + + PartitionedDMLTransaction tx = + new PartitionedDMLTransaction(session, rpc) { + long ticks = 0L; + + @Override + Stopwatch createStopwatchStarted() { + Ticker ticker = mock(Ticker.class); + when(ticker.read()) + .thenAnswer( + new Answer() { + @Override + public Long answer(InvocationOnMock invocation) throws Throwable { + return TimeUnit.NANOSECONDS.convert(++ticks, TimeUnit.MINUTES); + } + }); + return Stopwatch.createStarted(ticker); + } + }; + try { + tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10)); + fail("missing expected DEADLINE_EXCEEDED exception"); + } catch (SpannerException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED); + // It should start a transaction exactly 10 times (10 ticks == 10 minutes). + verify(rpc, times(10)).beginTransaction(any(BeginTransactionRequest.class), anyMap()); + // The last transaction should timeout before it starts the actual statement execution, which + // means that the execute method is only executed 9 times. + verify(rpc, times(9)) + .executeStreamingPartitionedDml( + Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)); + } + } +}