Skip to content

Commit

Permalink
fix: retry Query streams (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidt-sebastian committed Oct 21, 2020
1 parent 078dd57 commit 3513cd3
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 9 deletions.
Expand Up @@ -32,9 +32,11 @@
import com.google.api.core.InternalExtensionOnly;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.StatusCode;
import com.google.auto.value.AutoValue;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.Query.QueryOptions.Builder;
import com.google.cloud.firestore.v1.FirestoreSettings;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -52,14 +54,17 @@
import com.google.firestore.v1.Value;
import com.google.protobuf.ByteString;
import com.google.protobuf.Int32Value;
import io.grpc.Status;
import io.opencensus.trace.AttributeValue;
import io.opencensus.trace.Tracing;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -1297,7 +1302,8 @@ public void onCompleted() {
responseObserver.onCompleted();
}
},
null);
/* transactionId= */ null,
/* readTime= */ null);
}

/**
Expand Down Expand Up @@ -1431,13 +1437,18 @@ Timestamp getReadTime() {
}

private void internalStream(
final QuerySnapshotObserver documentObserver, @Nullable ByteString transactionId) {
final QuerySnapshotObserver documentObserver,
@Nullable final ByteString transactionId,
@Nullable final Timestamp readTime) {
RunQueryRequest.Builder request = RunQueryRequest.newBuilder();
request.setStructuredQuery(buildQuery()).setParent(options.getParentPath().toString());

if (transactionId != null) {
request.setTransaction(transactionId);
}
if (readTime != null) {
request.setReadTime(readTime.toProto());
}

Tracing.getTracer()
.getCurrentSpan()
Expand All @@ -1446,6 +1457,8 @@ private void internalStream(
ImmutableMap.of(
"transactional", AttributeValue.booleanAttributeValue(transactionId != null)));

final AtomicReference<QueryDocumentSnapshot> lastReceivedDocument = new AtomicReference<>();

ApiStreamObserver<RunQueryResponse> observer =
new ApiStreamObserver<RunQueryResponse>() {
Timestamp readTime;
Expand All @@ -1470,6 +1483,7 @@ public void onNext(RunQueryResponse response) {
QueryDocumentSnapshot.fromDocument(
rpcContext, Timestamp.fromProto(response.getReadTime()), document);
documentObserver.onNext(documentSnapshot);
lastReceivedDocument.set(documentSnapshot);
}

if (readTime == null) {
Expand All @@ -1479,8 +1493,27 @@ public void onNext(RunQueryResponse response) {

@Override
public void onError(Throwable throwable) {
Tracing.getTracer().getCurrentSpan().addAnnotation("Firestore.Query: Error");
documentObserver.onError(throwable);
// If a non-transactional query failed, attempt to restart.
// Transactional queries are retried via the transaction runner.
if (transactionId == null && isRetryableError(throwable)) {
Tracing.getTracer()
.getCurrentSpan()
.addAnnotation("Firestore.Query: Retryable Error");

// Restart the query but use the last document we received as the
// query cursor. Note that this it is ok to not use backoff here
// since we are requiring at least a single document result.
QueryDocumentSnapshot cursor = lastReceivedDocument.get();
if (cursor != null) {
Query.this
.startAfter(cursor)
.internalStream(
documentObserver, /* transactionId= */ null, cursor.getReadTime());
}
} else {
Tracing.getTracer().getCurrentSpan().addAnnotation("Firestore.Query: Error");
documentObserver.onError(throwable);
}
}

@Override
Expand Down Expand Up @@ -1562,7 +1595,8 @@ public void onCompleted() {
result.set(querySnapshot);
}
},
transactionId);
transactionId,
/* readTime= */ null);

return result;
}
Expand Down Expand Up @@ -1624,6 +1658,22 @@ private <T> ImmutableList<T> append(ImmutableList<T> existingList, T newElement)
return builder.build();
}

/** Verifies whether the given exception is retryable based on the RunQuery configuration. */
private boolean isRetryableError(Throwable throwable) {
if (!(throwable instanceof FirestoreException)) {
return false;
}
Set<StatusCode.Code> codes =
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
Status status = ((FirestoreException) throwable).getStatus();
for (StatusCode.Code code : codes) {
if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) {
return true;
}
}
return false;
}

/**
* Returns true if this Query is equal to the provided object.
*
Expand Down
Expand Up @@ -347,7 +347,10 @@ public void notFound() throws Exception {
getDocumentResponse.setReadTime(
com.google.protobuf.Timestamp.newBuilder().setSeconds(5).setNanos(6));

doAnswer(streamingResponse(getDocumentResponse.build()))
doAnswer(
streamingResponse(
new BatchGetDocumentsResponse[] {getDocumentResponse.build()},
/* throwable= */ null))
.when(firestoreMock)
.streamRequest(
getAllCapture.capture(),
Expand Down
Expand Up @@ -279,7 +279,7 @@ public static Answer<BatchGetDocumentsResponse> getAllResponse(
responses[i] = response.build();
}

return streamingResponse(responses);
return streamingResponse(responses, null);
}

