diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/CollectionGroup.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/CollectionGroup.java index bcfc3b920..0a4c35e73 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/CollectionGroup.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/CollectionGroup.java @@ -33,7 +33,9 @@ import io.opencensus.common.Scope; import io.opencensus.trace.Span; import io.opencensus.trace.Status; +import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; import java.util.List; import javax.annotation.Nullable; @@ -63,6 +65,8 @@ public class CollectionGroup extends Query { * parallel. The returned partition cursors are split points that can be used as starting/end * points for the query results. * + * @deprecated Please use {@link #getPartitions(long)} instead. All cursors will be loaded before + * any value will be provided to {@code observer}. * @param desiredPartitionCount The desired maximum number of partition points. The number must be * strictly positive. The actual number of partitions returned may be fewer. * @param observer a stream observer that receives the result of the Partition request. @@ -159,8 +163,23 @@ private PartitionQueryRequest buildRequest(long desiredPartitionCount) { private void consumePartitions( PartitionQueryPagedResponse response, Function consumer) { - @Nullable Object[] lastCursor = null; + List cursors = new ArrayList<>(); for (Cursor cursor : response.iterateAll()) { + cursors.add(cursor); + } + + // Sort the partitions as they may not be ordered if responses are paged. + Collections.sort( + cursors, + new Comparator() { + @Override + public int compare(Cursor left, Cursor right) { + return Order.INSTANCE.compareArrays(left.getValuesList(), right.getValuesList()); + } + }); + + @Nullable Object[] lastCursor = null; + for (Cursor cursor : cursors) { Object[] decodedCursorValue = new Object[cursor.getValuesCount()]; for (int i = 0; i < cursor.getValuesCount(); ++i) { decodedCursorValue[i] = UserDataConverter.decodeValue(rpcContext, cursor.getValues(i)); diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Order.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Order.java index 46d1ffa7c..67348a728 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Order.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Order.java @@ -109,7 +109,8 @@ public int compare(@Nonnull Value left, @Nonnull Value right) { case GEO_POINT: return compareGeoPoints(left, right); case ARRAY: - return compareArrays(left, right); + return compareArrays( + left.getArrayValue().getValuesList(), right.getArrayValue().getValuesList()); case OBJECT: return compareObjects(left, right); default: @@ -171,27 +172,22 @@ private int compareResourcePaths(Value left, Value right) { return leftPath.compareTo(rightPath); } - private int compareArrays(Value left, Value right) { - List leftValue = left.getArrayValue().getValuesList(); - List rightValue = right.getArrayValue().getValuesList(); - - int minLength = Math.min(leftValue.size(), rightValue.size()); + public int compareArrays(List left, List right) { + int minLength = Math.min(left.size(), right.size()); for (int i = 0; i < minLength; i++) { - int cmp = compare(leftValue.get(i), rightValue.get(i)); + int cmp = compare(left.get(i), right.get(i)); if (cmp != 0) { return cmp; } } - return Integer.compare(leftValue.size(), rightValue.size()); + return Integer.compare(left.size(), right.size()); } private int compareObjects(Value left, Value right) { // This requires iterating over the keys in the object in order and doing a // deep comparison. - SortedMap leftMap = new TreeMap<>(); - leftMap.putAll(left.getMapValue().getFieldsMap()); - SortedMap rightMap = new TreeMap<>(); - rightMap.putAll(right.getMapValue().getFieldsMap()); + SortedMap leftMap = new TreeMap<>(left.getMapValue().getFieldsMap()); + SortedMap rightMap = new TreeMap<>(right.getMapValue().getFieldsMap()); Iterator> leftIterator = leftMap.entrySet().iterator(); Iterator> rightIterator = rightMap.entrySet().iterator(); diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/PartitionQuery.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/PartitionQuery.java new file mode 100644 index 000000000..26d7c4cde --- /dev/null +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/PartitionQuery.java @@ -0,0 +1,204 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore; + +import static com.google.cloud.firestore.LocalFirestoreHelper.queryResponse; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFutures; +import com.google.api.gax.rpc.ApiStreamObserver; +import com.google.api.gax.rpc.ServerStreamingCallable; +import com.google.api.gax.rpc.UnaryCallable; +import com.google.cloud.firestore.spi.v1.FirestoreRpc; +import com.google.cloud.firestore.v1.FirestoreClient.PartitionQueryPage; +import com.google.cloud.firestore.v1.FirestoreClient.PartitionQueryPagedResponse; +import com.google.common.collect.ImmutableList; +import com.google.firestore.v1.Cursor; +import com.google.firestore.v1.PartitionQueryRequest; +import com.google.firestore.v1.PartitionQueryResponse; +import com.google.firestore.v1.RunQueryRequest; +import com.google.firestore.v1.StructuredQuery; +import com.google.firestore.v1.Value; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Matchers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.Spy; +import org.mockito.runners.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class PartitionQuery { + public static final String DATABASE_NAME = "projects/test-project/databases/(default)/documents"; + public static final Cursor CURSOR1 = + Cursor.newBuilder() + .addValues(Value.newBuilder().setReferenceValue(DATABASE_NAME + "/collection/doc1")) + .build(); + public static final Cursor PARTITION1 = CURSOR1.toBuilder().setBefore(true).build(); + public static final Cursor CURSOR2 = + Cursor.newBuilder() + .addValues(Value.newBuilder().setReferenceValue(DATABASE_NAME + "/collection/doc2")) + .build(); + public static final Cursor PARTITION2 = CURSOR2.toBuilder().setBefore(true).build(); + + @Spy + private final FirestoreImpl firestoreMock = + new FirestoreImpl( + FirestoreOptions.newBuilder().setProjectId("test-project").build(), + Mockito.mock(FirestoreRpc.class)); + + @Mock private UnaryCallable callable; + @Mock private PartitionQueryPagedResponse pagedResponse; + @Mock private PartitionQueryPage queryPage; + + @Captor private ArgumentCaptor runQuery; + @Captor private ArgumentCaptor streamObserverCapture; + @Captor private ArgumentCaptor requestCaptor; + + @Test + public void requestsOneLessThanDesired() throws Exception { + int desiredPartitionsCount = 2; + + PartitionQueryRequest expectedRequest = + PartitionQueryRequest.newBuilder() + .setParent(DATABASE_NAME) + .setStructuredQuery( + StructuredQuery.newBuilder() + .addFrom( + StructuredQuery.CollectionSelector.newBuilder() + .setAllDescendants(true) + .setCollectionId("collectionId")) + .addOrderBy( + StructuredQuery.Order.newBuilder() + .setField( + StructuredQuery.FieldReference.newBuilder() + .setFieldPath("__name__")) + .setDirection(StructuredQuery.Direction.ASCENDING))) + .setPartitionCount(desiredPartitionsCount - 1) + .build(); + + PartitionQueryResponse response = + PartitionQueryResponse.newBuilder().addPartitions(CURSOR1).build(); + + when(pagedResponse.iterateAll()).thenReturn(ImmutableList.of(CURSOR1)); + when(queryPage.getResponse()).thenReturn(response); + doReturn(ApiFutures.immediateFuture(pagedResponse)) + .when(firestoreMock) + .sendRequest( + requestCaptor.capture(), + Matchers.>any()); + + firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get(); + + PartitionQueryRequest actualRequest = requestCaptor.getValue(); + assertEquals(actualRequest, expectedRequest); + } + + @Test + public void doesNotIssueRpcIfOnlyASinglePartitionIsRequested() throws Exception { + int desiredPartitionsCount = 1; + + List partitions = + firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get(); + + assertEquals(partitions.size(), 1); + assertNull(partitions.get(0).getStartAt()); + assertNull(partitions.get(0).getEndBefore()); + } + + @Test + public void validatesPartitionCount() { + int desiredPartitionsCount = 0; + try { + firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount); + fail(); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "Desired partition count must be one or greater"); + } + } + + @Test + public void convertsPartitionsToQueries() throws Exception { + int desiredPartitionsCount = 3; + + PartitionQueryResponse response = + PartitionQueryResponse.newBuilder().addPartitions(CURSOR1).build(); + + when(pagedResponse.iterateAll()).thenReturn(ImmutableList.of(CURSOR1, CURSOR2)); + when(queryPage.getResponse()).thenReturn(response); + doReturn(ApiFutures.immediateFuture(pagedResponse)) + .when(firestoreMock) + .sendRequest( + requestCaptor.capture(), + Matchers.>any()); + + doAnswer(queryResponse()) + .when(firestoreMock) + .streamRequest( + runQuery.capture(), + streamObserverCapture.capture(), + Matchers.any()); + + List partitions = + firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get(); + + assertEquals(partitions.size(), 3); + for (QueryPartition partition : partitions) { + partition.createQuery().get(); + } + + assertEquals(runQuery.getAllValues().size(), 3); + + assertFalse(runQuery.getAllValues().get(0).getStructuredQuery().hasStartAt()); + assertEquals(runQuery.getAllValues().get(0).getStructuredQuery().getEndAt(), PARTITION1); + assertEquals(runQuery.getAllValues().get(1).getStructuredQuery().getStartAt(), PARTITION1); + assertEquals(runQuery.getAllValues().get(1).getStructuredQuery().getEndAt(), PARTITION2); + assertEquals(runQuery.getAllValues().get(2).getStructuredQuery().getStartAt(), PARTITION2); + assertFalse(runQuery.getAllValues().get(2).getStructuredQuery().hasEndAt()); + } + + @Test + public void sortsPartitions() throws Exception { + int desiredPartitionsCount = 3; + + PartitionQueryResponse response = + PartitionQueryResponse.newBuilder().addPartitions(CURSOR1).build(); + + when(pagedResponse.iterateAll()).thenReturn(ImmutableList.of(CURSOR2, CURSOR1)); + when(queryPage.getResponse()).thenReturn(response); + doReturn(ApiFutures.immediateFuture(pagedResponse)) + .when(firestoreMock) + .sendRequest( + requestCaptor.capture(), + Matchers.>any()); + + List partitions = + firestoreMock.collectionGroup("collectionId").getPartitions(desiredPartitionsCount).get(); + + assertEquals(((DocumentReference) partitions.get(0).getEndBefore()[0]).getId(), "doc1"); + assertEquals(((DocumentReference) partitions.get(1).getEndBefore()[0]).getId(), "doc2"); + } +}