Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: retry Query streams #426

Merged
merged 2 commits into from Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -32,9 +32,11 @@
import com.google.api.core.InternalExtensionOnly;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.StatusCode;
import com.google.auto.value.AutoValue;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.Query.QueryOptions.Builder;
import com.google.cloud.firestore.v1.FirestoreSettings;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -52,14 +54,17 @@
import com.google.firestore.v1.Value;
import com.google.protobuf.ByteString;
import com.google.protobuf.Int32Value;
import io.grpc.Status;
import io.opencensus.trace.AttributeValue;
import io.opencensus.trace.Tracing;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -1297,7 +1302,8 @@ public void onCompleted() {
responseObserver.onCompleted();
}
},
null);
/* transactionId= */ null,
/* readTime= */ null);
}

/**
Expand Down Expand Up @@ -1431,13 +1437,18 @@ Timestamp getReadTime() {
}

private void internalStream(
final QuerySnapshotObserver documentObserver, @Nullable ByteString transactionId) {
final QuerySnapshotObserver documentObserver,
@Nullable final ByteString transactionId,
@Nullable final Timestamp readTime) {
RunQueryRequest.Builder request = RunQueryRequest.newBuilder();
request.setStructuredQuery(buildQuery()).setParent(options.getParentPath().toString());

if (transactionId != null) {
request.setTransaction(transactionId);
}
if (readTime != null) {
request.setReadTime(readTime.toProto());
}

Tracing.getTracer()
.getCurrentSpan()
Expand All @@ -1446,6 +1457,8 @@ private void internalStream(
ImmutableMap.of(
"transactional", AttributeValue.booleanAttributeValue(transactionId != null)));

final AtomicReference<QueryDocumentSnapshot> lastReceivedDocument = new AtomicReference<>();

ApiStreamObserver<RunQueryResponse> observer =
new ApiStreamObserver<RunQueryResponse>() {
Timestamp readTime;
Expand All @@ -1470,6 +1483,7 @@ public void onNext(RunQueryResponse response) {
QueryDocumentSnapshot.fromDocument(
rpcContext, Timestamp.fromProto(response.getReadTime()), document);
documentObserver.onNext(documentSnapshot);
lastReceivedDocument.set(documentSnapshot);
}

if (readTime == null) {
Expand All @@ -1479,8 +1493,27 @@ public void onNext(RunQueryResponse response) {

@Override
public void onError(Throwable throwable) {
Tracing.getTracer().getCurrentSpan().addAnnotation("Firestore.Query: Error");
documentObserver.onError(throwable);
// If a non-transactional query failed, attempt to restart.
// Transactional queries are retried via the transaction runner.
if (transactionId == null && isRetryableError(throwable)) {
Tracing.getTracer()
.getCurrentSpan()
.addAnnotation("Firestore.Query: Retryable Error");

// Restart the query but use the last document we received as the
// query cursor. Note that this it is ok to not use backoff here
// since we are requiring at least a single document result.
QueryDocumentSnapshot cursor = lastReceivedDocument.get();
if (cursor != null) {
Query.this
.startAfter(cursor)
.internalStream(
documentObserver, /* transactionId= */ null, cursor.getReadTime());
}
} else {
Tracing.getTracer().getCurrentSpan().addAnnotation("Firestore.Query: Error");
documentObserver.onError(throwable);
}
}

@Override
Expand Down Expand Up @@ -1562,7 +1595,8 @@ public void onCompleted() {
result.set(querySnapshot);
}
},
transactionId);
transactionId,
/* readTime= */ null);

return result;
}
Expand Down Expand Up @@ -1624,6 +1658,22 @@ private <T> ImmutableList<T> append(ImmutableList<T> existingList, T newElement)
return builder.build();
}

/** Verifies whether the given exception is retryable based on the RunQuery configuration. */
private boolean isRetryableError(Throwable throwable) {
if (!(throwable instanceof FirestoreException)) {
return false;
}
Set<StatusCode.Code> codes =
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
Status status = ((FirestoreException) throwable).getStatus();
for (StatusCode.Code code : codes) {
if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) {
return true;
}
}
return false;
}

/**
* Returns true if this Query is equal to the provided object.
*
Expand Down
Expand Up @@ -347,7 +347,10 @@ public void notFound() throws Exception {
getDocumentResponse.setReadTime(
com.google.protobuf.Timestamp.newBuilder().setSeconds(5).setNanos(6));

doAnswer(streamingResponse(getDocumentResponse.build()))
doAnswer(
streamingResponse(
new BatchGetDocumentsResponse[] {getDocumentResponse.build()},
/* throwable= */ null))
.when(firestoreMock)
.streamRequest(
getAllCapture.capture(),
Expand Down
Expand Up @@ -279,7 +279,7 @@ public static Answer<BatchGetDocumentsResponse> getAllResponse(
responses[i] = response.build();
}

return streamingResponse(responses);
return streamingResponse(responses, null);
}

