From 5905438af6501353e978210808834a26947aae95 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 5 Nov 2020 14:00:28 -0800 Subject: [PATCH] fix: retry transactions that fail with expired transaction IDs (#447) --- .../google/cloud/firestore/Transaction.java | 15 +- .../cloud/firestore/TransactionRunner.java | 105 ++++----- .../cloud/firestore/BulkWriterTest.java | 55 +---- .../cloud/firestore/LocalFirestoreHelper.java | 82 +++---- .../cloud/firestore/TransactionTest.java | 201 +++++++++++------- 5 files changed, 224 insertions(+), 234 deletions(-) diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Transaction.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Transaction.java index e8308b57d..26eec365e 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Transaction.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Transaction.java @@ -65,7 +65,6 @@ public interface AsyncFunction { } private final TransactionOptions transactionOptions; - @Nullable private final ByteString previousTransactionId; private ByteString transactionId; Transaction( @@ -74,8 +73,11 @@ public interface AsyncFunction { @Nullable Transaction previousTransaction) { super(firestore); this.transactionOptions = transactionOptions; - this.previousTransactionId = - previousTransaction != null ? previousTransaction.transactionId : null; + this.transactionId = previousTransaction != null ? previousTransaction.transactionId : null; + } + + public boolean hasTransactionId() { + return transactionId != null; } Transaction wrapResult(ApiFuture result) { @@ -89,11 +91,8 @@ ApiFuture begin() { beginTransaction.setDatabase(firestore.getDatabaseName()); if (TransactionOptionsType.READ_WRITE.equals(transactionOptions.getType()) - && previousTransactionId != null) { - beginTransaction - .getOptionsBuilder() - .getReadWriteBuilder() - .setRetryTransaction(previousTransactionId); + && transactionId != null) { + beginTransaction.getOptionsBuilder().getReadWriteBuilder().setRetryTransaction(transactionId); } else if (TransactionOptionsType.READ_ONLY.equals(transactionOptions.getType())) { final ReadOnly.Builder readOnlyBuilder = ReadOnly.newBuilder(); if (transactionOptions.getReadTime() != null) { diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/TransactionRunner.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/TransactionRunner.java index 2c6a8c735..6babd3bad 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/TransactionRunner.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/TransactionRunner.java @@ -105,23 +105,40 @@ ApiFuture run() { "Start runTransaction", ImmutableMap.of("attemptsRemaining", AttributeValue.longAttributeValue(attemptsRemaining))); - final SettableApiFuture backoff = SettableApiFuture.create(); + return ApiFutures.catchingAsync( + ApiFutures.transformAsync( + maybeRollback(), new RollbackCallback(), MoreExecutors.directExecutor()), + Throwable.class, + new RestartTransactionCallback(), + MoreExecutors.directExecutor()); + } - // Add a backoff delay. At first, this is 0. - this.firestoreExecutor.schedule( - new Runnable() { - @Override - public void run() { - backoff.set(null); - } - }, - nextBackoffAttempt.getRandomizedRetryDelay().toMillis(), - TimeUnit.MILLISECONDS); + private ApiFuture maybeRollback() { + return transaction.hasTransactionId() + ? transaction.rollback() + : ApiFutures.immediateFuture(null); + } - nextBackoffAttempt = backoffAlgorithm.createNextAttempt(nextBackoffAttempt); + /** A callback that invokes the BeginTransaction callback. */ + private class RollbackCallback implements ApiAsyncFunction { + @Override + public ApiFuture apply(Void input) { + final SettableApiFuture backoff = SettableApiFuture.create(); + // Add a backoff delay. At first, this is 0. + firestoreExecutor.schedule( + new Runnable() { + @Override + public void run() { + backoff.set(null); + } + }, + nextBackoffAttempt.getRandomizedRetryDelay().toMillis(), + TimeUnit.MILLISECONDS); - return ApiFutures.transformAsync( - backoff, new BackoffCallback(), MoreExecutors.directExecutor()); + nextBackoffAttempt = backoffAlgorithm.createNextAttempt(nextBackoffAttempt); + return ApiFutures.transformAsync( + backoff, new BackoffCallback(), MoreExecutors.directExecutor()); + } } /** @@ -138,7 +155,6 @@ public void run() { new ApiFutureCallback() { @Override public void onFailure(Throwable t) { - callbackResult.setException(t); } @@ -168,12 +184,8 @@ public ApiFuture apply(Void input) { */ private class BeginTransactionCallback implements ApiAsyncFunction { public ApiFuture apply(Void ignored) { - return ApiFutures.catchingAsync( - ApiFutures.transformAsync( - invokeUserCallback(), new UserFunctionCallback(), MoreExecutors.directExecutor()), - Throwable.class, - new RestartTransactionCallback(), - MoreExecutors.directExecutor()); + return ApiFutures.transformAsync( + invokeUserCallback(), new UserFunctionCallback(), MoreExecutors.directExecutor()); } } @@ -217,10 +229,10 @@ public ApiFuture apply(Throwable throwable) { } ApiException apiException = (ApiException) throwable; - if (isRetryableTransactionError(apiException)) { + if (transaction.hasTransactionId() && isRetryableTransactionError(apiException)) { if (attemptsRemaining > 0) { span.addAnnotation("retrying"); - return rollbackAndContinue(); + return run(); } else { span.setStatus(TOO_MANY_RETRIES_STATUS); final FirestoreException firestoreException = @@ -251,39 +263,36 @@ private boolean isRetryableTransactionError(ApiException exception) { case UNAUTHENTICATED: case RESOURCE_EXHAUSTED: return true; + case INVALID_ARGUMENT: + // The Firestore backend uses "INVALID_ARGUMENT" for transactions IDs that have expired. + // While INVALID_ARGUMENT is generally not retryable, we retry this specific case. + return exception.getMessage().contains("transaction has expired"); default: return false; } } - /** Rolls the transaction back and attempts it again. */ - private ApiFuture rollbackAndContinue() { - return ApiFutures.transformAsync( - transaction.rollback(), - new ApiAsyncFunction() { - @Override - public ApiFuture apply(Void input) { - return run(); - } - }, - MoreExecutors.directExecutor()); - } - /** Rolls the transaction back and returns the error. */ private ApiFuture rollbackAndReject(final Throwable throwable) { final SettableApiFuture failedTransaction = SettableApiFuture.create(); - // We use `addListener()` since we want to return the original exception regardless of whether - // rollback() succeeds. - transaction - .rollback() - .addListener( - new Runnable() { - @Override - public void run() { - failedTransaction.setException(throwable); - } - }, - MoreExecutors.directExecutor()); + + if (transaction.hasTransactionId()) { + // We use `addListener()` since we want to return the original exception regardless of + // whether rollback() succeeds. + transaction + .rollback() + .addListener( + new Runnable() { + @Override + public void run() { + failedTransaction.setException(throwable); + } + }, + MoreExecutors.directExecutor()); + } else { + failedTransaction.setException(throwable); + } + span.end(); return failedTransaction; } diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/BulkWriterTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/BulkWriterTest.java index e4e4b0230..529f3b4b4 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/BulkWriterTest.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/BulkWriterTest.java @@ -36,7 +36,6 @@ import com.google.cloud.firestore.spi.v1.FirestoreRpc; import com.google.firestore.v1.BatchWriteRequest; import com.google.firestore.v1.BatchWriteResponse; -import com.google.protobuf.GeneratedMessageV3; import com.google.rpc.Code; import io.grpc.Status; import java.util.ArrayList; @@ -120,13 +119,6 @@ private ApiFuture mergeResponses(ApiFuture requests, ResponseStubber responseStubber) { - int index = 0; - for (GeneratedMessageV3 request : responseStubber.keySet()) { - assertEquals(request, requests.get(index++)); - } - } - @Before public void before() { doReturn(immediateExecutor).when(firestoreRpc).getExecutor(); @@ -150,10 +142,7 @@ public void hasSetMethod() throws Exception { ApiFuture result = bulkWriter.set(doc1, LocalFirestoreHelper.SINGLE_FIELD_MAP); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result.get().getUpdateTime()); } @@ -172,10 +161,7 @@ public void hasUpdateMethod() throws Exception { ApiFuture result = bulkWriter.update(doc1, LocalFirestoreHelper.SINGLE_FIELD_MAP); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result.get().getUpdateTime()); } @@ -192,10 +178,7 @@ public void hasDeleteMethod() throws Exception { ApiFuture result = bulkWriter.delete(doc1); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result.get().getUpdateTime()); } @@ -214,10 +197,7 @@ public void hasCreateMethod() throws Exception { ApiFuture result = bulkWriter.create(doc1, LocalFirestoreHelper.SINGLE_FIELD_MAP); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result.get().getUpdateTime()); } @@ -236,10 +216,7 @@ public void surfacesErrors() throws Exception { ApiFuture result = bulkWriter.set(doc1, LocalFirestoreHelper.SINGLE_FIELD_MAP); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); try { result.get(); fail("set() should have failed"); @@ -274,10 +251,7 @@ public void addsWritesToNewBatchAfterFlush() throws Exception { ApiFuture result2 = bulkWriter.set(doc2, LocalFirestoreHelper.SINGLE_FIELD_MAP); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); assertEquals(Timestamp.ofTimeSecondsAndNanos(1, 0), result1.get().getUpdateTime()); assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result2.get().getUpdateTime()); } @@ -350,10 +324,7 @@ public void canSendWritesToSameDocInSameBatch() throws Exception { bulkWriter.update(sameDoc, LocalFirestoreHelper.SINGLE_FIELD_MAP); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); assertEquals(Timestamp.ofTimeSecondsAndNanos(1, 0), result1.get().getUpdateTime()); assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result2.get().getUpdateTime()); } @@ -376,10 +347,7 @@ public void sendWritesToDifferentDocsInSameBatch() throws Exception { ApiFuture result2 = bulkWriter.update(doc2, LocalFirestoreHelper.SINGLE_FIELD_MAP); bulkWriter.close(); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); assertEquals(Timestamp.ofTimeSecondsAndNanos(1, 0), result1.get().getUpdateTime()); assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result2.get().getUpdateTime()); } @@ -413,9 +381,7 @@ public void sendBatchesWhenSizeLimitIsReached() throws Exception { assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result2.get().getUpdateTime()); assertEquals(Timestamp.ofTimeSecondsAndNanos(3, 0), result3.get().getUpdateTime()); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - verifyRequests(requests, responseStubber); + responseStubber.verifyAllRequestsSent(); } @Test @@ -462,8 +428,7 @@ public void retriesIndividualWritesThatFailWithAbortedOrUnavailable() throws Exc assertEquals(Timestamp.ofTimeSecondsAndNanos(2, 0), result2.get().getUpdateTime()); assertEquals(Timestamp.ofTimeSecondsAndNanos(3, 0), result3.get().getUpdateTime()); - List requests = batchWriteCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); + responseStubber.verifyAllRequestsSent(); } @Test diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java index 85a2690f3..5ebd32712 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java @@ -17,14 +17,12 @@ package com.google.cloud.firestore; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; import static org.mockito.Mockito.doAnswer; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; -import com.google.api.core.SettableApiFuture; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ApiStreamObserver; import com.google.api.gax.rpc.UnaryCallable; @@ -75,11 +73,9 @@ import java.util.Comparator; import java.util.Date; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.mockito.ArgumentCaptor; @@ -993,32 +989,41 @@ public static String fullPath(DocumentReference ref, FirestoreOptions options) { .toString(); } + static class RequestResponsePair { + GeneratedMessageV3 request; + ApiFuture response; + + public RequestResponsePair( + GeneratedMessageV3 request, ApiFuture response) { + this.request = request; + this.response = response; + } + } /** * Contains a map of request/response pairs that are used to create stub responses when * `sendRequest()` is called. */ - static class ResponseStubber - extends LinkedHashMap> { - - /** - * Verifies the response before returning. This method can be overridden to perform logic before - * the stubbed response is returned. - */ - ApiFuture verifyResponse( - ApiFuture response) { - return response; + static class ResponseStubber { + int requestCount = 0; + + List operationList = new ArrayList<>(); + + void put(GeneratedMessageV3 request, ApiFuture response) { + operationList.add(new RequestResponsePair(request, response)); } void initializeStub( ArgumentCaptor argumentCaptor, FirestoreImpl firestoreMock) { Stubber stubber = null; - for (final ApiFuture response : values()) { + for (final RequestResponsePair entry : operationList) { Answer> answer = new Answer>() { @Override public ApiFuture answer( InvocationOnMock invocationOnMock) throws Throwable { - return verifyResponse(response); + ++requestCount; + assertEquals(entry.request, invocationOnMock.getArguments()[0]); + return entry.response; } }; stubber = (stubber != null) ? stubber.doAnswer(answer) : doAnswer(answer); @@ -1028,47 +1033,12 @@ public ApiFuture answer( .when(firestoreMock) .sendRequest(argumentCaptor.capture(), Matchers.>any()); } - } - /** - * Contains a map of request/response pairs that are used to create stub responses when - * `sendRequest()` is called. - * - *

Enforces that only one active request can be pending at a time. - */ - static class SerialResponseStubber extends ResponseStubber { - int activeRequestCounter = 0; - SettableApiFuture activeRequestComplete = SettableApiFuture.create(); - Semaphore semaphore = new Semaphore(0); - - void markAllRequestsComplete() { - activeRequestComplete.set(null); - } - - void awaitRequest() { - activeRequestComplete = SettableApiFuture.create(); - try { - semaphore.acquire(); - } catch (Exception e) { - fail("sempahore.acquire() should not fail"); - } - } - - @Override - ApiFuture verifyResponse( - ApiFuture response) { - ++activeRequestCounter; - - // This assert is used to test that only one request is made at a time. - assertEquals(1, activeRequestCounter); - try { - semaphore.release(); - activeRequestComplete.get(); - } catch (Exception e) { - fail("activeRequestComplete.get() should not fail"); - } - --activeRequestCounter; - return response; + public void verifyAllRequestsSent() { + assertEquals( + String.format("Expected %d requests, but got %d", operationList.size(), requestCount), + operationList.size(), + requestCount); } } diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/TransactionTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/TransactionTest.java index 25c60c5f4..a8405972a 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/TransactionTest.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/TransactionTest.java @@ -49,7 +49,6 @@ import com.google.api.gax.rpc.ApiException; import com.google.api.gax.rpc.ApiStreamObserver; import com.google.api.gax.rpc.ServerStreamingCallable; -import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.UnaryCallable; import com.google.cloud.Timestamp; import com.google.cloud.firestore.LocalFirestoreHelper.ResponseStubber; @@ -57,6 +56,7 @@ import com.google.cloud.firestore.TransactionOptions.ReadWriteOptionsBuilder; import com.google.cloud.firestore.TransactionOptions.TransactionOptionsType; import com.google.cloud.firestore.spi.v1.FirestoreRpc; +import com.google.common.base.Function; import com.google.firestore.v1.BatchGetDocumentsRequest; import com.google.firestore.v1.DocumentMask; import com.google.firestore.v1.Write; @@ -65,9 +65,7 @@ import io.grpc.Status; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; @@ -385,13 +383,7 @@ public String updateCallback(Transaction transaction) { assertTrue(e.getMessage().endsWith("Transaction was cancelled because of too many retries.")); } - List requests = requestCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - int index = 0; - for (GeneratedMessageV3 request : responseStubber.keySet()) { - assertEquals(request, requests.get(index++)); - } + responseStubber.verifyAllRequestsSent(); } @Test @@ -435,96 +427,68 @@ public String updateCallback(Transaction transaction) { assertEquals("foo6", transaction.get()); - List requests = requestCapture.getAllValues(); - assertEquals(responseStubber.size(), requests.size()); - - int index = 0; - for (GeneratedMessageV3 request : responseStubber.keySet()) { - assertEquals(request, requests.get(index++)); - } + responseStubber.verifyAllRequestsSent(); } - @Test - public void retriesBasedOnErrorCode() throws Exception { - Map retryBehavior = - new HashMap() { - { - put(Status.Code.CANCELLED, true); - put(Status.Code.UNKNOWN, true); - put(Status.Code.INVALID_ARGUMENT, false); - put(Status.Code.DEADLINE_EXCEEDED, true); - put(Status.Code.NOT_FOUND, false); - put(Status.Code.ALREADY_EXISTS, false); - put(Status.Code.RESOURCE_EXHAUSTED, true); - put(Status.Code.FAILED_PRECONDITION, false); - put(Status.Code.ABORTED, true); - put(Status.Code.OUT_OF_RANGE, false); - put(Status.Code.UNIMPLEMENTED, false); - put(Status.Code.INTERNAL, true); - put(Status.Code.UNAVAILABLE, true); - put(Status.Code.DATA_LOSS, false); - put(Status.Code.UNAUTHENTICATED, true); - } + private void verifyRetries( + Function expectedSequenceWithRetry, + Function expectedSequenceWithoutRetry) + throws ExecutionException, InterruptedException { + ApiException[] exceptionWithRetryBehavior = + new ApiException[] { + (exception(Status.Code.CANCELLED, true)), + (exception(Status.Code.UNKNOWN, true)), + (exception(Status.Code.INVALID_ARGUMENT, false)), + (exception( + Status.Code.INVALID_ARGUMENT, + "The referenced transaction has expired or is no longer valid.", + true)), + (exception(Status.Code.DEADLINE_EXCEEDED, true)), + (exception(Status.Code.NOT_FOUND, false)), + (exception(Status.Code.ALREADY_EXISTS, false)), + (exception(Status.Code.RESOURCE_EXHAUSTED, true)), + (exception(Status.Code.FAILED_PRECONDITION, false)), + (exception(Status.Code.ABORTED, true)), + (exception(Status.Code.OUT_OF_RANGE, false)), + (exception(Status.Code.UNIMPLEMENTED, false)), + (exception(Status.Code.INTERNAL, true)), + (exception(Status.Code.UNAVAILABLE, true)), + (exception(Status.Code.DATA_LOSS, false)), + (exception(Status.Code.UNAUTHENTICATED, true)) }; - for (Map.Entry entry : retryBehavior.entrySet()) { - StatusCode code = GrpcStatusCode.of(entry.getKey()); - boolean shouldRetry = entry.getValue(); - - final ApiException apiException = - new ApiException(new Exception("Test Exception"), code, shouldRetry); - - if (shouldRetry) { - ResponseStubber responseStubber = - new ResponseStubber() { - { - put(begin(), beginResponse("foo1")); - put( - commit("foo1"), - ApiFutures.immediateFailedFuture(apiException)); - put(rollback("foo1"), rollbackResponse()); - put(begin("foo1"), beginResponse("foo2")); - put(commit("foo2"), commitResponse(0, 0)); - } - }; - responseStubber.initializeStub(requestCapture, firestoreMock); + for (ApiException apiException : exceptionWithRetryBehavior) { + if (apiException.isRetryable()) { + ResponseStubber stubber = expectedSequenceWithRetry.apply(apiException); + stubber.initializeStub(requestCapture, firestoreMock); final int[] attempts = new int[] {0}; - ApiFuture transaction = + ApiFuture transaction = firestoreMock.runTransaction( - new Transaction.Function() { + new Transaction.Function() { @Override - public String updateCallback(Transaction transaction) { + public Void updateCallback(Transaction transaction) { ++attempts[0]; return null; } }); transaction.get(); + stubber.verifyAllRequestsSent(); assertEquals(2, attempts[0]); } else { - ResponseStubber responseStubber = - new ResponseStubber() { - { - put(begin(), beginResponse("foo1")); - put( - commit("foo1"), - ApiFutures.immediateFailedFuture(apiException)); - put(rollback("foo1"), rollbackResponse()); - } - }; - - responseStubber.initializeStub(requestCapture, firestoreMock); + ResponseStubber stubber = expectedSequenceWithoutRetry.apply(apiException); + stubber.initializeStub(requestCapture, firestoreMock); final int[] attempts = new int[] {0}; - ApiFuture transaction = + ApiFuture transaction = firestoreMock.runTransaction( - new Transaction.Function() { + new Transaction.Function() { @Override - public String updateCallback(Transaction transaction) { + public Void updateCallback(Transaction transaction) { ++attempts[0]; return null; } @@ -536,11 +500,86 @@ public String updateCallback(Transaction transaction) { } catch (Exception ignored) { } + stubber.verifyAllRequestsSent(); assertEquals(1, attempts[0]); } } } + @Test + public void retriesCommitBasedOnErrorCode() throws Exception { + verifyRetries( + /* expectedSequenceWithRetry= */ new Function() { + @Override + public ResponseStubber apply(final ApiException e) { + return new ResponseStubber() { + { + put(begin(), beginResponse("foo1")); + put(commit("foo1"), ApiFutures.immediateFailedFuture(e)); + put(rollback("foo1"), rollbackResponse()); + put(begin("foo1"), beginResponse("foo2")); + put(commit("foo2"), commitResponse(0, 0)); + } + }; + } + }, + /* expectedSequenceWithoutRetry= */ new Function() { + @Override + public ResponseStubber apply(final ApiException e) { + return new ResponseStubber() { + { + put(begin(), beginResponse("foo1")); + put(commit("foo1"), ApiFutures.immediateFailedFuture(e)); + put(rollback("foo1"), rollbackResponse()); + } + }; + } + }); + } + + @Test + public void retriesRollbackBasedOnErrorCode() throws Exception { + final ApiException commitException = exception(Status.Code.ABORTED, true); + + verifyRetries( + /* expectedSequenceWithRetry= */ new Function() { + @Override + public ResponseStubber apply(final ApiException e) { + final ApiFuture rollbackException = + ApiFutures.immediateFailedFuture(e); + return new ResponseStubber() { + { + put(begin(), beginResponse("foo1")); + put( + commit("foo1"), + ApiFutures.immediateFailedFuture(commitException)); + put(rollback("foo1"), rollbackException); + put(rollback("foo1"), rollbackResponse()); + put(begin("foo1"), beginResponse("foo2")); + put(commit("foo2"), commitResponse(0, 0)); + } + }; + } + }, + /* expectedSequenceWithoutRetry= */ new Function() { + @Override + public ResponseStubber apply(final ApiException e) { + final ApiFuture rollbackException = + ApiFutures.immediateFailedFuture(e); + return new ResponseStubber() { + { + put(begin(), beginResponse("foo1")); + put( + commit("foo1"), + ApiFutures.immediateFailedFuture(commitException)); + put(rollback("foo1"), rollbackException); + put(rollback("foo1"), rollbackResponse()); + } + }; + } + }); + } + @Test public void getDocument() throws Exception { doReturn(beginResponse()) @@ -946,4 +985,12 @@ public void readWriteTransactionOptionsBuilder_errorAttemptingToSetNumAttemptsLe // expected } } + + private ApiException exception(Status.Code code, boolean shouldRetry) { + return exception(code, "Test exception", shouldRetry); + } + + private ApiException exception(Status.Code code, String message, boolean shouldRetry) { + return new ApiException(new Exception(message), GrpcStatusCode.of(code), shouldRetry); + } }