From c54de50922b810cac030a71526bf62b9e4785a2f Mon Sep 17 00:00:00 2001 From: Craig Labenz Date: Mon, 29 Mar 2021 13:21:54 -0700 Subject: [PATCH] feat: add firestore bundles (#319) * chore: manual synth * ran synthtool to add bundle proto definitions * beginning of bundle implementation added methods to assemble bundles, but not yet serialize them into length-prefixed json strings with tests for bundle assembly * linting * Added bundle build method * fixed incomplete document id * fixed git merge error * Added first draft of docstrings * Added FirestoreBundle deserialization * Fixed import desync * Improved test coverage for bundles * linting * test coverage * CI happiness * converted redundant exception to assertion * removed todo * Updated comments * linted * Moved query limit type into bundle code * Added typed response for parsing reference values * refactored document reference parsing * removed auto import of bundles from firestore * small tweaks * added tests for document iters * Updated FirestoreBundle imports and synthtool gen * linting * extra test coverage * responses to code review * linting * Fixed stale docstring * camelCased bundle output * updated stale comments * Added test for binary data * linting Co-authored-by: Craig Labenz Co-authored-by: Christopher Wilcox --- google/cloud/firestore_bundle/__init__.py | 3 + google/cloud/firestore_bundle/_helpers.py | 13 + google/cloud/firestore_bundle/bundle.py | 362 ++++++++++++++ google/cloud/firestore_v1/_helpers.py | 269 +++++++++- google/cloud/firestore_v1/base_document.py | 6 +- google/cloud/firestore_v1/base_query.py | 7 +- google/cloud/firestore_v1/query.py | 7 +- synth.py | 12 + tests/unit/v1/_test_helpers.py | 84 ++++ tests/unit/v1/test__helpers.py | 68 +++ tests/unit/v1/test_bundle.py | 554 +++++++++++++++++++++ tests/unit/v1/test_collection.py | 29 +- 12 files changed, 1376 insertions(+), 38 deletions(-) create mode 100644 google/cloud/firestore_bundle/_helpers.py create mode 100644 google/cloud/firestore_bundle/bundle.py create mode 100644 tests/unit/v1/_test_helpers.py create mode 100644 tests/unit/v1/test_bundle.py diff --git a/google/cloud/firestore_bundle/__init__.py b/google/cloud/firestore_bundle/__init__.py index 75cf63e02..d1ffaeff5 100644 --- a/google/cloud/firestore_bundle/__init__.py +++ b/google/cloud/firestore_bundle/__init__.py @@ -21,6 +21,8 @@ from .types.bundle import BundledQuery from .types.bundle import NamedQuery +from .bundle import FirestoreBundle + __all__ = ( "BundleElement", @@ -28,4 +30,5 @@ "BundledDocumentMetadata", "NamedQuery", "BundledQuery", + "FirestoreBundle", ) diff --git a/google/cloud/firestore_bundle/_helpers.py b/google/cloud/firestore_bundle/_helpers.py new file mode 100644 index 000000000..8b7ce7a69 --- /dev/null +++ b/google/cloud/firestore_bundle/_helpers.py @@ -0,0 +1,13 @@ +from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_bundle.types import BundledQuery + + +def limit_type_of_query(query: BaseQuery) -> int: + """BundledQuery.LimitType equivalent of this query. + """ + + return ( + BundledQuery.LimitType.LAST + if query._limit_to_last + else BundledQuery.LimitType.FIRST + ) diff --git a/google/cloud/firestore_bundle/bundle.py b/google/cloud/firestore_bundle/bundle.py new file mode 100644 index 000000000..eae1fa3f4 --- /dev/null +++ b/google/cloud/firestore_bundle/bundle.py @@ -0,0 +1,362 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing bundles for the Google Cloud Firestore API.""" + +import datetime +import json + +from google.cloud.firestore_bundle.types.bundle import ( + BundledDocumentMetadata, + BundledQuery, + BundleElement, + BundleMetadata, + NamedQuery, +) +from google.cloud._helpers import _datetime_to_pb_timestamp, UTC # type: ignore +from google.cloud.firestore_bundle._helpers import limit_type_of_query +from google.cloud.firestore_v1.async_query import AsyncQuery +from google.cloud.firestore_v1.base_client import BaseClient +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1 import _helpers +from google.protobuf.timestamp_pb2 import Timestamp # type: ignore +from google.protobuf import json_format # type: ignore +from typing import ( + Dict, + List, + Optional, + Union, +) + + +class FirestoreBundle: + """A group of serialized documents and queries, suitable for + longterm storage or query resumption. + + If any queries are added to this bundle, all associated documents will be + loaded and stored in memory for serialization. + + Usage: + + from google.cloud.firestore import Client + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore import _helpers + + db = Client() + bundle = FirestoreBundle('my-bundle') + bundle.add_named_query('all-users', db.collection('users')._query()) + bundle.add_named_query( + 'top-ten-hamburgers', + db.collection('hamburgers').limit(limit=10)._query(), + ) + serialized: str = bundle.build() + + # Store somewhere like your GCS for retrieval by a client SDK. + + Args: + name (str): The Id of the bundle. + """ + + BUNDLE_SCHEMA_VERSION: int = 1 + + def __init__(self, name: str) -> None: + self.name: str = name + self.documents: Dict[str, "_BundledDocument"] = {} + self.named_queries: Dict[str, NamedQuery] = {} + self.latest_read_time: Timestamp = Timestamp(seconds=0, nanos=0) + self._deserialized_metadata: Optional[BundledDocumentMetadata] = None + + def add_document(self, snapshot: DocumentSnapshot) -> "FirestoreBundle": + """Adds a document to the bundle. + + Args: + snapshot (DocumentSnapshot): The fully-loaded Firestore document to + be preserved. + + Example: + + from google.cloud import firestore + + db = firestore.Client() + collection_ref = db.collection(u'users') + + bundle = firestore.FirestoreBundle('my bundle') + bundle.add_document(collection_ref.documents('some_id').get()) + + Returns: + FirestoreBundle: self + """ + original_document: Optional[_BundledDocument] + original_queries: Optional[List[str]] = [] + full_document_path: str = snapshot.reference._document_path + + original_document = self.documents.get(full_document_path) + if original_document: + original_queries = original_document.metadata.queries # type: ignore + + should_use_snaphot: bool = ( + original_document is None + # equivalent to: + # `if snapshot.read_time > original_document.snapshot.read_time` + or _helpers.compare_timestamps( + snapshot.read_time, original_document.snapshot.read_time, + ) + >= 0 + ) + + if should_use_snaphot: + self.documents[full_document_path] = _BundledDocument( + snapshot=snapshot, + metadata=BundledDocumentMetadata( + name=full_document_path, + read_time=snapshot.read_time, + exists=snapshot.exists, + queries=original_queries, + ), + ) + + self._update_last_read_time(snapshot.read_time) + self._reset_metadata() + return self + + def add_named_query(self, name: str, query: BaseQuery) -> "FirestoreBundle": + """Adds a query to the bundle, referenced by the provided name. + + Args: + name (str): The name by which the provided query should be referenced. + query (Query): Query of documents to be fully loaded and stored in + the bundle for future access. + + Example: + + from google.cloud import firestore + + db = firestore.Client() + collection_ref = db.collection(u'users') + + bundle = firestore.FirestoreBundle('my bundle') + bundle.add_named_query('all the users', collection_ref._query()) + + Returns: + FirestoreBundle: self + + Raises: + ValueError: If anything other than a BaseQuery (e.g., a Collection) + is supplied. If you have a Collection, call its `_query()` + method to get what this method expects. + ValueError: If the supplied name has already been added. + """ + if not isinstance(query, BaseQuery): + raise ValueError( + "Attempted to add named query of type: " + f"{type(query).__name__}. Expected BaseQuery.", + ) + + if name in self.named_queries: + raise ValueError(f"Query name conflict: {name} has already been added.") + + # Execute the query and save each resulting document + _read_time = self._save_documents_from_query(query, query_name=name) + + # Actually save the query to our local object cache + self._save_named_query(name, query, _read_time) + self._reset_metadata() + return self + + def _save_documents_from_query( + self, query: BaseQuery, query_name: str + ) -> datetime.datetime: + _read_time = datetime.datetime.min.replace(tzinfo=UTC) + if isinstance(query, AsyncQuery): + import asyncio + + loop = asyncio.get_event_loop() + return loop.run_until_complete(self._process_async_query(query, query_name)) + + # `query` is now known to be a non-async `BaseQuery` + doc: DocumentSnapshot + for doc in query.stream(): # type: ignore + self.add_document(doc) + bundled_document = self.documents.get(doc.reference._document_path) + bundled_document.metadata.queries.append(query_name) # type: ignore + _read_time = doc.read_time + return _read_time + + def _save_named_query( + self, name: str, query: BaseQuery, read_time: datetime.datetime, + ) -> None: + self.named_queries[name] = self._build_named_query( + name=name, snapshot=query, read_time=read_time, + ) + self._update_last_read_time(read_time) + + async def _process_async_query( + self, snapshot: AsyncQuery, query_name: str, + ) -> datetime.datetime: + doc: DocumentSnapshot + _read_time = datetime.datetime.min.replace(tzinfo=UTC) + async for doc in snapshot.stream(): + self.add_document(doc) + bundled_document = self.documents.get(doc.reference._document_path) + bundled_document.metadata.queries.append(query_name) # type: ignore + _read_time = doc.read_time + return _read_time + + def _build_named_query( + self, name: str, snapshot: BaseQuery, read_time: datetime.datetime, + ) -> NamedQuery: + return NamedQuery( + name=name, + bundled_query=BundledQuery( + parent=name, + structured_query=snapshot._to_protobuf()._pb, + limit_type=limit_type_of_query(snapshot), + ), + read_time=_helpers.build_timestamp(read_time), + ) + + def _update_last_read_time( + self, read_time: Union[datetime.datetime, Timestamp] + ) -> None: + _ts: Timestamp = ( + read_time + if isinstance(read_time, Timestamp) + else _datetime_to_pb_timestamp(read_time) + ) + + # if `_ts` is greater than `self.latest_read_time` + if _helpers.compare_timestamps(_ts, self.latest_read_time) == 1: + self.latest_read_time = _ts + + def _add_bundle_element(self, bundle_element: BundleElement, *, client: BaseClient, type: str): # type: ignore + """Applies BundleElements to this FirestoreBundle instance as a part of + deserializing a FirestoreBundle string. + """ + from google.cloud.firestore_v1.types.document import Document + + if getattr(self, "_doc_metadata_map", None) is None: + self._doc_metadata_map = {} + if type == "metadata": + self._deserialized_metadata = bundle_element.metadata # type: ignore + elif type == "namedQuery": + self.named_queries[bundle_element.named_query.name] = bundle_element.named_query # type: ignore + elif type == "documentMetadata": + self._doc_metadata_map[ + bundle_element.document_metadata.name + ] = bundle_element.document_metadata + elif type == "document": + doc_ref_value = _helpers.DocumentReferenceValue( + bundle_element.document.name + ) + snapshot = DocumentSnapshot( + data=_helpers.decode_dict( + Document(mapping=bundle_element.document).fields, client + ), + exists=True, + reference=DocumentReference( + doc_ref_value.collection_name, + doc_ref_value.document_id, + client=client, + ), + read_time=self._doc_metadata_map[ + bundle_element.document.name + ].read_time, + create_time=bundle_element.document.create_time, # type: ignore + update_time=bundle_element.document.update_time, # type: ignore + ) + self.add_document(snapshot) + + bundled_document = self.documents.get(snapshot.reference._document_path) + for query_name in self._doc_metadata_map[ + bundle_element.document.name + ].queries: + bundled_document.metadata.queries.append(query_name) # type: ignore + else: + raise ValueError(f"Unexpected type of BundleElement: {type}") + + def build(self) -> str: + """Iterates over the bundle's stored documents and queries and produces + a single length-prefixed json string suitable for long-term storage. + + Example: + + from google.cloud import firestore + + db = firestore.Client() + collection_ref = db.collection(u'users') + + bundle = firestore.FirestoreBundle('my bundle') + bundle.add_named_query('app-users', collection_ref._query()) + + serialized_bundle: str = bundle.build() + + # Now upload `serialized_bundle` to Google Cloud Storage, store it + # in Memorystore, or any other storage solution. + + Returns: + str: The length-prefixed string representation of this bundle' + contents. + """ + buffer: str = "" + + named_query: NamedQuery + for named_query in self.named_queries.values(): + buffer += self._compile_bundle_element( + BundleElement(named_query=named_query) + ) + + bundled_document: "_BundledDocument" # type: ignore + document_count: int = 0 + for bundled_document in self.documents.values(): + buffer += self._compile_bundle_element( + BundleElement(document_metadata=bundled_document.metadata) + ) + document_count += 1 + buffer += self._compile_bundle_element( + BundleElement(document=bundled_document.snapshot._to_protobuf()._pb,) + ) + + metadata: BundleElement = BundleElement( + metadata=self._deserialized_metadata + or BundleMetadata( + id=self.name, + create_time=_helpers.build_timestamp(), + version=FirestoreBundle.BUNDLE_SCHEMA_VERSION, + total_documents=document_count, + total_bytes=len(buffer.encode("utf-8")), + ) + ) + return f"{self._compile_bundle_element(metadata)}{buffer}" + + def _compile_bundle_element(self, bundle_element: BundleElement) -> str: + serialized_be = json.dumps(json_format.MessageToDict(bundle_element._pb)) + return f"{len(serialized_be)}{serialized_be}" + + def _reset_metadata(self): + """Hydrating bundles stores cached data we must reset anytime new + queries or documents are added""" + self._deserialized_metadata = None + + +class _BundledDocument: + """Convenience class to hold both the metadata and the actual content + of a document to be bundled.""" + + def __init__( + self, snapshot: DocumentSnapshot, metadata: BundledDocumentMetadata, + ) -> None: + self.snapshot = snapshot + self.metadata = metadata diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 89cf3b002..aebdbee47 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -15,7 +15,9 @@ """Common helpers shared across Google Cloud Firestore modules.""" import datetime +import json +import google from google.api_core.datetime_helpers import DatetimeWithNanoseconds # type: ignore from google.api_core import gapic_v1 # type: ignore from google.protobuf import struct_pb2 @@ -32,7 +34,18 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import write -from typing import Any, Generator, List, NoReturn, Optional, Tuple, Union +from google.protobuf.timestamp_pb2 import Timestamp # type: ignore +from typing import ( + Any, + Dict, + Generator, + Iterator, + List, + NoReturn, + Optional, + Tuple, + Union, +) _EmptyDict: transforms.Sentinel _GRPC_ERROR_MAPPING: dict @@ -219,6 +232,72 @@ def encode_dict(values_dict) -> dict: return {key: encode_value(value) for key, value in values_dict.items()} +def document_snapshot_to_protobuf(snapshot: "google.cloud.firestore_v1.base_document.DocumentSnapshot") -> Optional["google.cloud.firestore_v1.types.Document"]: # type: ignore + from google.cloud.firestore_v1.types import Document + + if not snapshot.exists: + return None + + return Document( + name=snapshot.reference._document_path, + fields=encode_dict(snapshot._data), + create_time=snapshot.create_time, + update_time=snapshot.update_time, + ) + + +class DocumentReferenceValue: + """DocumentReference path container with accessors for each relevant chunk. + + Usage: + doc_ref_val = DocumentReferenceValue( + 'projects/my-proj/databases/(default)/documents/my-col/my-doc', + ) + assert doc_ref_val.project_name == 'my-proj' + assert doc_ref_val.collection_name == 'my-col' + assert doc_ref_val.document_id == 'my-doc' + assert doc_ref_val.database_name == '(default)' + + Raises: + ValueError: If the supplied value cannot satisfy a complete path. + """ + + def __init__(self, reference_value: str): + self._reference_value = reference_value + + # The first 5 parts are + # projects, {project}, databases, {database}, documents + parts = reference_value.split(DOCUMENT_PATH_DELIMITER) + if len(parts) < 7: + msg = BAD_REFERENCE_ERROR.format(reference_value) + raise ValueError(msg) + + self.project_name = parts[1] + self.collection_name = parts[5] + self.database_name = parts[3] + self.document_id = "/".join(parts[6:]) + + @property + def full_key(self) -> str: + """Computed property for a DocumentReference's collection_name and + document Id""" + return "/".join([self.collection_name, self.document_id]) + + @property + def full_path(self) -> str: + return self._reference_value or "/".join( + [ + "projects", + self.project_name, + "databases", + self.database_name, + "documents", + self.collection_name, + self.document_id, + ] + ) + + def reference_value_to_document(reference_value, client) -> Any: """Convert a reference value string to a document. @@ -237,15 +316,11 @@ def reference_value_to_document(reference_value, client) -> Any: ValueError: If the ``reference_value`` does not come from the same project / database combination as the ``client``. """ - # The first 5 parts are - # projects, {project}, databases, {database}, documents - parts = reference_value.split(DOCUMENT_PATH_DELIMITER, 5) - if len(parts) != 6: - msg = BAD_REFERENCE_ERROR.format(reference_value) - raise ValueError(msg) + from google.cloud.firestore_v1.base_document import BaseDocumentReference - # The sixth part is `a/b/c/d` (i.e. the document path) - document = client.document(parts[-1]) + doc_ref_value = DocumentReferenceValue(reference_value) + + document: BaseDocumentReference = client.document(doc_ref_value.full_key) if document._document_path != reference_value: msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string) raise ValueError(msg) @@ -1041,3 +1116,179 @@ def make_retry_timeout_kwargs(retry, timeout) -> dict: kwargs["timeout"] = timeout return kwargs + + +def build_timestamp( + dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None +) -> Timestamp: + """Returns the supplied datetime (or "now") as a Timestamp""" + return _datetime_to_pb_timestamp(dt or DatetimeWithNanoseconds.utcnow()) + + +def compare_timestamps( + ts1: Union[Timestamp, datetime.datetime], ts2: Union[Timestamp, datetime.datetime], +) -> int: + ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1 + ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2 + ts1_nanos = ts1.nanos + ts1.seconds * 1e9 + ts2_nanos = ts2.nanos + ts2.seconds * 1e9 + if ts1_nanos == ts2_nanos: + return 0 + return 1 if ts1_nanos > ts2_nanos else -1 + + +def deserialize_bundle( + serialized: Union[str, bytes], + client: "google.cloud.firestore_v1.client.BaseClient", # type: ignore +) -> "google.cloud.firestore_bundle.FirestoreBundle": # type: ignore + """Inverse operation to a `FirestoreBundle` instance's `build()` method. + + Args: + serialized (Union[str, bytes]): The result of `FirestoreBundle.build()`. + Should be a list of dictionaries in string format. + client (BaseClient): A connected Client instance. + + Returns: + FirestoreBundle: A bundle equivalent to that which called `build()` and + initially created the `serialized` value. + + Raises: + ValueError: If any of the dictionaries in the list contain any more than + one top-level key. + ValueError: If any unexpected BundleElement types are encountered. + ValueError: If the serialized bundle ends before expected. + """ + from google.cloud.firestore_bundle import BundleElement, FirestoreBundle + + # Outlines the legal transitions from one BundleElement to another. + bundle_state_machine = { + "__initial__": ["metadata"], + "metadata": ["namedQuery", "documentMetadata", "__end__"], + "namedQuery": ["namedQuery", "documentMetadata", "__end__"], + "documentMetadata": ["document"], + "document": ["documentMetadata", "__end__"], + } + allowed_next_element_types: List[str] = bundle_state_machine["__initial__"] + + # This must be saved and added last, since we cache it to preserve timestamps, + # yet must flush it whenever a new document or query is added to a bundle. + # The process of deserializing a bundle uses these methods which flush a + # cached metadata element, and thus, it must be the last BundleElement + # added during deserialization. + metadata_bundle_element: Optional[BundleElement] = None + + bundle: Optional[FirestoreBundle] = None + data: Dict + for data in _parse_bundle_elements_data(serialized): + + # BundleElements are serialized as JSON containing one key outlining + # the type, with all further data nested under that key + keys: List[str] = list(data.keys()) + + if len(keys) != 1: + raise ValueError("Expected serialized BundleElement with one top-level key") + + key: str = keys[0] + + if key not in allowed_next_element_types: + raise ValueError( + f"Encountered BundleElement of type {key}. " + f"Expected one of {allowed_next_element_types}" + ) + + # Create and add our BundleElement + bundle_element: BundleElement + try: + bundle_element: BundleElement = BundleElement.from_json(json.dumps(data)) # type: ignore + except AttributeError as e: + # Some bad serialization formats cannot be universally deserialized. + if e.args[0] == "'dict' object has no attribute 'find'": + raise ValueError( + "Invalid serialization of datetimes. " + "Cannot deserialize Bundles created from the NodeJS SDK." + ) + raise e # pragma: NO COVER + + if bundle is None: + # This must be the first bundle type encountered + assert key == "metadata" + bundle = FirestoreBundle(data[key]["id"]) + metadata_bundle_element = bundle_element + + else: + bundle._add_bundle_element(bundle_element, client=client, type=key) + + # Update the allowed next BundleElement types + allowed_next_element_types = bundle_state_machine[key] + + if "__end__" not in allowed_next_element_types: + raise ValueError("Unexpected end to serialized FirestoreBundle") + + # Now, finally add the metadata element + bundle._add_bundle_element( + metadata_bundle_element, client=client, type="metadata", # type: ignore + ) + + return bundle + + +def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict, None, None]: # type: ignore + """Reads through a serialized FirestoreBundle and yields JSON chunks that + were created via `BundleElement.to_json(bundle_element)`. + + Serialized FirestoreBundle instances are length-prefixed JSON objects, and + so are of the form "123{...}57{...}" + To correctly and safely read a bundle, we must first detect these length + prefixes, read that many bytes of data, and attempt to JSON-parse that. + + Raises: + ValueError: If a chunk of JSON ever starts without following a length + prefix. + """ + _serialized: Iterator[int] = iter( + serialized if isinstance(serialized, bytes) else serialized.encode("utf-8") + ) + + length_prefix: str = "" + while True: + byte: Optional[int] = next(_serialized, None) + + if byte is None: + return None + + _str: str = chr(byte) + if _str.isnumeric(): + length_prefix += _str + else: + if length_prefix == "": + raise ValueError("Expected length prefix") + + _length_prefix = int(length_prefix) + length_prefix = "" + _bytes = bytearray([byte]) + _counter = 1 + while _counter < _length_prefix: + _bytes.append(next(_serialized)) + _counter += 1 + + yield json.loads(_bytes.decode("utf-8")) + + +def _get_documents_from_bundle( + bundle, *, query_name: Optional[str] = None +) -> Generator["google.cloud.firestore.DocumentSnapshot", None, None]: # type: ignore + from google.cloud.firestore_bundle.bundle import _BundledDocument + + bundled_doc: _BundledDocument + for bundled_doc in bundle.documents.values(): + if query_name and query_name not in bundled_doc.metadata.queries: + continue + yield bundled_doc.snapshot + + +def _get_document_from_bundle( + bundle, *, document_id: str, +) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore + bundled_doc = bundle.documents.get(document_id) + if bundled_doc: + return bundled_doc.snapshot diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 2438409b7..32694ac47 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -18,6 +18,7 @@ from google.api_core import retry as retries # type: ignore +from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import field_path as field_path_module from google.cloud.firestore_v1.types import common @@ -25,7 +26,7 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import write -from typing import Any, Dict, Iterable, NoReturn, Union, Tuple +from typing import Any, Dict, Iterable, NoReturn, Optional, Union, Tuple class BaseDocumentReference(object): @@ -491,6 +492,9 @@ def to_dict(self) -> Union[Dict[str, Any], None]: return None return copy.deepcopy(self._data) + def _to_protobuf(self) -> Optional[Document]: + return _helpers.document_snapshot_to_protobuf(self) + def _get_document_path(client, path: Tuple[str]) -> str: """Convert a path tuple into a full path string. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 27897ee23..564483b5e 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -33,7 +33,7 @@ from google.cloud.firestore_v1.types import Cursor from google.cloud.firestore_v1.types import RunQueryResponse from google.cloud.firestore_v1.order import Order -from typing import Any, Dict, Iterable, NoReturn, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterable, NoReturn, Optional, Tuple, Union # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot @@ -804,12 +804,11 @@ def _to_protobuf(self) -> StructuredQuery: query_kwargs["offset"] = self._offset if self._limit is not None: query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) - return query.StructuredQuery(**query_kwargs) def get( self, transaction=None, retry: retries.Retry = None, timeout: float = None, - ) -> NoReturn: + ) -> Iterable[DocumentSnapshot]: raise NotImplementedError def _prep_stream( @@ -834,7 +833,7 @@ def _prep_stream( def stream( self, transaction=None, retry: retries.Retry = None, timeout: float = None, - ) -> NoReturn: + ) -> Generator[document.DocumentSnapshot, Any, None]: raise NotImplementedError def on_snapshot(self, callback) -> NoReturn: diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 1716999be..aa2f5ad09 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -19,6 +19,7 @@ a more common way to create a query than direct usage of the constructor. """ +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -33,9 +34,7 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any -from typing import Callable -from typing import Generator +from typing import Any, Callable, Generator, List class Query(BaseQuery): @@ -125,7 +124,7 @@ def get( transaction=None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> list: + ) -> List[DocumentSnapshot]: """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and returns a list of documents diff --git a/synth.py b/synth.py index e5626d223..b4fa23153 100644 --- a/synth.py +++ b/synth.py @@ -247,6 +247,18 @@ def lint_setup_py(session): """, ) +s.replace( + "google/cloud/firestore_bundle/__init__.py", + "from .types.bundle import NamedQuery\n", + "from .types.bundle import NamedQuery\n\nfrom .bundle import FirestoreBundle\n", +) + +s.replace( + "google/cloud/firestore_bundle/__init__.py", + "\'BundledQuery\',", + "\"BundledQuery\",\n \"FirestoreBundle\",", +) + s.shell.run(["nox", "-s", "blacken"], hide_output=False) s.replace( diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py new file mode 100644 index 000000000..65aece0d4 --- /dev/null +++ b/tests/unit/v1/_test_helpers.py @@ -0,0 +1,84 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import mock +import typing + +import google +from google.cloud.firestore_v1.base_client import BaseClient +from google.cloud.firestore_v1.document import DocumentReference, DocumentSnapshot +from google.cloud._helpers import _datetime_to_pb_timestamp, UTC # type: ignore +from google.cloud.firestore_v1._helpers import build_timestamp +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1.client import Client +from google.protobuf.timestamp_pb2 import Timestamp # type: ignore + + +def make_test_credentials() -> google.auth.credentials.Credentials: # type: ignore + import google.auth.credentials # type: ignore + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def make_client(project_name: typing.Optional[str] = None) -> Client: + return Client( + project=project_name or "project-project", credentials=make_test_credentials(), + ) + + +def make_async_client() -> AsyncClient: + return AsyncClient(project="project-project", credentials=make_test_credentials()) + + +def build_test_timestamp( + year: int = 2021, + month: int = 1, + day: int = 1, + hour: int = 12, + minute: int = 0, + second: int = 0, +) -> Timestamp: + return _datetime_to_pb_timestamp( + datetime.datetime( + year=year, + month=month, + day=day, + hour=hour, + minute=minute, + second=second, + tzinfo=UTC, + ), + ) + + +def build_document_snapshot( + *, + collection_name: str = "col", + document_id: str = "doc", + client: typing.Optional[BaseClient] = None, + data: typing.Optional[typing.Dict] = None, + exists: bool = True, + create_time: typing.Optional[Timestamp] = None, + read_time: typing.Optional[Timestamp] = None, + update_time: typing.Optional[Timestamp] = None, +) -> DocumentSnapshot: + return DocumentSnapshot( + DocumentReference(collection_name, document_id, client=client), + data or {"hello", "world"}, + exists=exists, + read_time=read_time or build_timestamp(), + create_time=create_time or build_timestamp(), + update_time=update_time or build_timestamp(), + ) diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index 82fbfcf12..f558f3fe9 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -388,6 +388,74 @@ def test_different_client(self): self.assertEqual(exc_info.exception.args, (err_msg,)) +class TestDocumentReferenceValue(unittest.TestCase): + @staticmethod + def _call(ref_value: str): + from google.cloud.firestore_v1._helpers import DocumentReferenceValue + + return DocumentReferenceValue(ref_value) + + def test_normal(self): + orig = "projects/name/databases/(default)/documents/col/doc" + parsed = self._call(orig) + self.assertEqual(parsed.collection_name, "col") + self.assertEqual(parsed.database_name, "(default)") + self.assertEqual(parsed.document_id, "doc") + + self.assertEqual(parsed.full_path, orig) + parsed._reference_value = None # type: ignore + self.assertEqual(parsed.full_path, orig) + + def test_nested(self): + parsed = self._call( + "projects/name/databases/(default)/documents/col/doc/nested" + ) + self.assertEqual(parsed.collection_name, "col") + self.assertEqual(parsed.database_name, "(default)") + self.assertEqual(parsed.document_id, "doc/nested") + + def test_broken(self): + self.assertRaises( + ValueError, self._call, "projects/name/databases/(default)/documents/col", + ) + + +class Test_document_snapshot_to_protobuf(unittest.TestCase): + def test_real_snapshot(self): + from google.cloud.firestore_v1._helpers import document_snapshot_to_protobuf + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.document import DocumentReference + from google.protobuf import timestamp_pb2 # type: ignore + + client = _make_client() + snapshot = DocumentSnapshot( + data={"hello": "world"}, + reference=DocumentReference("col", "doc", client=client), + exists=True, + read_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), + update_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), + create_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), + ) + self.assertIsInstance(document_snapshot_to_protobuf(snapshot), Document) + + def test_non_existant_snapshot(self): + from google.cloud.firestore_v1._helpers import document_snapshot_to_protobuf + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.document import DocumentReference + + client = _make_client() + snapshot = DocumentSnapshot( + data=None, + reference=DocumentReference("col", "doc", client=client), + exists=False, + read_time=None, + update_time=None, + create_time=None, + ) + self.assertIsNone(document_snapshot_to_protobuf(snapshot)) + + class Test_decode_value(unittest.TestCase): @staticmethod def _call_fut(value, client=mock.sentinel.client): diff --git a/tests/unit/v1/test_bundle.py b/tests/unit/v1/test_bundle.py new file mode 100644 index 000000000..4332a92fa --- /dev/null +++ b/tests/unit/v1/test_bundle.py @@ -0,0 +1,554 @@ +# -*- coding: utf-8 -*- +# +# # Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import typing +import unittest + +import mock +from google.cloud.firestore_bundle import BundleElement, FirestoreBundle +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.async_collection import AsyncCollectionReference +from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.collection import CollectionReference +from google.cloud.firestore_v1.query import Query +from google.cloud.firestore_v1.services.firestore.client import FirestoreClient +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse +from google.protobuf.timestamp_pb2 import Timestamp # type: ignore +from tests.unit.v1 import _test_helpers +from tests.unit.v1 import test__helpers + + +class _CollectionQueryMixin: + + # Path to each document where we don't specify custom collection names or + # document Ids + doc_key: str = "projects/project-project/databases/(default)/documents/col/doc" + + @staticmethod + def build_results_iterable(items): + raise NotImplementedError() + + @staticmethod + def get_collection_class(): + raise NotImplementedError() + + @staticmethod + def get_internal_client_mock(): + raise NotImplementedError() + + @staticmethod + def get_client(): + raise NotImplementedError() + + def _bundled_collection_helper( + self, + document_ids: typing.Optional[typing.List[str]] = None, + data: typing.Optional[typing.List[typing.Dict]] = None, + ) -> CollectionReference: + """Builder of a mocked Query for the sake of testing Bundles. + + Bundling queries involves loading the actual documents for cold storage, + and this method arranges all of the necessary mocks so that unit tests + can think they are evaluating a live query. + """ + client = self.get_client() + template = client._database_string + "/documents/col/{}" + document_ids = document_ids or ["doc-1", "doc-2"] + + def _index_from_data(index: int): + if data is None or len(data) < index + 1: + return None + return data[index] + + documents = [ + RunQueryResponse( + transaction=b"", + document=Document( + name=template.format(document_id), + fields=_helpers.encode_dict( + _index_from_data(index) or {"hello": "world"} + ), + create_time=Timestamp(seconds=1, nanos=1), + update_time=Timestamp(seconds=1, nanos=1), + ), + read_time=_test_helpers.build_timestamp(), + ) + for index, document_id in enumerate(document_ids) + ] + iterator = self.build_results_iterable(documents) + api_client = self.get_internal_client_mock() + api_client.run_query.return_value = iterator + client._firestore_api_internal = api_client + return self.get_collection_class()("col", client=client) + + def _bundled_query_helper( + self, + document_ids: typing.Optional[typing.List[str]] = None, + data: typing.Optional[typing.List[typing.Dict]] = None, + ) -> BaseQuery: + return self._bundled_collection_helper( + document_ids=document_ids, data=data, + )._query() + + +class TestBundle(_CollectionQueryMixin, unittest.TestCase): + @staticmethod + def build_results_iterable(items): + return iter(items) + + @staticmethod + def get_client(): + return _test_helpers.make_client() + + @staticmethod + def get_internal_client_mock(): + return mock.create_autospec(FirestoreClient) + + @classmethod + def get_collection_class(cls): + return CollectionReference + + def test_add_document(self): + bundle = FirestoreBundle("test") + doc = _test_helpers.build_document_snapshot(client=_test_helpers.make_client()) + bundle.add_document(doc) + self.assertEqual(bundle.documents[self.doc_key].snapshot, doc) + + def test_add_newer_document(self): + bundle = FirestoreBundle("test") + old_doc = _test_helpers.build_document_snapshot( + data={"version": 1}, + client=_test_helpers.make_client(), + read_time=Timestamp(seconds=1, nanos=1), + ) + bundle.add_document(old_doc) + self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 1) + + # Builds the same ID by default + new_doc = _test_helpers.build_document_snapshot( + data={"version": 2}, + client=_test_helpers.make_client(), + read_time=Timestamp(seconds=1, nanos=2), + ) + bundle.add_document(new_doc) + self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 2) + + def test_add_older_document(self): + bundle = FirestoreBundle("test") + new_doc = _test_helpers.build_document_snapshot( + data={"version": 2}, + client=_test_helpers.make_client(), + read_time=Timestamp(seconds=1, nanos=2), + ) + bundle.add_document(new_doc) + self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 2) + + # Builds the same ID by default + old_doc = _test_helpers.build_document_snapshot( + data={"version": 1}, + client=_test_helpers.make_client(), + read_time=Timestamp(seconds=1, nanos=1), + ) + bundle.add_document(old_doc) + self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 2) + + def test_add_document_with_different_read_times(self): + bundle = FirestoreBundle("test") + doc = _test_helpers.build_document_snapshot( + client=_test_helpers.make_client(), + data={"version": 1}, + read_time=_test_helpers.build_test_timestamp(second=1), + ) + # Create another reference to the same document, but with new + # data and a more recent `read_time` + doc_refreshed = _test_helpers.build_document_snapshot( + client=_test_helpers.make_client(), + data={"version": 2}, + read_time=_test_helpers.build_test_timestamp(second=2), + ) + + bundle.add_document(doc) + self.assertEqual( + bundle.documents[self.doc_key].snapshot._data, {"version": 1}, + ) + bundle.add_document(doc_refreshed) + self.assertEqual( + bundle.documents[self.doc_key].snapshot._data, {"version": 2}, + ) + + def test_add_query(self): + query = self._bundled_query_helper() + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", query) + self.assertIsNotNone(bundle.named_queries.get("asdf")) + self.assertIsNotNone( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-1" + ] + ) + self.assertIsNotNone( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-2" + ] + ) + + def test_add_query_twice(self): + query = self._bundled_query_helper() + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", query) + self.assertRaises(ValueError, bundle.add_named_query, "asdf", query) + + def test_adding_collection_raises_error(self): + col = self._bundled_collection_helper() + bundle = FirestoreBundle("test") + self.assertRaises(ValueError, bundle.add_named_query, "asdf", col) + + def test_bundle_build(self): + bundle = FirestoreBundle("test") + bundle.add_named_query("best name", self._bundled_query_helper()) + self.assertIsInstance(bundle.build(), str) + + def test_get_documents(self): + bundle = FirestoreBundle("test") + query: Query = self._bundled_query_helper() # type: ignore + bundle.add_named_query("sweet query", query) + docs_iter = _helpers._get_documents_from_bundle( + bundle, query_name="sweet query" + ) + doc = next(docs_iter) + self.assertEqual(doc.id, "doc-1") + doc = next(docs_iter) + self.assertEqual(doc.id, "doc-2") + + # Now an empty one + docs_iter = _helpers._get_documents_from_bundle( + bundle, query_name="wrong query" + ) + doc = next(docs_iter, None) + self.assertIsNone(doc) + + def test_get_documents_two_queries(self): + bundle = FirestoreBundle("test") + query: Query = self._bundled_query_helper() # type: ignore + bundle.add_named_query("sweet query", query) + + query: Query = self._bundled_query_helper(document_ids=["doc-3", "doc-4"]) # type: ignore + bundle.add_named_query("second query", query) + + docs_iter = _helpers._get_documents_from_bundle( + bundle, query_name="sweet query" + ) + doc = next(docs_iter) + self.assertEqual(doc.id, "doc-1") + doc = next(docs_iter) + self.assertEqual(doc.id, "doc-2") + + docs_iter = _helpers._get_documents_from_bundle( + bundle, query_name="second query" + ) + doc = next(docs_iter) + self.assertEqual(doc.id, "doc-3") + doc = next(docs_iter) + self.assertEqual(doc.id, "doc-4") + + def test_get_document(self): + bundle = FirestoreBundle("test") + query: Query = self._bundled_query_helper() # type: ignore + bundle.add_named_query("sweet query", query) + + self.assertIsNotNone( + _helpers._get_document_from_bundle( + bundle, + document_id="projects/project-project/databases/(default)/documents/col/doc-1", + ), + ) + + self.assertIsNone( + _helpers._get_document_from_bundle( + bundle, + document_id="projects/project-project/databases/(default)/documents/col/doc-0", + ), + ) + + +class TestAsyncBundle(_CollectionQueryMixin, unittest.TestCase): + @staticmethod + def get_client(): + return _test_helpers.make_async_client() + + @staticmethod + def build_results_iterable(items): + return test__helpers.AsyncIter(items) + + @staticmethod + def get_internal_client_mock(): + return test__helpers.AsyncMock(spec=["run_query"]) + + @classmethod + def get_collection_class(cls): + return AsyncCollectionReference + + def test_async_query(self): + # Create an async query, but this test does not need to be + # marked as async by pytest because `bundle.add_named_query()` + # seemlessly handles accepting async iterables. + async_query = self._bundled_query_helper() + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", async_query) + self.assertIsNotNone(bundle.named_queries.get("asdf")) + self.assertIsNotNone( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-1" + ] + ) + self.assertIsNotNone( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-2" + ] + ) + + +class TestBundleBuilder(_CollectionQueryMixin, unittest.TestCase): + @staticmethod + def build_results_iterable(items): + return iter(items) + + @staticmethod + def get_client(): + return _test_helpers.make_client() + + @staticmethod + def get_internal_client_mock(): + return mock.create_autospec(FirestoreClient) + + @classmethod + def get_collection_class(cls): + return CollectionReference + + def test_build_round_trip(self): + query = self._bundled_query_helper() + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", query) + serialized = bundle.build() + self.assertEqual( + serialized, _helpers.deserialize_bundle(serialized, query._client).build(), + ) + + def test_build_round_trip_emojis(self): + smile = "😂" + mermaid = "🧜🏿‍♀️" + query = self._bundled_query_helper( + data=[{"smile": smile}, {"compound": mermaid}], + ) + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", query) + serialized = bundle.build() + reserialized_bundle = _helpers.deserialize_bundle(serialized, query._client) + + self.assertEqual( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-1" + ].snapshot._data["smile"], + smile, + ) + self.assertEqual( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-2" + ].snapshot._data["compound"], + mermaid, + ) + self.assertEqual( + serialized, reserialized_bundle.build(), + ) + + def test_build_round_trip_more_unicode(self): + bano = "baño" + chinese_characters = "殷周金文集成引得" + query = self._bundled_query_helper( + data=[{"bano": bano}, {"international": chinese_characters}], + ) + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", query) + serialized = bundle.build() + reserialized_bundle = _helpers.deserialize_bundle(serialized, query._client) + + self.assertEqual( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-1" + ].snapshot._data["bano"], + bano, + ) + self.assertEqual( + bundle.documents[ + "projects/project-project/databases/(default)/documents/col/doc-2" + ].snapshot._data["international"], + chinese_characters, + ) + self.assertEqual( + serialized, reserialized_bundle.build(), + ) + + def test_roundtrip_binary_data(self): + query = self._bundled_query_helper(data=[{"binary_data": b"\x0f"}],) + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", query) + serialized = bundle.build() + reserialized_bundle = _helpers.deserialize_bundle(serialized, query._client) + gen = _helpers._get_documents_from_bundle(reserialized_bundle) + snapshot = next(gen) + self.assertEqual( + int.from_bytes(snapshot._data["binary_data"], byteorder=sys.byteorder), 15, + ) + + def test_deserialize_from_seconds_nanos(self): + """Some SDKs (Node) serialize Timestamp values to + '{"seconds": 123, "nanos": 456}', instead of an ISO-formatted string. + This tests deserialization from that format.""" + + client = _test_helpers.make_client(project_name="fir-bundles-test") + + _serialized: str = ( + '139{"metadata":{"id":"test-bundle","createTime":' + + '{"seconds":"1616434660","nanos":913764000},"version":1,"totalDocuments"' + + ':1,"totalBytes":"829"}}224{"namedQuery":{"name":"self","bundledQuery":' + + '{"parent":"projects/fir-bundles-test/databases/(default)/documents",' + + '"structuredQuery":{"from":[{"collectionId":"bundles"}]}},"readTime":' + + '{"seconds":"1616434660","nanos":913764000}}}194{"documentMetadata":' + + '{"name":"projects/fir-bundles-test/databases/(default)/documents/' + + 'bundles/test-bundle","readTime":{"seconds":"1616434660","nanos":' + + '913764000},"exists":true,"queries":["self"]}}402{"document":{"name":' + + '"projects/fir-bundles-test/databases/(default)/documents/bundles/' + + 'test-bundle","fields":{"clientCache":{"stringValue":"1200"},' + + '"serverCache":{"stringValue":"600"},"queries":{"mapValue":{"fields":' + + '{"self":{"mapValue":{"fields":{"collection":{"stringValue":"bundles"' + + '}}}}}}}},"createTime":{"seconds":"1615488796","nanos":163327000},' + + '"updateTime":{"seconds":"1615492486","nanos":34157000}}}' + ) + + self.assertRaises( + ValueError, _helpers.deserialize_bundle, _serialized, client=client, + ) + + # The following assertions would test deserialization of NodeJS bundles + # were explicit handling of that edge case to be added. + + # First, deserialize that value into a Bundle instance. If this succeeds, + # we're off to a good start. + # bundle = _helpers.deserialize_bundle(_serialized, client=client) + # Second, re-serialize it into a Python-centric format (aka, ISO timestamps) + # instead of seconds/nanos. + # re_serialized = bundle.build() + # # Finally, confirm the round trip. + # self.assertEqual( + # re_serialized, + # _helpers.deserialize_bundle(re_serialized, client=client).build(), + # ) + + def test_deserialized_bundle_cached_metadata(self): + query = self._bundled_query_helper() + bundle = FirestoreBundle("test") + bundle.add_named_query("asdf", query) + bundle_copy = _helpers.deserialize_bundle(bundle.build(), query._client) + self.assertIsInstance(bundle_copy, FirestoreBundle) + self.assertIsNotNone(bundle_copy._deserialized_metadata) + bundle_copy.add_named_query("second query", query) + self.assertIsNone(bundle_copy._deserialized_metadata) + + @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") + def test_invalid_json(self, fnc): + client = _test_helpers.make_client() + fnc.return_value = iter([{}]) + self.assertRaises( + ValueError, _helpers.deserialize_bundle, "does not matter", client, + ) + + @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") + def test_not_metadata_first(self, fnc): + client = _test_helpers.make_client() + fnc.return_value = iter([{"document": {}}]) + self.assertRaises( + ValueError, _helpers.deserialize_bundle, "does not matter", client, + ) + + @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") + @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") + def test_unexpected_termination(self, fnc, _): + client = _test_helpers.make_client() + # invalid bc `document_metadata` must be followed by a `document` + fnc.return_value = [{"metadata": {"id": "asdf"}}, {"documentMetadata": {}}] + self.assertRaises( + ValueError, _helpers.deserialize_bundle, "does not matter", client, + ) + + @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") + @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") + def test_valid_passes(self, fnc, _): + client = _test_helpers.make_client() + fnc.return_value = [ + {"metadata": {"id": "asdf"}}, + {"documentMetadata": {}}, + {"document": {}}, + ] + _helpers.deserialize_bundle("does not matter", client) + + @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") + @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") + def test_invalid_bundle(self, fnc, _): + client = _test_helpers.make_client() + # invalid bc `document` must follow `document_metadata` + fnc.return_value = [{"metadata": {"id": "asdf"}}, {"document": {}}] + self.assertRaises( + ValueError, _helpers.deserialize_bundle, "does not matter", client, + ) + + @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") + @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") + def test_invalid_bundle_element_type(self, fnc, _): + client = _test_helpers.make_client() + # invalid bc `wtfisthis?` is obviously invalid + fnc.return_value = [{"metadata": {"id": "asdf"}}, {"wtfisthis?": {}}] + self.assertRaises( + ValueError, _helpers.deserialize_bundle, "does not matter", client, + ) + + @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") + @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") + def test_invalid_bundle_start(self, fnc, _): + client = _test_helpers.make_client() + # invalid bc first element must be of key `metadata` + fnc.return_value = [{"document": {}}] + self.assertRaises( + ValueError, _helpers.deserialize_bundle, "does not matter", client, + ) + + def test_not_actually_a_bundle_at_all(self): + client = _test_helpers.make_client() + self.assertRaises( + ValueError, _helpers.deserialize_bundle, "{}", client, + ) + + def test_add_invalid_bundle_element_type(self): + client = _test_helpers.make_client() + bundle = FirestoreBundle("asdf") + self.assertRaises( + ValueError, + bundle._add_bundle_element, + BundleElement(), + client=client, + type="asdf", + ) diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 3e6b1d7be..feaec8119 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -17,6 +17,8 @@ import mock +from tests.unit.v1 import _test_helpers + class TestCollectionReference(unittest.TestCase): @staticmethod @@ -89,7 +91,7 @@ def test_add_auto_assigned(self): firestore_api.commit.return_value = commit_response create_doc_response = document.Document() firestore_api.create_document.return_value = create_doc_response - client = _make_client() + client = _test_helpers.make_client() client._firestore_api_internal = firestore_api # Actually make a collection. @@ -140,7 +142,7 @@ def _write_pb_for_create(document_path, document_data): def _add_helper(self, retry=None, timeout=None): from google.cloud.firestore_v1.document import DocumentReference - from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) @@ -155,7 +157,7 @@ def _add_helper(self, retry=None, timeout=None): firestore_api.commit.return_value = commit_response # Attach the fake GAPIC to a real client. - client = _make_client() + client = _test_helpers.make_client() client._firestore_api_internal = firestore_api # Actually make a collection and call add(). @@ -163,7 +165,7 @@ def _add_helper(self, retry=None, timeout=None): document_data = {"zorp": 208.75, "i-did-not": b"know that"} doc_id = "child" - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + kwargs = _fs_v1_helpers.make_retry_timeout_kwargs(retry, timeout) update_time, document_ref = collection.add( document_data, document_id=doc_id, **kwargs ) @@ -196,7 +198,7 @@ def test_add_w_retry_timeout(self): self._add_helper(retry=retry, timeout=timeout) def _list_documents_helper(self, page_size=None, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.document import DocumentReference @@ -213,7 +215,7 @@ def _next_page(self): page, self._pages = self._pages[0], self._pages[1:] return Page(self, page, self.item_to_value) - client = _make_client() + client = _test_helpers.make_client() template = client._database_string + "/documents/{}" document_ids = ["doc-1", "doc-2"] documents = [ @@ -224,7 +226,7 @@ def _next_page(self): api_client.list_documents.return_value = iterator client._firestore_api_internal = api_client collection = self._make_one("collection", client=client) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + kwargs = _fs_v1_helpers.make_retry_timeout_kwargs(retry, timeout) if page_size is not None: documents = list(collection.list_documents(page_size=page_size, **kwargs)) @@ -347,16 +349,3 @@ def test_on_snapshot(self, watch): collection = self._make_one("collection") collection.on_snapshot(None) watch.for_query.assert_called_once() - - -def _make_credentials(): - import google.auth.credentials - - return mock.Mock(spec=google.auth.credentials.Credentials) - - -def _make_client(): - from google.cloud.firestore_v1.client import Client - - credentials = _make_credentials() - return Client(project="project-project", credentials=credentials)