public static ApiFuture<Empty> rollbackResponse() {
Expand All @@ -291,6 +291,12 @@ public static Answer<RunQueryResponse> queryResponse() {
}

public static Answer<RunQueryResponse> queryResponse(String... documentNames) {
return queryResponse(/* throwable= */ null, documentNames);
}

/** Returns a stream of documents followed by an optional exception. */
public static Answer<RunQueryResponse> queryResponse(
@Nullable Throwable throwable, String... documentNames) {
RunQueryResponse[] responses = new RunQueryResponse[documentNames.length];

for (int i = 0; i < documentNames.length; ++i) {
Expand All @@ -301,17 +307,23 @@ public static Answer<RunQueryResponse> queryResponse(String... documentNames) {
com.google.protobuf.Timestamp.newBuilder().setSeconds(1).setNanos(2));
responses[i] = runQueryResponse.build();
}
return streamingResponse(responses);

return streamingResponse(responses, throwable);
}

public static <T> Answer<T> streamingResponse(final T... response) {
/** Returns a stream of responses followed by an optional exception. */
public static <T> Answer<T> streamingResponse(
final T[] response, @Nullable final Throwable throwable) {
return new Answer<T>() {
public T answer(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
ApiStreamObserver<T> observer = (ApiStreamObserver<T>) args[1];
for (T resp : response) {
observer.onNext(resp);
}
if (throwable != null) {
observer.onError(throwable);
}
observer.onCompleted();
return null;
}
Expand Down
Expand Up @@ -46,17 +46,20 @@
import com.google.common.io.BaseEncoding;
import com.google.firestore.v1.ArrayValue;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.RunQueryResponse;
import com.google.firestore.v1.StructuredQuery;
import com.google.firestore.v1.StructuredQuery.Direction;
import com.google.firestore.v1.StructuredQuery.FieldFilter.Operator;
import com.google.firestore.v1.Value;
import com.google.protobuf.InvalidProtocolBufferException;
import io.grpc.Status;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Semaphore;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -66,7 +69,9 @@
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;

@RunWith(MockitoJUnitRunner.class)
public class QueryTest {
Expand Down Expand Up @@ -902,6 +907,113 @@ public void onCompleted() {
semaphore.acquire();
}

@Test
public void retriesAfterRetryableError() throws Exception {
final boolean[] returnError = new boolean[] {true};

doAnswer(
new Answer<RunQueryResponse>() {
public RunQueryResponse answer(InvocationOnMock invocation) throws Throwable {
if (returnError[0]) {
returnError[0] = false;
return queryResponse(
FirestoreException.serverRejected(
Status.DEADLINE_EXCEEDED, "Simulated test failure"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes sense to add a test to verify the no-retry behavior where it is not suppose to re-try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

DOCUMENT_NAME + "1",
DOCUMENT_NAME + "2")
.answer(invocation);
} else {
return queryResponse(DOCUMENT_NAME + "3").answer(invocation);
}
}
})
.when(firestoreMock)
.streamRequest(
runQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

// Verify the responses
final Semaphore semaphore = new Semaphore(0);
final Iterator<String> iterator = Arrays.asList("doc1", "doc2", "doc3").iterator();

query.stream(
new ApiStreamObserver<DocumentSnapshot>() {
@Override
public void onNext(DocumentSnapshot documentSnapshot) {
assertEquals(iterator.next(), documentSnapshot.getId());
}

@Override
public void onError(Throwable throwable) {
fail();
}

@Override
public void onCompleted() {
semaphore.release();
}
});

semaphore.acquire();

// Verify the requests
List<RunQueryRequest> requests = runQuery.getAllValues();
assertEquals(2, requests.size());

assertFalse(requests.get(0).hasReadTime());
assertFalse(requests.get(0).getStructuredQuery().hasStartAt());

assertEquals(
com.google.protobuf.Timestamp.newBuilder().setSeconds(1).setNanos(2).build(),
requests.get(1).getReadTime());
assertFalse(requests.get(1).getStructuredQuery().getStartAt().getBefore());
assertEquals(
DOCUMENT_NAME + "2",
requests.get(1).getStructuredQuery().getStartAt().getValues(0).getReferenceValue());
}

@Test
public void doesNotRetryAfterNonRetryableError() throws Exception {
doAnswer(
queryResponse(
FirestoreException.serverRejected(
Status.PERMISSION_DENIED, "Simulated test failure"),
DOCUMENT_NAME + "1",
DOCUMENT_NAME + "2"))
.when(firestoreMock)
.streamRequest(
runQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

// Verify the responses
final Semaphore semaphore = new Semaphore(0);
final Iterator<String> iterator = Arrays.asList("doc1", "doc2").iterator();

query.stream(
new ApiStreamObserver<DocumentSnapshot>() {
@Override
public void onNext(DocumentSnapshot documentSnapshot) {
assertEquals(iterator.next(), documentSnapshot.getId());
}

@Override
public void onError(Throwable throwable) {
semaphore.release();
}

@Override
public void onCompleted() {}
});

semaphore.acquire();

// Verify the request count
List<RunQueryRequest> requests = runQuery.getAllValues();
assertEquals(1, runQuery.getAllValues().size());
}

@Test
public void equalsTest() {
assertEquals(query.limit(42).offset(1337), query.offset(1337).limit(42));
Expand Down