public static ApiFuture<Empty> rollbackResponse() {
Expand All @@ -291,6 +291,12 @@ public static Answer<RunQueryResponse> queryResponse() {
}

public static Answer<RunQueryResponse> queryResponse(String... documentNames) {
return queryResponse(/* throwable= */ null, documentNames);
}

/** Returns a stream of documents followed by an optional exception. */
public static Answer<RunQueryResponse> queryResponse(
@Nullable Throwable throwable, String... documentNames) {
RunQueryResponse[] responses = new RunQueryResponse[documentNames.length];

for (int i = 0; i < documentNames.length; ++i) {
Expand All @@ -301,17 +307,23 @@ public static Answer<RunQueryResponse> queryResponse(String... documentNames) {
com.google.protobuf.Timestamp.newBuilder().setSeconds(1).setNanos(2));
responses[i] = runQueryResponse.build();
}
return streamingResponse(responses);

return streamingResponse(responses, throwable);
}

public static <T> Answer<T> streamingResponse(final T... response) {
/** Returns a stream of responses followed by an optional exception. */
public static <T> Answer<T> streamingResponse(
final T[] response, @Nullable final Throwable throwable) {
return new Answer<T>() {
public T answer(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
ApiStreamObserver<T> observer = (ApiStreamObserver<T>) args[1];
for (T resp : response) {
observer.onNext(resp);
}
if (throwable != null) {
observer.onError(throwable);
}
observer.onCompleted();
return null;
}
Expand Down
Expand Up @@ -46,17 +46,20 @@
import com.google.common.io.BaseEncoding;
import com.google.firestore.v1.ArrayValue;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.RunQueryResponse;
import com.google.firestore.v1.StructuredQuery;
import com.google.firestore.v1.StructuredQuery.Direction;
import com.google.firestore.v1.StructuredQuery.FieldFilter.Operator;
import com.google.firestore.v1.Value;
import com.google.protobuf.InvalidProtocolBufferException;
import io.grpc.Status;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Semaphore;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -66,7 +69,9 @@
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;

@RunWith(MockitoJUnitRunner.class)
public class QueryTest {
Expand Down Expand Up @@ -902,6 +907,113 @@ public void onCompleted() {
semaphore.acquire();
}

@Test
public void retriesAfterRetryableError() throws Exception {
final boolean[] returnError = new boolean[] {true};

doAnswer(
new Answer<RunQueryResponse>() {
public RunQueryResponse answer(InvocationOnMock invocation) throws Throwable {
if (returnError[0]) {
returnError[0] = false;
return queryResponse(
FirestoreException.serverRejected(
Status.DEADLINE_EXCEEDED, "Simulated test failure"),
DOCUMENT_NAME + "1",
DOCUMENT_NAME + "2")
.answer(invocation);
} else {
return queryResponse(DOCUMENT_NAME + "3").answer(invocation);
}
}
})
.when(firestoreMock)
.streamRequest(
runQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

// Verify the responses
final Semaphore semaphore = new Semaphore(0);
final Iterator<String> iterator = Arrays.asList("doc1", "doc2", "doc3").iterator();

query.stream(
new ApiStreamObserver<DocumentSnapshot>() {
@Override
public void onNext(DocumentSnapshot documentSnapshot) {
assertEquals(iterator.next(), documentSnapshot.getId());
}

@Override
public void onError(Throwable throwable) {
fail();
}

@Override
public void onCompleted() {
semaphore.release();
}
});

semaphore.acquire();

// Verify the requests
List<RunQueryRequest> requests = runQuery.getAllValues();
assertEquals(2, requests.size());

assertFalse(requests.get(0).hasReadTime());
assertFalse(requests.get(0).getStructuredQuery().hasStartAt());

assertEquals(
com.google.protobuf.Timestamp.newBuilder().setSeconds(1).setNanos(2).build(),
requests.get(1).getReadTime());
assertFalse(requests.get(1).getStructuredQuery().getStartAt().getBefore());
assertEquals(
DOCUMENT_NAME + "2",
requests.get(1).getStructuredQuery().getStartAt().getValues(0).getReferenceValue());
}

@Test
public void doesNotRetryAfterNonRetryableError() throws Exception {
doAnswer(
queryResponse(
FirestoreException.serverRejected(
Status.PERMISSION_DENIED, "Simulated test failure"),
DOCUMENT_NAME + "1",
DOCUMENT_NAME + "2"))
.when(firestoreMock)
.streamRequest(
runQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

// Verify the responses
final Semaphore semaphore = new Semaphore(0);
final Iterator<String> iterator = Arrays.asList("doc1", "doc2").iterator();

query.stream(
new ApiStreamObserver<DocumentSnapshot>() {
@Override
public void onNext(DocumentSnapshot documentSnapshot) {
assertEquals(iterator.next(), documentSnapshot.getId());
}

@Override
public void onError(Throwable throwable) {
semaphore.release();
}

@Override
public void onCompleted() {}
});

semaphore.acquire();

// Verify the request count
List<RunQueryRequest> requests = runQuery.getAllValues();
assertEquals(1, runQuery.getAllValues().size());
}

@Test
public void equalsTest() {
assertEquals(query.limit(42).offset(1337), query.offset(1337).limit(42));
Expand Down

0 comments on commit 3513cd3

Please sign in to comment.