Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for recursive queries #407

Merged
merged 15 commits into from Aug 11, 2021
Merged
1 change: 1 addition & 0 deletions google/cloud/firestore_v1/_helpers.py
Expand Up @@ -144,6 +144,7 @@ def verify_path(path, is_collection) -> None:
if is_collection:
if num_elements % 2 == 0:
raise ValueError("A collection must have an odd number of path elements")

else:
if num_elements % 2 == 1:
raise ValueError("A document must have an even number of path elements")
Expand Down
18 changes: 17 additions & 1 deletion google/cloud/firestore_v1/async_query.py
Expand Up @@ -22,6 +22,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore

from google.cloud import firestore_v1
from google.cloud.firestore_v1.base_query import (
BaseCollectionGroup,
BaseQuery,
Expand All @@ -32,7 +33,7 @@
)

from google.cloud.firestore_v1 import async_document
from typing import AsyncGenerator
from typing import AsyncGenerator, Type

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
Expand Down Expand Up @@ -92,6 +93,9 @@ class AsyncQuery(BaseQuery):
When false, selects only collections that are immediate children
of the `parent` specified in the containing `RunQueryRequest`.
When true, selects all descendant collections.
recursive (Optional[bool]):
When true, returns all documents and all documents in any subcollections
below them. Defaults to false.
"""

def __init__(
Expand All @@ -106,6 +110,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=False,
recursive=False,
) -> None:
super(AsyncQuery, self).__init__(
parent=parent,
Expand All @@ -118,6 +123,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

async def get(
Expand Down Expand Up @@ -224,6 +230,14 @@ async def stream(
if snapshot is not None:
yield snapshot

@staticmethod
def _get_collection_reference_class() -> Type[
"firestore_v1.async_collection.AsyncCollectionReference"
]:
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference

return AsyncCollectionReference


class AsyncCollectionGroup(AsyncQuery, BaseCollectionGroup):
"""Represents a Collection Group in the Firestore API.
Expand All @@ -249,6 +263,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=True,
recursive=False,
) -> None:
super(AsyncCollectionGroup, self).__init__(
parent=parent,
Expand All @@ -261,6 +276,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

@staticmethod
Expand Down
8 changes: 7 additions & 1 deletion google/cloud/firestore_v1/base_collection.py
Expand Up @@ -124,7 +124,10 @@ def document(self, document_id: str = None) -> DocumentReference:
if document_id is None:
document_id = _auto_id()

child_path = self._path + (document_id,)
# Append `self._path` and the passed document's ID as long as the first
# element in the path is not an empty string, which comes from setting the
# parent to "" for recursive queries.
child_path = self._path + (document_id,) if self._path[0] else (document_id,)
return self._client.document(*child_path)

def _parent_info(self) -> Tuple[Any, str]:
Expand Down Expand Up @@ -200,6 +203,9 @@ def list_documents(
]:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
return self._query().recursive()

def select(self, field_paths: Iterable[str]) -> BaseQuery:
"""Create a "select" query with this collection as parent.

Expand Down
65 changes: 64 additions & 1 deletion google/cloud/firestore_v1/base_query.py
Expand Up @@ -33,7 +33,17 @@
from google.cloud.firestore_v1.types import Cursor
from google.cloud.firestore_v1.types import RunQueryResponse
from google.cloud.firestore_v1.order import Order
from typing import Any, Dict, Generator, Iterable, NoReturn, Optional, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
NoReturn,
Optional,
Tuple,
Type,
Union,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
Expand Down Expand Up @@ -144,6 +154,9 @@ class BaseQuery(object):
When false, selects only collections that are immediate children
of the `parent` specified in the containing `RunQueryRequest`.
When true, selects all descendant collections.
recursive (Optional[bool]):
When true, returns all documents and all documents in any subcollections
below them. Defaults to false.
"""

ASCENDING = "ASCENDING"
Expand All @@ -163,6 +176,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=False,
recursive=False,
) -> None:
self._parent = parent
self._projection = projection
Expand All @@ -174,6 +188,7 @@ def __init__(
self._start_at = start_at
self._end_at = end_at
self._all_descendants = all_descendants
self._recursive = recursive

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand Down Expand Up @@ -247,6 +262,7 @@ def _copy(
start_at: Optional[Tuple[dict, bool]] = _not_passed,
end_at: Optional[Tuple[dict, bool]] = _not_passed,
all_descendants: Optional[bool] = _not_passed,
recursive: Optional[bool] = _not_passed,
) -> "BaseQuery":
return self.__class__(
self._parent,
Expand All @@ -261,6 +277,7 @@ def _copy(
all_descendants=self._evaluate_param(
all_descendants, self._all_descendants
),
recursive=self._evaluate_param(recursive, self._recursive),
)

def _evaluate_param(self, value, fallback_value):
Expand Down Expand Up @@ -813,6 +830,46 @@ def stream(
def on_snapshot(self, callback) -> NoReturn:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
"""Returns a copy of this query whose iterator will yield all matching
documents as well as each of their descendent subcollections and documents.

This differs from the `all_descendents` flag, which only returns descendents
whose subcollection names match the parent collection's name. To return
all descendents, regardless of their subcollection name, use this.
"""
copied = self._copy(recursive=True, all_descendants=True)
if copied._parent and copied._parent.id:
original_collection_id = "/".join(copied._parent._path)

# Reset the parent to nothing so we can recurse through the entire
# database. This is required to have
# `CollectionSelector.collection_id` not override
# `CollectionSelector.all_descendants`, which happens if both are
# set.
copied._parent = copied._get_collection_reference_class()("")
copied._parent._client = self._parent._client

# But wait! We don't want to load the entire database; only the
# collection the user originally specified. To accomplish that, we
# add the following arcane filters.

REFERENCE_NAME_MIN_ID = "__id-9223372036854775808__"
start_at = f"{original_collection_id}/{REFERENCE_NAME_MIN_ID}"

# The backend interprets this null character is flipping the filter
# to mean the end of the range instead of the beginning.
nullChar = "\0"
end_at = f"{original_collection_id}{nullChar}/{REFERENCE_NAME_MIN_ID}"

copied = (
copied.order_by(field_path_module.FieldPath.document_id())
.start_at({field_path_module.FieldPath.document_id(): start_at})
.end_at({field_path_module.FieldPath.document_id(): end_at})
)

return copied

def _comparator(self, doc1, doc2) -> int:
_orders = self._orders

Expand Down Expand Up @@ -1073,6 +1130,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=True,
recursive=False,
) -> None:
if not all_descendants:
raise ValueError("all_descendants must be True for collection group query.")
Expand All @@ -1088,6 +1146,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

def _validate_partition_query(self):
Expand Down Expand Up @@ -1133,6 +1192,10 @@ def get_partitions(
) -> NoReturn:
raise NotImplementedError

@staticmethod
def _get_collection_reference_class() -> Type["BaseCollectionGroup"]:
raise NotImplementedError


class QueryPartition:
"""Represents a bounded partition of a collection group query.
Expand Down
15 changes: 14 additions & 1 deletion google/cloud/firestore_v1/query.py
Expand Up @@ -19,6 +19,7 @@
a more common way to create a query than direct usage of the constructor.
"""

from google.cloud import firestore_v1
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
Expand All @@ -34,7 +35,7 @@

from google.cloud.firestore_v1 import document
from google.cloud.firestore_v1.watch import Watch
from typing import Any, Callable, Generator, List
from typing import Any, Callable, Generator, List, Type


class Query(BaseQuery):
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=False,
recursive=False,
) -> None:
super(Query, self).__init__(
parent=parent,
Expand All @@ -117,6 +119,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

def get(
Expand Down Expand Up @@ -254,6 +257,14 @@ def on_snapshot(docs, changes, read_time):
self, callback, document.DocumentSnapshot, document.DocumentReference
)

@staticmethod
def _get_collection_reference_class() -> Type[
"firestore_v1.collection.CollectionReference"
]:
from google.cloud.firestore_v1.collection import CollectionReference

return CollectionReference


class CollectionGroup(Query, BaseCollectionGroup):
"""Represents a Collection Group in the Firestore API.
Expand All @@ -279,6 +290,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=True,
recursive=False,
) -> None:
super(CollectionGroup, self).__init__(
parent=parent,
Expand All @@ -291,6 +303,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

@staticmethod
Expand Down