diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index aebdbee47..52d88006c 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -144,6 +144,7 @@ def verify_path(path, is_collection) -> None: if is_collection: if num_elements % 2 == 0: raise ValueError("A collection must have an odd number of path elements") + else: if num_elements % 2 == 1: raise ValueError("A document must have an even number of path elements") diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index f772194e8..2f94b5f7c 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -22,6 +22,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore +from google.cloud import firestore_v1 from google.cloud.firestore_v1.base_query import ( BaseCollectionGroup, BaseQuery, @@ -32,7 +33,7 @@ ) from google.cloud.firestore_v1 import async_document -from typing import AsyncGenerator +from typing import AsyncGenerator, Type # Types needed only for Type Hints from google.cloud.firestore_v1.transaction import Transaction @@ -92,6 +93,9 @@ class AsyncQuery(BaseQuery): When false, selects only collections that are immediate children of the `parent` specified in the containing `RunQueryRequest`. When true, selects all descendant collections. + recursive (Optional[bool]): + When true, returns all documents and all documents in any subcollections + below them. Defaults to false. """ def __init__( @@ -106,6 +110,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, + recursive=False, ) -> None: super(AsyncQuery, self).__init__( parent=parent, @@ -118,6 +123,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) async def get( @@ -224,6 +230,14 @@ async def stream( if snapshot is not None: yield snapshot + @staticmethod + def _get_collection_reference_class() -> Type[ + "firestore_v1.async_collection.AsyncCollectionReference" + ]: + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + return AsyncCollectionReference + class AsyncCollectionGroup(AsyncQuery, BaseCollectionGroup): """Represents a Collection Group in the Firestore API. @@ -249,6 +263,7 @@ def __init__( start_at=None, end_at=None, all_descendants=True, + recursive=False, ) -> None: super(AsyncCollectionGroup, self).__init__( parent=parent, @@ -261,6 +276,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) @staticmethod diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index ce31bfb0a..02363efc2 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -124,7 +124,10 @@ def document(self, document_id: str = None) -> DocumentReference: if document_id is None: document_id = _auto_id() - child_path = self._path + (document_id,) + # Append `self._path` and the passed document's ID as long as the first + # element in the path is not an empty string, which comes from setting the + # parent to "" for recursive queries. + child_path = self._path + (document_id,) if self._path[0] else (document_id,) return self._client.document(*child_path) def _parent_info(self) -> Tuple[Any, str]: @@ -200,6 +203,9 @@ def list_documents( ]: raise NotImplementedError + def recursive(self) -> "BaseQuery": + return self._query().recursive() + def select(self, field_paths: Iterable[str]) -> BaseQuery: """Create a "select" query with this collection as parent. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 5d11ccb3c..1812cfca0 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -33,7 +33,17 @@ from google.cloud.firestore_v1.types import Cursor from google.cloud.firestore_v1.types import RunQueryResponse from google.cloud.firestore_v1.order import Order -from typing import Any, Dict, Generator, Iterable, NoReturn, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + NoReturn, + Optional, + Tuple, + Type, + Union, +) # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot @@ -144,6 +154,9 @@ class BaseQuery(object): When false, selects only collections that are immediate children of the `parent` specified in the containing `RunQueryRequest`. When true, selects all descendant collections. + recursive (Optional[bool]): + When true, returns all documents and all documents in any subcollections + below them. Defaults to false. """ ASCENDING = "ASCENDING" @@ -163,6 +176,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, + recursive=False, ) -> None: self._parent = parent self._projection = projection @@ -174,6 +188,7 @@ def __init__( self._start_at = start_at self._end_at = end_at self._all_descendants = all_descendants + self._recursive = recursive def __eq__(self, other): if not isinstance(other, self.__class__): @@ -247,6 +262,7 @@ def _copy( start_at: Optional[Tuple[dict, bool]] = _not_passed, end_at: Optional[Tuple[dict, bool]] = _not_passed, all_descendants: Optional[bool] = _not_passed, + recursive: Optional[bool] = _not_passed, ) -> "BaseQuery": return self.__class__( self._parent, @@ -261,6 +277,7 @@ def _copy( all_descendants=self._evaluate_param( all_descendants, self._all_descendants ), + recursive=self._evaluate_param(recursive, self._recursive), ) def _evaluate_param(self, value, fallback_value): @@ -813,6 +830,46 @@ def stream( def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError + def recursive(self) -> "BaseQuery": + """Returns a copy of this query whose iterator will yield all matching + documents as well as each of their descendent subcollections and documents. + + This differs from the `all_descendents` flag, which only returns descendents + whose subcollection names match the parent collection's name. To return + all descendents, regardless of their subcollection name, use this. + """ + copied = self._copy(recursive=True, all_descendants=True) + if copied._parent and copied._parent.id: + original_collection_id = "/".join(copied._parent._path) + + # Reset the parent to nothing so we can recurse through the entire + # database. This is required to have + # `CollectionSelector.collection_id` not override + # `CollectionSelector.all_descendants`, which happens if both are + # set. + copied._parent = copied._get_collection_reference_class()("") + copied._parent._client = self._parent._client + + # But wait! We don't want to load the entire database; only the + # collection the user originally specified. To accomplish that, we + # add the following arcane filters. + + REFERENCE_NAME_MIN_ID = "__id-9223372036854775808__" + start_at = f"{original_collection_id}/{REFERENCE_NAME_MIN_ID}" + + # The backend interprets this null character is flipping the filter + # to mean the end of the range instead of the beginning. + nullChar = "\0" + end_at = f"{original_collection_id}{nullChar}/{REFERENCE_NAME_MIN_ID}" + + copied = ( + copied.order_by(field_path_module.FieldPath.document_id()) + .start_at({field_path_module.FieldPath.document_id(): start_at}) + .end_at({field_path_module.FieldPath.document_id(): end_at}) + ) + + return copied + def _comparator(self, doc1, doc2) -> int: _orders = self._orders @@ -1073,6 +1130,7 @@ def __init__( start_at=None, end_at=None, all_descendants=True, + recursive=False, ) -> None: if not all_descendants: raise ValueError("all_descendants must be True for collection group query.") @@ -1088,6 +1146,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) def _validate_partition_query(self): @@ -1133,6 +1192,10 @@ def get_partitions( ) -> NoReturn: raise NotImplementedError + @staticmethod + def _get_collection_reference_class() -> Type["BaseCollectionGroup"]: + raise NotImplementedError + class QueryPartition: """Represents a bounded partition of a collection group query. diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index aa2f5ad09..f1e044cbd 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -19,6 +19,7 @@ a more common way to create a query than direct usage of the constructor. """ +from google.cloud import firestore_v1 from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -34,7 +35,7 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List +from typing import Any, Callable, Generator, List, Type class Query(BaseQuery): @@ -105,6 +106,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, + recursive=False, ) -> None: super(Query, self).__init__( parent=parent, @@ -117,6 +119,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) def get( @@ -254,6 +257,14 @@ def on_snapshot(docs, changes, read_time): self, callback, document.DocumentSnapshot, document.DocumentReference ) + @staticmethod + def _get_collection_reference_class() -> Type[ + "firestore_v1.collection.CollectionReference" + ]: + from google.cloud.firestore_v1.collection import CollectionReference + + return CollectionReference + class CollectionGroup(Query, BaseCollectionGroup): """Represents a Collection Group in the Firestore API. @@ -279,6 +290,7 @@ def __init__( start_at=None, end_at=None, all_descendants=True, + recursive=False, ) -> None: super(CollectionGroup, self).__init__( parent=parent, @@ -291,6 +303,7 @@ def __init__( start_at=start_at, end_at=end_at, all_descendants=all_descendants, + recursive=recursive, ) @staticmethod diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 6d4471461..6e72e65cf 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1212,6 +1212,127 @@ def test_array_union(client, cleanup): assert doc_ref.get().to_dict() == expected +def test_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Socrates", "favoriteCity": "Athens"}, + "subcollections": { + "pets": [{"name": "Scruffy"}, {"name": "Snowflake"}], + "hobbies": [{"name": "pontificating"}, {"name": "journaling"}], + "philosophers": [{"name": "Aristotle"}, {"name": "Plato"}], + }, + }, + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + { + "data": {"name": "Plato", "favoriteCity": "Corinth"}, + "subcollections": { + "pets": [{"name": "Cuddles"}, {"name": "Sergeant-Puppers"}], + "hobbies": [{"name": "abstraction"}, {"name": "hypotheticals"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}" + ) + ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + ids = [doc.id for doc in db.collection_group("philosophers").recursive().get()] + + expected_ids = [ + # Aristotle doc and subdocs + f"Aristotle{UNIQUE_RESOURCE_ID}", + "meditation", + "questioning-stuff", + "Doggy-Dog", + "Floof-Boy", + # Plato doc and subdocs + f"Plato{UNIQUE_RESOURCE_ID}", + "abstraction", + "hypotheticals", + "Cuddles", + "Sergeant-Puppers", + # Socrates doc and subdocs + f"Socrates{UNIQUE_RESOURCE_ID}", + "journaling", + "pontificating", + "Scruffy", + "Snowflake", + "Aristotle", + "Plato", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + +def test_nested_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}" + ) + ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + aristotle = collection_ref.document(f"Aristotle{UNIQUE_RESOURCE_ID}") + ids = [doc.id for doc in aristotle.collection("pets")._query().recursive().get()] + + expected_ids = [ + # Aristotle pets + "Doggy-Dog", + "Floof-Boy", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + def test_watch_query_order(client, cleanup): db = client collection_ref = db.collection("users") diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 65a46d984..ef8022f0e 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -1071,6 +1071,131 @@ async def test_batch(client, cleanup): assert not (await document3.get()).exists +async def test_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Socrates", "favoriteCity": "Athens"}, + "subcollections": { + "pets": [{"name": "Scruffy"}, {"name": "Snowflake"}], + "hobbies": [{"name": "pontificating"}, {"name": "journaling"}], + "philosophers": [{"name": "Aristotle"}, {"name": "Plato"}], + }, + }, + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + { + "data": {"name": "Plato", "favoriteCity": "Corinth"}, + "subcollections": { + "pets": [{"name": "Cuddles"}, {"name": "Sergeant-Puppers"}], + "hobbies": [{"name": "abstraction"}, {"name": "hypotheticals"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}-async" + ) + await ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + await inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + ids = [ + doc.id for doc in await db.collection_group("philosophers").recursive().get() + ] + + expected_ids = [ + # Aristotle doc and subdocs + f"Aristotle{UNIQUE_RESOURCE_ID}-async", + "meditation", + "questioning-stuff", + "Doggy-Dog", + "Floof-Boy", + # Plato doc and subdocs + f"Plato{UNIQUE_RESOURCE_ID}-async", + "abstraction", + "hypotheticals", + "Cuddles", + "Sergeant-Puppers", + # Socrates doc and subdocs + f"Socrates{UNIQUE_RESOURCE_ID}-async", + "journaling", + "pontificating", + "Scruffy", + "Snowflake", + "Aristotle", + "Plato", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + +async def test_nested_recursive_query(client, cleanup): + + philosophers = [ + { + "data": {"name": "Aristotle", "favoriteCity": "Sparta"}, + "subcollections": { + "pets": [{"name": "Floof-Boy"}, {"name": "Doggy-Dog"}], + "hobbies": [{"name": "questioning-stuff"}, {"name": "meditation"}], + }, + }, + ] + + db = client + collection_ref = db.collection("philosophers") + for philosopher in philosophers: + ref = collection_ref.document( + f"{philosopher['data']['name']}{UNIQUE_RESOURCE_ID}-async" + ) + await ref.set(philosopher["data"]) + cleanup(ref.delete) + for col_name, entries in philosopher["subcollections"].items(): + sub_col = ref.collection(col_name) + for entry in entries: + inner_doc_ref = sub_col.document(entry["name"]) + await inner_doc_ref.set(entry) + cleanup(inner_doc_ref.delete) + + aristotle = collection_ref.document(f"Aristotle{UNIQUE_RESOURCE_ID}-async") + ids = [ + doc.id for doc in await aristotle.collection("pets")._query().recursive().get() + ] + + expected_ids = [ + # Aristotle pets + "Doggy-Dog", + "Floof-Boy", + ] + + assert len(ids) == len(expected_ids) + + for index in range(len(ids)): + error_msg = ( + f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" + ) + assert ids[index] == expected_ids[index], error_msg + + async def _chain(*iterators): """Asynchronous reimplementation of `itertools.chain`.""" for iterator in iterators: diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index bf0959e04..33006e254 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -375,6 +375,12 @@ async def test_stream_with_transaction(self, query_class): query_instance = query_class.return_value query_instance.stream.assert_called_once_with(transaction=transaction) + def test_recursive(self): + from google.cloud.firestore_v1.async_query import AsyncQuery + + col = self._make_one("collection") + self.assertIsInstance(col.recursive(), AsyncQuery) + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index a61aaedb2..3fb9a687f 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1151,6 +1151,12 @@ def test_comparator_missing_order_by_field_in_data_raises(self): with self.assertRaisesRegex(ValueError, "Can only compare fields "): query._comparator(doc1, doc2) + def test_multiple_recursive_calls(self): + query = self._make_one(_make_client().collection("asdf")) + self.assertIsInstance( + query.recursive().recursive(), type(query), + ) + class Test__enum_from_op_string(unittest.TestCase): @staticmethod diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index feaec8119..5885a29d9 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -349,3 +349,9 @@ def test_on_snapshot(self, watch): collection = self._make_one("collection") collection.on_snapshot(None) watch.for_query.assert_called_once() + + def test_recursive(self): + from google.cloud.firestore_v1.query import Query + + col = self._make_one("collection") + self.assertIsInstance(col.recursive(), Query)