Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix: return results from getPartitions() in order (#653)
  • Loading branch information
schmidt-sebastian committed Jun 4, 2021
1 parent 41a0078 commit 12d17d1
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 13 deletions.
Expand Up @@ -33,7 +33,9 @@
import io.opencensus.common.Scope;
import io.opencensus.trace.Span;
import io.opencensus.trace.Status;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -63,6 +65,8 @@ public class CollectionGroup extends Query {
* parallel. The returned partition cursors are split points that can be used as starting/end
* points for the query results.
*
* @deprecated Please use {@link #getPartitions(long)} instead. All cursors will be loaded before
* any value will be provided to {@code observer}.
* @param desiredPartitionCount The desired maximum number of partition points. The number must be
* strictly positive. The actual number of partitions returned may be fewer.
* @param observer a stream observer that receives the result of the Partition request.
Expand Down Expand Up @@ -159,8 +163,23 @@ private PartitionQueryRequest buildRequest(long desiredPartitionCount) {

private void consumePartitions(
PartitionQueryPagedResponse response, Function<QueryPartition, Void> consumer) {
@Nullable Object[] lastCursor = null;
List<Cursor> cursors = new ArrayList<>();
for (Cursor cursor : response.iterateAll()) {
cursors.add(cursor);
}

// Sort the partitions as they may not be ordered if responses are paged.
Collections.sort(
cursors,
new Comparator<Cursor>() {
@Override
public int compare(Cursor left, Cursor right) {
return Order.INSTANCE.compareArrays(left.getValuesList(), right.getValuesList());
}
});

@Nullable Object[] lastCursor = null;
for (Cursor cursor : cursors) {
Object[] decodedCursorValue = new Object[cursor.getValuesCount()];
for (int i = 0; i < cursor.getValuesCount(); ++i) {
decodedCursorValue[i] = UserDataConverter.decodeValue(rpcContext, cursor.getValues(i));
Expand Down
Expand Up @@ -109,7 +109,8 @@ public int compare(@Nonnull Value left, @Nonnull Value right) {
case GEO_POINT:
return compareGeoPoints(left, right);
case ARRAY:
return compareArrays(left, right);
return compareArrays(
left.getArrayValue().getValuesList(), right.getArrayValue().getValuesList());
case OBJECT:
return compareObjects(left, right);
default:
Expand Down Expand Up @@ -171,27 +172,22 @@ private int compareResourcePaths(Value left, Value right) {
return leftPath.compareTo(rightPath);
}

private int compareArrays(Value left, Value right) {
List<Value> leftValue = left.getArrayValue().getValuesList();
List<Value> rightValue = right.getArrayValue().getValuesList();

int minLength = Math.min(leftValue.size(), rightValue.size());
public int compareArrays(List<Value> left, List<Value> right) {
int minLength = Math.min(left.size(), right.size());
for (int i = 0; i < minLength; i++) {
int cmp = compare(leftValue.get(i), rightValue.get(i));
int cmp = compare(left.get(i), right.get(i));
if (cmp != 0) {
return cmp;
}
}
return Integer.compare(leftValue.size(), rightValue.size());
return Integer.compare(left.size(), right.size());
}

private int compareObjects(Value left, Value right) {
// This requires iterating over the keys in the object in order and doing a
// deep comparison.
SortedMap<String, Value> leftMap = new TreeMap<>();
leftMap.putAll(left.getMapValue().getFieldsMap());
SortedMap<String, Value> rightMap = new TreeMap<>();
rightMap.putAll(right.getMapValue().getFieldsMap());
SortedMap<String, Value> leftMap = new TreeMap<>(left.getMapValue().getFieldsMap());
SortedMap<String, Value> rightMap = new TreeMap<>(right.getMapValue().getFieldsMap());

Iterator<Entry<String, Value>> leftIterator = leftMap.entrySet().iterator();
Iterator<Entry<String, Value>> rightIterator = rightMap.entrySet().iterator();
Expand Down
@@ -0,0 +1,204 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

import static com.google.cloud.firestore.LocalFirestoreHelper.queryResponse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.when;

import com.google.api.core.ApiFutures;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.cloud.firestore.spi.v1.FirestoreRpc;
import com.google.cloud.firestore.v1.FirestoreClient.PartitionQueryPage;
import com.google.cloud.firestore.v1.FirestoreClient.PartitionQueryPagedResponse;
import com.google.common.collect.ImmutableList;
import com.google.firestore.v1.Cursor;
import com.google.firestore.v1.PartitionQueryRequest;
import com.google.firestore.v1.PartitionQueryResponse;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.StructuredQuery;
import com.google.firestore.v1.Value;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.runners.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class PartitionQuery {
public static final String DATABASE_NAME = "projects/test-project/databases/(default)/documents";
public static final Cursor CURSOR1 =
Cursor.newBuilder()
.addValues(Value.newBuilder().setReferenceValue(DATABASE_NAME + "/collection/doc1"))
.build();
public static final Cursor PARTITION1 = CURSOR1.toBuilder().setBefore(true).build();
public static final Cursor CURSOR2 =
Cursor.newBuilder()
.addValues(Value.newBuilder().setReferenceValue(DATABASE_NAME + "/collection/doc2"))
.build();
public static final Cursor PARTITION2 = CURSOR2.toBuilder().setBefore(true).build();

@Spy
private final FirestoreImpl firestoreMock =
new FirestoreImpl(
FirestoreOptions.newBuilder().setProjectId("test-project").build(),
Mockito.mock(FirestoreRpc.class));

@Mock private UnaryCallable<PartitionQueryRequest, PartitionQueryPagedResponse> callable;
@Mock private PartitionQueryPagedResponse pagedResponse;
@Mock private PartitionQueryPage queryPage;

@Captor private ArgumentCaptor<RunQueryRequest> runQuery;
@Captor private ArgumentCaptor<ApiStreamObserver> streamObserverCapture;
@Captor private ArgumentCaptor<PartitionQueryRequest> requestCaptor;

@Test
public void requestsOneLessThanDesired() throws Exception {
int desiredPartitionsCount = 2;

PartitionQueryRequest expectedRequest =
PartitionQueryRequest.newBuilder()
.setParent(DATABASE_NAME)
.setStructuredQuery(
StructuredQuery.newBuilder()
.addFrom(
StructuredQuery.CollectionSelector.newBuilder()
.setAllDescendants(true)
.setCollectionId("collectionId"))
.addOrderBy(
StructuredQuery.Order.newBuilder()
.setField(
StructuredQuery.FieldReference.newBuilder()
.setFieldPath("__name__"))
.setDirection(StructuredQuery.Direction.ASCENDING)))
.setPartitionCount(desiredPartitionsCount - 1)
.build();

PartitionQueryResponse response =
PartitionQueryResponse.newBuilder().addPartitions(CURSOR1).build();

when(pagedResponse.iterateAll()).thenReturn(ImmutableList.of(CURSOR1));
when(queryPage.getResponse()).thenReturn(response);
doReturn(ApiFutures.immediateFuture(pagedResponse))
.when(firestoreMock)
.sendRequest(
requestCaptor.capture(),
Matchers.<UnaryCallable<PartitionQueryRequest, PartitionQueryPagedResponse>>any());

firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get();

PartitionQueryRequest actualRequest = requestCaptor.getValue();
assertEquals(actualRequest, expectedRequest);
}

@Test
public void doesNotIssueRpcIfOnlyASinglePartitionIsRequested() throws Exception {
int desiredPartitionsCount = 1;

List<QueryPartition> partitions =
firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get();

assertEquals(partitions.size(), 1);
assertNull(partitions.get(0).getStartAt());
assertNull(partitions.get(0).getEndBefore());
}

@Test
public void validatesPartitionCount() {
int desiredPartitionsCount = 0;
try {
firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount);
fail();
} catch (IllegalArgumentException e) {
assertEquals(e.getMessage(), "Desired partition count must be one or greater");
}
}

@Test
public void convertsPartitionsToQueries() throws Exception {
int desiredPartitionsCount = 3;

PartitionQueryResponse response =
PartitionQueryResponse.newBuilder().addPartitions(CURSOR1).build();

when(pagedResponse.iterateAll()).thenReturn(ImmutableList.of(CURSOR1, CURSOR2));
when(queryPage.getResponse()).thenReturn(response);
doReturn(ApiFutures.immediateFuture(pagedResponse))
.when(firestoreMock)
.sendRequest(
requestCaptor.capture(),
Matchers.<UnaryCallable<PartitionQueryRequest, PartitionQueryPagedResponse>>any());

doAnswer(queryResponse())
.when(firestoreMock)
.streamRequest(
runQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

List<QueryPartition> partitions =
firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get();

assertEquals(partitions.size(), 3);
for (QueryPartition partition : partitions) {
partition.createQuery().get();
}

assertEquals(runQuery.getAllValues().size(), 3);

assertFalse(runQuery.getAllValues().get(0).getStructuredQuery().hasStartAt());
assertEquals(runQuery.getAllValues().get(0).getStructuredQuery().getEndAt(), PARTITION1);
assertEquals(runQuery.getAllValues().get(1).getStructuredQuery().getStartAt(), PARTITION1);
assertEquals(runQuery.getAllValues().get(1).getStructuredQuery().getEndAt(), PARTITION2);
assertEquals(runQuery.getAllValues().get(2).getStructuredQuery().getStartAt(), PARTITION2);
assertFalse(runQuery.getAllValues().get(2).getStructuredQuery().hasEndAt());
}

@Test
public void sortsPartitions() throws Exception {
int desiredPartitionsCount = 3;

PartitionQueryResponse response =
PartitionQueryResponse.newBuilder().addPartitions(CURSOR1).build();

when(pagedResponse.iterateAll()).thenReturn(ImmutableList.of(CURSOR2, CURSOR1));
when(queryPage.getResponse()).thenReturn(response);
doReturn(ApiFutures.immediateFuture(pagedResponse))
.when(firestoreMock)
.sendRequest(
requestCaptor.capture(),
Matchers.<UnaryCallable<PartitionQueryRequest, PartitionQueryPagedResponse>>any());

List<QueryPartition> partitions =
firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get();

assertEquals(((DocumentReference) partitions.get(0).getEndBefore()[0]).getId(), "doc1");
assertEquals(((DocumentReference) partitions.get(1).getEndBefore()[0]).getId(), "doc2");
}
}

0 comments on commit 12d17d1

Please sign in to comment.