From 3513cd39ff43d26c8432c05ce20693350539ae8f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 21 Oct 2020 14:36:52 -0700 Subject: [PATCH] fix: retry Query streams (#426) --- .../com/google/cloud/firestore/Query.java | 60 +++++++++- .../firestore/DocumentReferenceTest.java | 5 +- .../cloud/firestore/LocalFirestoreHelper.java | 18 ++- .../com/google/cloud/firestore/QueryTest.java | 112 ++++++++++++++++++ 4 files changed, 186 insertions(+), 9 deletions(-) diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java index e236c40ea..1d57f64aa 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java @@ -32,9 +32,11 @@ import com.google.api.core.InternalExtensionOnly; import com.google.api.core.SettableApiFuture; import com.google.api.gax.rpc.ApiStreamObserver; +import com.google.api.gax.rpc.StatusCode; import com.google.auto.value.AutoValue; import com.google.cloud.Timestamp; import com.google.cloud.firestore.Query.QueryOptions.Builder; +import com.google.cloud.firestore.v1.FirestoreSettings; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -52,6 +54,7 @@ import com.google.firestore.v1.Value; import com.google.protobuf.ByteString; import com.google.protobuf.Int32Value; +import io.grpc.Status; import io.opencensus.trace.AttributeValue; import io.opencensus.trace.Tracing; import java.util.ArrayList; @@ -59,7 +62,9 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; +import java.util.Set; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -1297,7 +1302,8 @@ public void onCompleted() { responseObserver.onCompleted(); } }, - null); + /* transactionId= */ null, + /* readTime= */ null); } /** @@ -1431,13 +1437,18 @@ Timestamp getReadTime() { } private void internalStream( - final QuerySnapshotObserver documentObserver, @Nullable ByteString transactionId) { + final QuerySnapshotObserver documentObserver, + @Nullable final ByteString transactionId, + @Nullable final Timestamp readTime) { RunQueryRequest.Builder request = RunQueryRequest.newBuilder(); request.setStructuredQuery(buildQuery()).setParent(options.getParentPath().toString()); if (transactionId != null) { request.setTransaction(transactionId); } + if (readTime != null) { + request.setReadTime(readTime.toProto()); + } Tracing.getTracer() .getCurrentSpan() @@ -1446,6 +1457,8 @@ private void internalStream( ImmutableMap.of( "transactional", AttributeValue.booleanAttributeValue(transactionId != null))); + final AtomicReference lastReceivedDocument = new AtomicReference<>(); + ApiStreamObserver observer = new ApiStreamObserver() { Timestamp readTime; @@ -1470,6 +1483,7 @@ public void onNext(RunQueryResponse response) { QueryDocumentSnapshot.fromDocument( rpcContext, Timestamp.fromProto(response.getReadTime()), document); documentObserver.onNext(documentSnapshot); + lastReceivedDocument.set(documentSnapshot); } if (readTime == null) { @@ -1479,8 +1493,27 @@ public void onNext(RunQueryResponse response) { @Override public void onError(Throwable throwable) { - Tracing.getTracer().getCurrentSpan().addAnnotation("Firestore.Query: Error"); - documentObserver.onError(throwable); + // If a non-transactional query failed, attempt to restart. + // Transactional queries are retried via the transaction runner. + if (transactionId == null && isRetryableError(throwable)) { + Tracing.getTracer() + .getCurrentSpan() + .addAnnotation("Firestore.Query: Retryable Error"); + + // Restart the query but use the last document we received as the + // query cursor. Note that this it is ok to not use backoff here + // since we are requiring at least a single document result. + QueryDocumentSnapshot cursor = lastReceivedDocument.get(); + if (cursor != null) { + Query.this + .startAfter(cursor) + .internalStream( + documentObserver, /* transactionId= */ null, cursor.getReadTime()); + } + } else { + Tracing.getTracer().getCurrentSpan().addAnnotation("Firestore.Query: Error"); + documentObserver.onError(throwable); + } } @Override @@ -1562,7 +1595,8 @@ public void onCompleted() { result.set(querySnapshot); } }, - transactionId); + transactionId, + /* readTime= */ null); return result; } @@ -1624,6 +1658,22 @@ private ImmutableList append(ImmutableList existingList, T newElement) return builder.build(); } + /** Verifies whether the given exception is retryable based on the RunQuery configuration. */ + private boolean isRetryableError(Throwable throwable) { + if (!(throwable instanceof FirestoreException)) { + return false; + } + Set codes = + FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes(); + Status status = ((FirestoreException) throwable).getStatus(); + for (StatusCode.Code code : codes) { + if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) { + return true; + } + } + return false; + } + /** * Returns true if this Query is equal to the provided object. * diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/DocumentReferenceTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/DocumentReferenceTest.java index 0cce72d07..85433b5c8 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/DocumentReferenceTest.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/DocumentReferenceTest.java @@ -347,7 +347,10 @@ public void notFound() throws Exception { getDocumentResponse.setReadTime( com.google.protobuf.Timestamp.newBuilder().setSeconds(5).setNanos(6)); - doAnswer(streamingResponse(getDocumentResponse.build())) + doAnswer( + streamingResponse( + new BatchGetDocumentsResponse[] {getDocumentResponse.build()}, + /* throwable= */ null)) .when(firestoreMock) .streamRequest( getAllCapture.capture(), diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java index ad7c5c2dc..85a2690f3 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java @@ -279,7 +279,7 @@ public static Answer getAllResponse( responses[i] = response.build(); } - return streamingResponse(responses); + return streamingResponse(responses, null); } public static ApiFuture rollbackResponse() { @@ -291,6 +291,12 @@ public static Answer queryResponse() { } public static Answer queryResponse(String... documentNames) { + return queryResponse(/* throwable= */ null, documentNames); + } + + /** Returns a stream of documents followed by an optional exception. */ + public static Answer queryResponse( + @Nullable Throwable throwable, String... documentNames) { RunQueryResponse[] responses = new RunQueryResponse[documentNames.length]; for (int i = 0; i < documentNames.length; ++i) { @@ -301,10 +307,13 @@ public static Answer queryResponse(String... documentNames) { com.google.protobuf.Timestamp.newBuilder().setSeconds(1).setNanos(2)); responses[i] = runQueryResponse.build(); } - return streamingResponse(responses); + + return streamingResponse(responses, throwable); } - public static Answer streamingResponse(final T... response) { + /** Returns a stream of responses followed by an optional exception. */ + public static Answer streamingResponse( + final T[] response, @Nullable final Throwable throwable) { return new Answer() { public T answer(InvocationOnMock invocation) { Object[] args = invocation.getArguments(); @@ -312,6 +321,9 @@ public T answer(InvocationOnMock invocation) { for (T resp : response) { observer.onNext(resp); } + if (throwable != null) { + observer.onError(throwable); + } observer.onCompleted(); return null; } diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/QueryTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/QueryTest.java index a7b6b5716..6c0d4e292 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/QueryTest.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/QueryTest.java @@ -46,17 +46,20 @@ import com.google.common.io.BaseEncoding; import com.google.firestore.v1.ArrayValue; import com.google.firestore.v1.RunQueryRequest; +import com.google.firestore.v1.RunQueryResponse; import com.google.firestore.v1.StructuredQuery; import com.google.firestore.v1.StructuredQuery.Direction; import com.google.firestore.v1.StructuredQuery.FieldFilter.Operator; import com.google.firestore.v1.Value; import com.google.protobuf.InvalidProtocolBufferException; +import io.grpc.Status; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; +import java.util.List; import java.util.concurrent.Semaphore; import org.junit.Before; import org.junit.Test; @@ -66,7 +69,9 @@ import org.mockito.Matchers; import org.mockito.Mockito; import org.mockito.Spy; +import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; @RunWith(MockitoJUnitRunner.class) public class QueryTest { @@ -902,6 +907,113 @@ public void onCompleted() { semaphore.acquire(); } + @Test + public void retriesAfterRetryableError() throws Exception { + final boolean[] returnError = new boolean[] {true}; + + doAnswer( + new Answer() { + public RunQueryResponse answer(InvocationOnMock invocation) throws Throwable { + if (returnError[0]) { + returnError[0] = false; + return queryResponse( + FirestoreException.serverRejected( + Status.DEADLINE_EXCEEDED, "Simulated test failure"), + DOCUMENT_NAME + "1", + DOCUMENT_NAME + "2") + .answer(invocation); + } else { + return queryResponse(DOCUMENT_NAME + "3").answer(invocation); + } + } + }) + .when(firestoreMock) + .streamRequest( + runQuery.capture(), + streamObserverCapture.capture(), + Matchers.any()); + + // Verify the responses + final Semaphore semaphore = new Semaphore(0); + final Iterator iterator = Arrays.asList("doc1", "doc2", "doc3").iterator(); + + query.stream( + new ApiStreamObserver() { + @Override + public void onNext(DocumentSnapshot documentSnapshot) { + assertEquals(iterator.next(), documentSnapshot.getId()); + } + + @Override + public void onError(Throwable throwable) { + fail(); + } + + @Override + public void onCompleted() { + semaphore.release(); + } + }); + + semaphore.acquire(); + + // Verify the requests + List requests = runQuery.getAllValues(); + assertEquals(2, requests.size()); + + assertFalse(requests.get(0).hasReadTime()); + assertFalse(requests.get(0).getStructuredQuery().hasStartAt()); + + assertEquals( + com.google.protobuf.Timestamp.newBuilder().setSeconds(1).setNanos(2).build(), + requests.get(1).getReadTime()); + assertFalse(requests.get(1).getStructuredQuery().getStartAt().getBefore()); + assertEquals( + DOCUMENT_NAME + "2", + requests.get(1).getStructuredQuery().getStartAt().getValues(0).getReferenceValue()); + } + + @Test + public void doesNotRetryAfterNonRetryableError() throws Exception { + doAnswer( + queryResponse( + FirestoreException.serverRejected( + Status.PERMISSION_DENIED, "Simulated test failure"), + DOCUMENT_NAME + "1", + DOCUMENT_NAME + "2")) + .when(firestoreMock) + .streamRequest( + runQuery.capture(), + streamObserverCapture.capture(), + Matchers.any()); + + // Verify the responses + final Semaphore semaphore = new Semaphore(0); + final Iterator iterator = Arrays.asList("doc1", "doc2").iterator(); + + query.stream( + new ApiStreamObserver() { + @Override + public void onNext(DocumentSnapshot documentSnapshot) { + assertEquals(iterator.next(), documentSnapshot.getId()); + } + + @Override + public void onError(Throwable throwable) { + semaphore.release(); + } + + @Override + public void onCompleted() {} + }); + + semaphore.acquire(); + + // Verify the request count + List requests = runQuery.getAllValues(); + assertEquals(1, runQuery.getAllValues().size()); + } + @Test public void equalsTest() { assertEquals(query.limit(42).offset(1337), query.offset(1337).limit(42));