diff --git a/google/cloud/firestore.py b/google/cloud/firestore.py index 904aedc00..f80d62c09 100644 --- a/google/cloud/firestore.py +++ b/google/cloud/firestore.py @@ -26,6 +26,7 @@ from google.cloud.firestore_v1 import AsyncTransaction from google.cloud.firestore_v1 import AsyncWriteBatch from google.cloud.firestore_v1 import Client +from google.cloud.firestore_v1 import CollectionGroup from google.cloud.firestore_v1 import CollectionReference from google.cloud.firestore_v1 import DELETE_FIELD from google.cloud.firestore_v1 import DocumentReference @@ -61,6 +62,7 @@ "AsyncTransaction", "AsyncWriteBatch", "Client", + "CollectionGroup", "CollectionReference", "DELETE_FIELD", "DocumentReference", diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index 23588e4a8..79d96c3dd 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -40,6 +40,7 @@ from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.query import CollectionGroup from google.cloud.firestore_v1.query import Query from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1.transaction import transactional @@ -115,6 +116,7 @@ "AsyncTransaction", "AsyncWriteBatch", "Client", + "CollectionGroup", "CollectionReference", "DELETE_FIELD", "DocumentReference", diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 9cdab62b4..dafd1a28d 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -35,7 +35,7 @@ ) from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.async_query import AsyncQuery +from google.cloud.firestore_v1.async_query import AsyncCollectionGroup from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.async_document import ( @@ -150,7 +150,7 @@ def collection(self, *collection_path) -> AsyncCollectionReference: """ return AsyncCollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id) -> AsyncQuery: + def collection_group(self, collection_id) -> AsyncCollectionGroup: """ Creates and returns a new AsyncQuery that includes all documents in the database that are contained in a collection or subcollection with the @@ -167,12 +167,10 @@ def collection_group(self, collection_id) -> AsyncQuery: path will be included. Cannot contain a slash. Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: + :class:`~google.cloud.firestore_v1.async_query.AsyncCollectionGroup`: The created AsyncQuery. """ - return AsyncQuery( - self._get_collection_reference(collection_id), all_descendants=True - ) + return AsyncCollectionGroup(self._get_collection_reference(collection_id)) def document(self, *document_path) -> AsyncDocumentReference: """Get a reference to a document in a collection. diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 3f89b04a8..8c5302db7 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -19,7 +19,9 @@ a more common way to create a query than direct usage of the constructor. """ from google.cloud.firestore_v1.base_query import ( + BaseCollectionGroup, BaseQuery, + QueryPartition, _query_response_to_snapshot, _collection_group_query_response_to_snapshot, _enum_from_direction, @@ -207,3 +209,83 @@ async def stream( ) if snapshot is not None: yield snapshot + + +class AsyncCollectionGroup(AsyncQuery, BaseCollectionGroup): + """Represents a Collection Group in the Firestore API. + + This is a specialization of :class:`.AsyncQuery` that includes all documents in the + database that are contained in a collection or subcollection of the given + parent. + + Args: + parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + The collection that this query applies to. + """ + + def __init__( + self, + parent, + projection=None, + field_filters=(), + orders=(), + limit=None, + limit_to_last=False, + offset=None, + start_at=None, + end_at=None, + all_descendants=True, + ) -> None: + super(AsyncCollectionGroup, self).__init__( + parent=parent, + projection=projection, + field_filters=field_filters, + orders=orders, + limit=limit, + limit_to_last=limit_to_last, + offset=offset, + start_at=start_at, + end_at=end_at, + all_descendants=all_descendants, + ) + + async def get_partitions( + self, partition_count + ) -> AsyncGenerator[QueryPartition, None]: + """Partition a query for parallelization. + + Partitions a query by returning partition cursors that can be used to run the + query in parallel. The returned partition cursors are split points that can be + used as starting/end points for the query results. + + Args: + partition_count (int): The desired maximum number of partition points. The + number must be strictly positive. The actual number of partitions + returned may be fewer. + """ + self._validate_partition_query() + query = AsyncQuery( + self._parent, + orders=self._PARTITION_QUERY_ORDER, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + parent_path, expected_prefix = self._parent._parent_info() + pager = await self._client._firestore_api.partition_query( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "partition_count": partition_count, + }, + metadata=self._client._rpc_metadata, + ) + + start_at = None + async for cursor_pb in pager: + cursor = self._client.document(cursor_pb.values[0].reference_value) + yield QueryPartition(self, start_at, cursor) + start_at = cursor + + yield QueryPartition(self, start_at, None) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index a7c006c11..1f7d9fdb7 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1020,3 +1020,115 @@ def _collection_group_query_response_to_snapshot( update_time=response_pb._pb.document.update_time, ) return snapshot + + +class BaseCollectionGroup(BaseQuery): + """Represents a Collection Group in the Firestore API. + + This is a specialization of :class:`.Query` that includes all documents in the + database that are contained in a collection or subcollection of the given + parent. + + Args: + parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + The collection that this query applies to. + """ + + _PARTITION_QUERY_ORDER = ( + BaseQuery._make_order( + field_path_module.FieldPath.document_id(), BaseQuery.ASCENDING, + ), + ) + + def __init__( + self, + parent, + projection=None, + field_filters=(), + orders=(), + limit=None, + limit_to_last=False, + offset=None, + start_at=None, + end_at=None, + all_descendants=True, + ) -> None: + if not all_descendants: + raise ValueError("all_descendants must be True for collection group query.") + + super(BaseCollectionGroup, self).__init__( + parent=parent, + projection=projection, + field_filters=field_filters, + orders=orders, + limit=limit, + limit_to_last=limit_to_last, + offset=offset, + start_at=start_at, + end_at=end_at, + all_descendants=all_descendants, + ) + + def _validate_partition_query(self): + if self._field_filters: + raise ValueError("Can't partition query with filters.") + + if self._projection: + raise ValueError("Can't partition query with projection.") + + if self._limit: + raise ValueError("Can't partition query with limit.") + + if self._offset: + raise ValueError("Can't partition query with offset.") + + +class QueryPartition: + """Represents a bounded partition of a collection group query. + + Contains cursors that can be used in a query as a starting and/or end point for the + collection group query. The cursors may only be used in a query that matches the + constraints of the query that produced this partition. + + Args: + query (BaseQuery): The original query that this is a partition of. + start_at (Optional[~google.cloud.firestore_v1.document.DocumentSnapshot]): + Cursor for first query result to include. If `None`, the partition starts at + the beginning of the result set. + end_at (Optional[~google.cloud.firestore_v1.document.DocumentSnapshot]): + Cursor for first query result after the last result included in the + partition. If `None`, the partition runs to the end of the result set. + + """ + + def __init__(self, query, start_at, end_at): + self._query = query + self._start_at = start_at + self._end_at = end_at + + @property + def start_at(self): + return self._start_at + + @property + def end_at(self): + return self._end_at + + def query(self): + """Generate a new query using this partition's bounds. + + Returns: + BaseQuery: Copy of the original query with start and end bounds set by the + cursors from this partition. + """ + query = self._query + start_at = ([self.start_at], True) if self.start_at else None + end_at = ([self.end_at], True) if self.end_at else None + + return type(query)( + query._parent, + all_descendants=query._all_descendants, + orders=query._PARTITION_QUERY_ORDER, + start_at=start_at, + end_at=end_at, + ) diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 30d6bd1cd..448a8f4fb 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -35,7 +35,7 @@ ) from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.query import Query +from google.cloud.firestore_v1.query import CollectionGroup from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference @@ -145,7 +145,7 @@ def collection(self, *collection_path) -> CollectionReference: """ return CollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id) -> Query: + def collection_group(self, collection_id) -> CollectionGroup: """ Creates and returns a new Query that includes all documents in the database that are contained in a collection or subcollection with the @@ -162,12 +162,10 @@ def collection_group(self, collection_id) -> Query: path will be included. Cannot contain a slash. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.CollectionGroup`: The created Query. """ - return Query( - self._get_collection_reference(collection_id), all_descendants=True - ) + return CollectionGroup(self._get_collection_reference(collection_id)) def document(self, *document_path) -> DocumentReference: """Get a reference to a document in a collection. diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 9b0dc4462..09f8dc47b 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -19,7 +19,9 @@ a more common way to create a query than direct usage of the constructor. """ from google.cloud.firestore_v1.base_query import ( + BaseCollectionGroup, BaseQuery, + QueryPartition, _query_response_to_snapshot, _collection_group_query_response_to_snapshot, _enum_from_direction, @@ -239,3 +241,81 @@ def on_snapshot(docs, changes, read_time): return Watch.for_query( self, callback, document.DocumentSnapshot, document.DocumentReference ) + + +class CollectionGroup(Query, BaseCollectionGroup): + """Represents a Collection Group in the Firestore API. + + This is a specialization of :class:`.Query` that includes all documents in the + database that are contained in a collection or subcollection of the given + parent. + + Args: + parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + The collection that this query applies to. + """ + + def __init__( + self, + parent, + projection=None, + field_filters=(), + orders=(), + limit=None, + limit_to_last=False, + offset=None, + start_at=None, + end_at=None, + all_descendants=True, + ) -> None: + super(CollectionGroup, self).__init__( + parent=parent, + projection=projection, + field_filters=field_filters, + orders=orders, + limit=limit, + limit_to_last=limit_to_last, + offset=offset, + start_at=start_at, + end_at=end_at, + all_descendants=all_descendants, + ) + + def get_partitions(self, partition_count) -> Generator[QueryPartition, None, None]: + """Partition a query for parallelization. + + Partitions a query by returning partition cursors that can be used to run the + query in parallel. The returned partition cursors are split points that can be + used as starting/end points for the query results. + + Args: + partition_count (int): The desired maximum number of partition points. The + number must be strictly positive. The actual number of partitions + returned may be fewer. + """ + self._validate_partition_query() + query = Query( + self._parent, + orders=self._PARTITION_QUERY_ORDER, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + parent_path, expected_prefix = self._parent._parent_info() + pager = self._client._firestore_api.partition_query( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "partition_count": partition_count, + }, + metadata=self._client._rpc_metadata, + ) + + start_at = None + for cursor_pb in pager: + cursor = self._client.document(cursor_pb.values[0].reference_value) + yield QueryPartition(self, start_at, cursor) + start_at = cursor + + yield QueryPartition(self, start_at, None) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 8b754e93f..988fa082c 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import itertools import math import operator @@ -52,7 +53,7 @@ def _get_credentials_and_project(): return credentials, project -@pytest.fixture(scope=u"module") +@pytest.fixture(scope="module") def client(): credentials, project = _get_credentials_and_project() yield firestore.Client(project=project, credentials=credentials) @@ -389,7 +390,7 @@ def test_document_get(client, cleanup): "fire": 199099299, "referee": ref_doc, "gio": firestore.GeoPoint(45.5, 90.0), - "deep": [u"some", b"\xde\xad\xbe\xef"], + "deep": ["some", b"\xde\xad\xbe\xef"], "map": {"ice": True, "water": None, "vapor": {"deeper": now}}, } write_result = document.create(data) @@ -717,9 +718,9 @@ def test_query_with_order_dot_key(client, cleanup): .stream() ) found_data = [ - {u"count": 30, u"wordcount": {u"page1": 130}}, - {u"count": 40, u"wordcount": {u"page1": 140}}, - {u"count": 50, u"wordcount": {u"page1": 150}}, + {"count": 30, "wordcount": {"page1": 130}}, + {"count": 40, "wordcount": {"page1": 140}}, + {"count": 50, "wordcount": {"page1": 150}}, ] assert found_data == [snap.to_dict() for snap in found] cursor_with_dotted_paths = {"wordcount.page1": last_value} @@ -890,6 +891,63 @@ def test_collection_group_queries_filters(client, cleanup): assert found == set(["cg-doc2"]) +def test_partition_query_no_partitions(client, cleanup): + collection_group = "b" + UNIQUE_RESOURCE_ID + + # less than minimum partition size + doc_paths = [ + "abc/123/" + collection_group + "/cg-doc1", + "abc/123/" + collection_group + "/cg-doc2", + collection_group + "/cg-doc3", + collection_group + "/cg-doc4", + "def/456/" + collection_group + "/cg-doc5", + ] + + batch = client.batch() + cleanup_batch = client.batch() + cleanup(cleanup_batch.commit) + for doc_path in doc_paths: + doc_ref = client.document(doc_path) + batch.set(doc_ref, {"x": 1}) + cleanup_batch.delete(doc_ref) + + batch.commit() + + query = client.collection_group(collection_group) + partitions = list(query.get_partitions(3)) + streams = [partition.query().stream() for partition in partitions] + snapshots = itertools.chain(*streams) + found = [snapshot.id for snapshot in snapshots] + expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] + assert found == expected + + +def test_partition_query(client, cleanup): + collection_group = "b" + UNIQUE_RESOURCE_ID + n_docs = 128 * 2 + 127 # Minimum partition size is 128 + parents = itertools.cycle(("", "abc/123/", "def/456/", "ghi/789/")) + batch = client.batch() + cleanup_batch = client.batch() + cleanup(cleanup_batch.commit) + expected = [] + for i, parent in zip(range(n_docs), parents): + doc_path = parent + collection_group + f"/cg-doc{i:03d}" + doc_ref = client.document(doc_path) + batch.set(doc_ref, {"x": i}) + cleanup_batch.delete(doc_ref) + expected.append(doc_path) + + batch.commit() + + query = client.collection_group(collection_group) + partitions = list(query.get_partitions(3)) + streams = [partition.query().stream() for partition in partitions] + snapshots = itertools.chain(*streams) + found = [snapshot.reference.path for snapshot in snapshots] + expected.sort() + assert found == expected + + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") def test_get_all(client, cleanup): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -989,11 +1047,11 @@ def test_batch(client, cleanup): def test_watch_document(client, cleanup): db = client - collection_ref = db.collection(u"wd-users" + UNIQUE_RESOURCE_ID) - doc_ref = collection_ref.document(u"alovelace") + collection_ref = db.collection("wd-users" + UNIQUE_RESOURCE_ID) + doc_ref = collection_ref.document("alovelace") # Initial setting - doc_ref.set({u"first": u"Jane", u"last": u"Doe", u"born": 1900}) + doc_ref.set({"first": "Jane", "last": "Doe", "born": 1900}) cleanup(doc_ref.delete) sleep(1) @@ -1007,7 +1065,7 @@ def on_snapshot(docs, changes, read_time): doc_ref.on_snapshot(on_snapshot) # Alter document - doc_ref.set({u"first": u"Ada", u"last": u"Lovelace", u"born": 1815}) + doc_ref.set({"first": "Ada", "last": "Lovelace", "born": 1815}) sleep(1) @@ -1025,11 +1083,11 @@ def on_snapshot(docs, changes, read_time): def test_watch_collection(client, cleanup): db = client - collection_ref = db.collection(u"wc-users" + UNIQUE_RESOURCE_ID) - doc_ref = collection_ref.document(u"alovelace") + collection_ref = db.collection("wc-users" + UNIQUE_RESOURCE_ID) + doc_ref = collection_ref.document("alovelace") # Initial setting - doc_ref.set({u"first": u"Jane", u"last": u"Doe", u"born": 1900}) + doc_ref.set({"first": "Jane", "last": "Doe", "born": 1900}) cleanup(doc_ref.delete) # Setup listener @@ -1046,7 +1104,7 @@ def on_snapshot(docs, changes, read_time): # delay here so initial on_snapshot occurs and isn't combined with set sleep(1) - doc_ref.set({u"first": u"Ada", u"last": u"Lovelace", u"born": 1815}) + doc_ref.set({"first": "Ada", "last": "Lovelace", "born": 1815}) for _ in range(10): if on_snapshot.born == 1815: @@ -1061,12 +1119,12 @@ def on_snapshot(docs, changes, read_time): def test_watch_query(client, cleanup): db = client - collection_ref = db.collection(u"wq-users" + UNIQUE_RESOURCE_ID) - doc_ref = collection_ref.document(u"alovelace") - query_ref = collection_ref.where("first", "==", u"Ada") + collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) + doc_ref = collection_ref.document("alovelace") + query_ref = collection_ref.where("first", "==", "Ada") # Initial setting - doc_ref.set({u"first": u"Jane", u"last": u"Doe", u"born": 1900}) + doc_ref.set({"first": "Jane", "last": "Doe", "born": 1900}) cleanup(doc_ref.delete) sleep(1) @@ -1076,7 +1134,7 @@ def on_snapshot(docs, changes, read_time): on_snapshot.called_count += 1 # A snapshot should return the same thing as if a query ran now. - query_ran = collection_ref.where("first", "==", u"Ada").stream() + query_ran = collection_ref.where("first", "==", "Ada").stream() assert len(docs) == len([i for i in query_ran]) on_snapshot.called_count = 0 @@ -1084,7 +1142,7 @@ def on_snapshot(docs, changes, read_time): query_ref.on_snapshot(on_snapshot) # Alter document - doc_ref.set({u"first": u"Ada", u"last": u"Lovelace", u"born": 1815}) + doc_ref.set({"first": "Ada", "last": "Lovelace", "born": 1815}) for _ in range(10): if on_snapshot.called_count == 1: @@ -1100,14 +1158,14 @@ def on_snapshot(docs, changes, read_time): def test_watch_query_order(client, cleanup): db = client - collection_ref = db.collection(u"users") - doc_ref1 = collection_ref.document(u"alovelace" + UNIQUE_RESOURCE_ID) - doc_ref2 = collection_ref.document(u"asecondlovelace" + UNIQUE_RESOURCE_ID) - doc_ref3 = collection_ref.document(u"athirdlovelace" + UNIQUE_RESOURCE_ID) - doc_ref4 = collection_ref.document(u"afourthlovelace" + UNIQUE_RESOURCE_ID) - doc_ref5 = collection_ref.document(u"afifthlovelace" + UNIQUE_RESOURCE_ID) + collection_ref = db.collection("users") + doc_ref1 = collection_ref.document("alovelace" + UNIQUE_RESOURCE_ID) + doc_ref2 = collection_ref.document("asecondlovelace" + UNIQUE_RESOURCE_ID) + doc_ref3 = collection_ref.document("athirdlovelace" + UNIQUE_RESOURCE_ID) + doc_ref4 = collection_ref.document("afourthlovelace" + UNIQUE_RESOURCE_ID) + doc_ref5 = collection_ref.document("afifthlovelace" + UNIQUE_RESOURCE_ID) - query_ref = collection_ref.where("first", "==", u"Ada").order_by("last") + query_ref = collection_ref.where("first", "==", "Ada").order_by("last") # Setup listener def on_snapshot(docs, changes, read_time): @@ -1139,19 +1197,19 @@ def on_snapshot(docs, changes, read_time): sleep(1) - doc_ref1.set({u"first": u"Ada", u"last": u"Lovelace", u"born": 1815}) + doc_ref1.set({"first": "Ada", "last": "Lovelace", "born": 1815}) cleanup(doc_ref1.delete) - doc_ref2.set({u"first": u"Ada", u"last": u"SecondLovelace", u"born": 1815}) + doc_ref2.set({"first": "Ada", "last": "SecondLovelace", "born": 1815}) cleanup(doc_ref2.delete) - doc_ref3.set({u"first": u"Ada", u"last": u"ThirdLovelace", u"born": 1815}) + doc_ref3.set({"first": "Ada", "last": "ThirdLovelace", "born": 1815}) cleanup(doc_ref3.delete) - doc_ref4.set({u"first": u"Ada", u"last": u"FourthLovelace", u"born": 1815}) + doc_ref4.set({"first": "Ada", "last": "FourthLovelace", "born": 1815}) cleanup(doc_ref4.delete) - doc_ref5.set({u"first": u"Ada", u"last": u"lovelace", u"born": 1815}) + doc_ref5.set({"first": "Ada", "last": "lovelace", "born": 1815}) cleanup(doc_ref5.delete) for _ in range(10): diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 09646ca46..65a46d984 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -14,6 +14,7 @@ import asyncio import datetime +import itertools import math import pytest import operator @@ -54,7 +55,7 @@ def _get_credentials_and_project(): return credentials, project -@pytest.fixture(scope=u"module") +@pytest.fixture(scope="module") def client(): credentials, project = _get_credentials_and_project() yield firestore.AsyncClient(project=project, credentials=credentials) @@ -399,7 +400,7 @@ async def test_document_get(client, cleanup): "fire": 199099299, "referee": ref_doc, "gio": firestore.GeoPoint(45.5, 90.0), - "deep": [u"some", b"\xde\xad\xbe\xef"], + "deep": ["some", b"\xde\xad\xbe\xef"], "map": {"ice": True, "water": None, "vapor": {"deeper": now}}, } write_result = await document.create(data) @@ -741,9 +742,9 @@ async def test_query_with_order_dot_key(client, cleanup): .stream() ] found_data = [ - {u"count": 30, u"wordcount": {u"page1": 130}}, - {u"count": 40, u"wordcount": {u"page1": 140}}, - {u"count": 50, u"wordcount": {u"page1": 150}}, + {"count": 30, "wordcount": {"page1": 130}}, + {"count": 40, "wordcount": {"page1": 140}}, + {"count": 50, "wordcount": {"page1": 150}}, ] assert found_data == [snap.to_dict() for snap in found] cursor_with_dotted_paths = {"wordcount.page1": last_value} @@ -915,6 +916,61 @@ async def test_collection_group_queries_filters(client, cleanup): assert found == set(["cg-doc2"]) +async def test_partition_query_no_partitions(client, cleanup): + collection_group = "b" + UNIQUE_RESOURCE_ID + + # less than minimum partition size + doc_paths = [ + "abc/123/" + collection_group + "/cg-doc1", + "abc/123/" + collection_group + "/cg-doc2", + collection_group + "/cg-doc3", + collection_group + "/cg-doc4", + "def/456/" + collection_group + "/cg-doc5", + ] + + batch = client.batch() + cleanup_batch = client.batch() + cleanup(cleanup_batch.commit) + for doc_path in doc_paths: + doc_ref = client.document(doc_path) + batch.set(doc_ref, {"x": 1}) + cleanup_batch.delete(doc_ref) + + await batch.commit() + + query = client.collection_group(collection_group) + partitions = [i async for i in query.get_partitions(3)] + streams = [partition.query().stream() for partition in partitions] + found = [snapshot.id async for snapshot in _chain(*streams)] + expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] + assert found == expected + + +async def test_partition_query(client, cleanup): + collection_group = "b" + UNIQUE_RESOURCE_ID + n_docs = 128 * 2 + 127 # Minimum partition size is 128 + parents = itertools.cycle(("", "abc/123/", "def/456/", "ghi/789/")) + batch = client.batch() + cleanup_batch = client.batch() + cleanup(cleanup_batch.commit) + expected = [] + for i, parent in zip(range(n_docs), parents): + doc_path = parent + collection_group + f"/cg-doc{i:03d}" + doc_ref = client.document(doc_path) + batch.set(doc_ref, {"x": i}) + cleanup_batch.delete(doc_ref) + expected.append(doc_path) + + await batch.commit() + + query = client.collection_group(collection_group) + partitions = [i async for i in query.get_partitions(3)] + streams = [partition.query().stream() for partition in partitions] + found = [snapshot.reference.path async for snapshot in _chain(*streams)] + expected.sort() + assert found == expected + + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") async def test_get_all(client, cleanup): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -1013,3 +1069,10 @@ async def test_batch(client, cleanup): assert snapshot2.update_time == write_result2.update_time assert not (await document3.get()).exists + + +async def _chain(*iterators): + """Asynchronous reimplementation of `itertools.chain`.""" + for iterator in iterators: + async for value in iterator: + yield value diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 14e41c278..944c63ae0 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -18,7 +18,11 @@ import mock from tests.unit.v1.test__helpers import AsyncMock, AsyncIter -from tests.unit.v1.test_base_query import _make_credentials, _make_query_response +from tests.unit.v1.test_base_query import ( + _make_credentials, + _make_query_response, + _make_cursor_pb, +) class MockAsyncIter: @@ -434,6 +438,116 @@ async def test_stream_w_collection_group(self): ) +class TestCollectionGroup(aiounittest.AsyncTestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_query import AsyncCollectionGroup + + return AsyncCollectionGroup + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + query = self._make_one(mock.sentinel.parent) + self.assertIs(query._parent, mock.sentinel.parent) + self.assertIsNone(query._projection) + self.assertEqual(query._field_filters, ()) + self.assertEqual(query._orders, ()) + self.assertIsNone(query._limit) + self.assertIsNone(query._offset) + self.assertIsNone(query._start_at) + self.assertIsNone(query._end_at) + self.assertTrue(query._all_descendants) + + def test_constructor_all_descendents_is_false(self): + with pytest.raises(ValueError): + self._make_one(mock.sentinel.parent, all_descendants=False) + + @pytest.mark.asyncio + async def test_get_partitions(self): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["partition_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + + # Make two **real** document references to use as cursors + document1 = parent.document("one") + document2 = parent.document("two") + + # Add cursor pb's to the minimal fake GAPIC. + cursor_pb1 = _make_cursor_pb(([document1], False)) + cursor_pb2 = _make_cursor_pb(([document2], False)) + firestore_api.partition_query.return_value = AsyncIter([cursor_pb1, cursor_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.get_partitions(2) + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [i async for i in get_response] + self.assertEqual(len(returned), 3) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + partition_query = self._make_one( + parent, orders=(query._make_order("__name__", query.ASCENDING),), + ) + firestore_api.partition_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": partition_query._to_protobuf(), + "partition_count": 2, + }, + metadata=client._rpc_metadata, + ) + + async def test_get_partitions_w_filter(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).where("foo", "==", "bar") + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] + + async def test_get_partitions_w_projection(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).select("foo") + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] + + async def test_get_partitions_w_limit(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).limit(10) + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] + + async def test_get_partitions_w_offset(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).offset(10) + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] + + def _make_client(project="project-project"): from google.cloud.firestore_v1.async_client import AsyncClient diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index faa0e2e78..59578af39 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1427,3 +1427,71 @@ def _make_query_response(**kwargs): kwargs["document"] = document_pb return firestore.RunQueryResponse(**kwargs) + + +def _make_cursor_pb(pair): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import query + + values, before = pair + value_pbs = [_helpers.encode_value(value) for value in values] + return query.Cursor(values=value_pbs, before=before) + + +class TestQueryPartition(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.base_query import QueryPartition + + return QueryPartition + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + partition = self._make_one(mock.sentinel.query, "start", "end") + assert partition._query is mock.sentinel.query + assert partition.start_at == "start" + assert partition.end_at == "end" + + def test_query_begin(self): + partition = self._make_one(DummyQuery("PARENT"), None, "end") + query = partition.query() + assert query._parent == "PARENT" + assert query.all_descendants == "YUP" + assert query.orders == "ORDER" + assert query.start_at is None + assert query.end_at == (["end"], True) + + def test_query_middle(self): + partition = self._make_one(DummyQuery("PARENT"), "start", "end") + query = partition.query() + assert query._parent == "PARENT" + assert query.all_descendants == "YUP" + assert query.orders == "ORDER" + assert query.start_at == (["start"], True) + assert query.end_at == (["end"], True) + + def test_query_end(self): + partition = self._make_one(DummyQuery("PARENT"), "start", None) + query = partition.query() + assert query._parent == "PARENT" + assert query.all_descendants == "YUP" + assert query.orders == "ORDER" + assert query.start_at == (["start"], True) + assert query.end_at is None + + +class DummyQuery: + _all_descendants = "YUP" + _PARTITION_QUERY_ORDER = "ORDER" + + def __init__( + self, parent, *, all_descendants=None, orders=None, start_at=None, end_at=None + ): + self._parent = parent + self.all_descendants = all_descendants + self.orders = orders + self.start_at = start_at + self.end_at = end_at diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index 3ad01d02c..e2290db37 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -16,8 +16,11 @@ import unittest import mock +import pytest -from tests.unit.v1.test_base_query import _make_credentials, _make_query_response +from tests.unit.v1.test_base_query import _make_credentials +from tests.unit.v1.test_base_query import _make_cursor_pb +from tests.unit.v1.test_base_query import _make_query_response class TestQuery(unittest.TestCase): @@ -418,6 +421,115 @@ def test_on_snapshot(self, watch): watch.for_query.assert_called_once() +class TestCollectionGroup(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.query import CollectionGroup + + return CollectionGroup + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + query = self._make_one(mock.sentinel.parent) + self.assertIs(query._parent, mock.sentinel.parent) + self.assertIsNone(query._projection) + self.assertEqual(query._field_filters, ()) + self.assertEqual(query._orders, ()) + self.assertIsNone(query._limit) + self.assertIsNone(query._offset) + self.assertIsNone(query._start_at) + self.assertIsNone(query._end_at) + self.assertTrue(query._all_descendants) + + def test_constructor_all_descendents_is_false(self): + with pytest.raises(ValueError): + self._make_one(mock.sentinel.parent, all_descendants=False) + + def test_get_partitions(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["partition_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + + # Make two **real** document references to use as cursors + document1 = parent.document("one") + document2 = parent.document("two") + + # Add cursor pb's to the minimal fake GAPIC. + cursor_pb1 = _make_cursor_pb(([document1], False)) + cursor_pb2 = _make_cursor_pb(([document2], False)) + firestore_api.partition_query.return_value = iter([cursor_pb1, cursor_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.get_partitions(2) + self.assertIsInstance(get_response, types.GeneratorType) + returned = list(get_response) + self.assertEqual(len(returned), 3) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + partition_query = self._make_one( + parent, orders=(query._make_order("__name__", query.ASCENDING),), + ) + firestore_api.partition_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": partition_query._to_protobuf(), + "partition_count": 2, + }, + metadata=client._rpc_metadata, + ) + + def test_get_partitions_w_filter(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).where("foo", "==", "bar") + with pytest.raises(ValueError): + list(query.get_partitions(2)) + + def test_get_partitions_w_projection(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).select("foo") + with pytest.raises(ValueError): + list(query.get_partitions(2)) + + def test_get_partitions_w_limit(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).limit(10) + with pytest.raises(ValueError): + list(query.get_partitions(2)) + + def test_get_partitions_w_offset(self): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = self._make_one(parent).offset(10) + with pytest.raises(ValueError): + list(query.get_partitions(2)) + + def _make_client(project="project-project"): from google.cloud.firestore_v1.client import Client