From db5f286772592460b2bf02df25a121994889585d Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 21 Oct 2020 17:22:41 -0400 Subject: [PATCH] feat: add retry/timeout to manual surface (#222) Closes #221 --- google/cloud/firestore_v1/_helpers.py | 16 +- google/cloud/firestore_v1/async_batch.py | 23 +- google/cloud/firestore_v1/async_client.py | 53 +++-- google/cloud/firestore_v1/async_collection.py | 77 ++++--- google/cloud/firestore_v1/async_document.py | 127 +++++++---- google/cloud/firestore_v1/async_query.py | 74 ++++--- .../cloud/firestore_v1/async_transaction.py | 40 +++- google/cloud/firestore_v1/base_batch.py | 11 + google/cloud/firestore_v1/base_client.py | 35 ++- google/cloud/firestore_v1/base_collection.py | 60 +++++- google/cloud/firestore_v1/base_document.py | 130 +++++++++++- google/cloud/firestore_v1/base_query.py | 59 +++++- google/cloud/firestore_v1/base_transaction.py | 9 +- google/cloud/firestore_v1/batch.py | 22 +- google/cloud/firestore_v1/client.py | 49 +++-- google/cloud/firestore_v1/collection.py | 81 ++++--- google/cloud/firestore_v1/document.py | 129 ++++++++---- google/cloud/firestore_v1/query.py | 85 ++++---- google/cloud/firestore_v1/transaction.py | 41 +++- tests/unit/v1/test__helpers.py | 47 ++++- tests/unit/v1/test_async_batch.py | 28 ++- tests/unit/v1/test_async_client.py | 199 ++++++++---------- tests/unit/v1/test_async_collection.py | 80 ++++++- tests/unit/v1/test_async_document.py | 118 +++++++++-- tests/unit/v1/test_async_query.py | 79 ++++++- tests/unit/v1/test_async_transaction.py | 66 +++++- tests/unit/v1/test_batch.py | 24 ++- tests/unit/v1/test_client.py | 189 ++++++++--------- tests/unit/v1/test_collection.py | 71 ++++++- tests/unit/v1/test_document.py | 108 ++++++++-- tests/unit/v1/test_query.py | 72 ++++++- tests/unit/v1/test_transaction.py | 63 +++++- 32 files changed, 1656 insertions(+), 609 deletions(-) diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index e98ec8547..fb2f73c83 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -16,13 +16,14 @@ import datetime +from google.api_core.datetime_helpers import DatetimeWithNanoseconds # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.protobuf import struct_pb2 from google.type import latlng_pb2 # type: ignore import grpc # type: ignore from google.cloud import exceptions # type: ignore from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore -from google.api_core.datetime_helpers import DatetimeWithNanoseconds # type: ignore from google.cloud.firestore_v1.types.write import DocumentTransform from google.cloud.firestore_v1 import transforms from google.cloud.firestore_v1 import types @@ -1042,3 +1043,16 @@ def modify_write(self, write, **unused_kwargs) -> None: """ current_doc = types.Precondition(exists=self._exists) write._pb.current_document.CopyFrom(current_doc._pb) + + +def make_retry_timeout_kwargs(retry, timeout) -> dict: + """Helper fo API methods which take optional 'retry' / 'timeout' args.""" + kwargs = {} + + if retry is not gapic_v1.method.DEFAULT: + kwargs["retry"] = retry + + if timeout is not None: + kwargs["timeout"] = timeout + + return kwargs diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index cc359d6b5..8c13102d9 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -15,6 +15,9 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_batch import BaseWriteBatch @@ -33,27 +36,33 @@ class AsyncWriteBatch(BaseWriteBatch): def __init__(self, client) -> None: super(AsyncWriteBatch, self).__init__(client=client) - async def commit(self) -> list: + async def commit( + self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, + ) -> list: """Commit the changes accumulated in this batch. + Args: + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + Returns: List[:class:`google.cloud.proto.firestore.v1.write.WriteResult`, ...]: The write results corresponding to the changes committed, returned in the same order as the changes were applied to this batch. A write result contains an ``update_time`` field. """ + request, kwargs = self._prep_commit(retry, timeout) + commit_response = await self._client._firestore_api.commit( - request={ - "database": self._client._database_string, - "writes": self._write_pbs, - "transaction": None, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) self._write_pbs = [] self.write_results = results = list(commit_response.write_results) self.commit_time = commit_response.commit_time + return results async def __aenter__(self): diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index b1376170e..8233fd509 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -24,17 +24,17 @@ :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_client import ( BaseClient, DEFAULT_DATABASE, _CLIENT_INFO, - _reference_info, # type: ignore _parse_batch_get, # type: ignore - _get_doc_mask, _path_helper, ) -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_query import AsyncCollectionGroup from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -208,7 +208,12 @@ def document(self, *document_path: Tuple[str]) -> AsyncDocumentReference: ) async def get_all( - self, references: list, field_paths: Iterable[str] = None, transaction=None, + self, + references: list, + field_paths: Iterable[str] = None, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieve a batch of documents. @@ -239,48 +244,54 @@ async def get_all( transaction (Optional[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`]): An existing transaction that these ``references`` will be retrieved in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ - document_paths, reference_map = _reference_info(references) - mask = _get_doc_mask(field_paths) + request, reference_map, kwargs = self._prep_get_all( + references, field_paths, transaction, retry, timeout + ) + response_iterator = await self._firestore_api.batch_get_documents( - request={ - "database": self._database_string, - "documents": document_paths, - "mask": mask, - "transaction": _helpers.get_transaction_id(transaction), - }, - metadata=self._rpc_metadata, + request=request, metadata=self._rpc_metadata, **kwargs, ) async for get_doc_response in response_iterator: yield _parse_batch_get(get_doc_response, reference_map, self) - async def collections(self) -> AsyncGenerator[AsyncCollectionReference, Any]: + async def collections( + self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, + ) -> AsyncGenerator[AsyncCollectionReference, Any]: """List top-level collections of the client's database. + Args: + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + Returns: Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: iterator of subcollections of the current document. """ + request, kwargs = self._prep_collections(retry, timeout) iterator = await self._firestore_api.list_collection_ids( - request={"parent": "{}/documents".format(self._database_string)}, - metadata=self._rpc_metadata, + request=request, metadata=self._rpc_metadata, **kwargs, ) while True: for i in iterator.collection_ids: yield self.collection(i) if iterator.next_page_token: + next_request = request.copy() + next_request["page_token"] = iterator.next_page_token iterator = await self._firestore_api.list_collection_ids( - request={ - "parent": "{}/documents".format(self._database_string), - "page_token": iterator.next_page_token, - }, - metadata=self._rpc_metadata, + request=next_request, metadata=self._rpc_metadata, **kwargs, ) else: return diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index f0d41985b..e3842f03e 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -13,9 +13,12 @@ # limitations under the License. """Classes for representing collections for the Google Cloud Firestore API.""" + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_collection import ( BaseCollectionReference, - _auto_id, _item_to_document_ref, ) from google.cloud.firestore_v1 import ( @@ -70,7 +73,11 @@ def _query(self) -> async_query.AsyncQuery: return async_query.AsyncQuery(self) async def add( - self, document_data: dict, document_id: str = None + self, + document_data: dict, + document_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. @@ -82,6 +89,10 @@ async def add( automatically assigned by the server (the assigned ID will be a random 20 character string composed of digits, uppercase and lowercase letters). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: Tuple[:class:`google.protobuf.timestamp_pb2.Timestamp`, \ @@ -95,22 +106,28 @@ async def add( ~google.cloud.exceptions.Conflict: If ``document_id`` is provided and the document already exists. """ - if document_id is None: - document_id = _auto_id() - - document_ref = self.document(document_id) - write_result = await document_ref.create(document_data) + document_ref, kwargs = self._prep_add( + document_data, document_id, retry, timeout, + ) + write_result = await document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref async def list_documents( - self, page_size: int = None + self, + page_size: int = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> AsyncGenerator[DocumentReference, None]: """List all subdocuments of the current collection. Args: page_size (Optional[int]]): The maximum number of documents - in each page of results from this request. Non-positive values - are ignored. Defaults to a sensible value set by the API. + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.DocumentReference`]: @@ -118,21 +135,20 @@ async def list_documents( collection does not exist at the time of `snapshot`, the iterator will be empty """ - parent, _ = self._parent_info() + request, kwargs = self._prep_list_documents(page_size, retry, timeout) iterator = await self._client._firestore_api.list_documents( - request={ - "parent": parent, - "collection_id": self.id, - "page_size": page_size, - "show_missing": True, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) async for i in iterator: yield _item_to_document_ref(self, i) - async def get(self, transaction: Transaction = None) -> list: + async def get( + self, + transaction: Transaction = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> list: """Read the documents in this collection. This sends a ``RunQuery`` RPC and returns a list of documents @@ -142,6 +158,10 @@ async def get(self, transaction: Transaction = None) -> list: transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -150,11 +170,15 @@ async def get(self, transaction: Transaction = None) -> list: Returns: list: The documents in this collection that match the query. """ - query = self._query() - return await query.get(transaction=transaction) + query, kwargs = self._prep_get_or_stream(retry, timeout) + + return await query.get(transaction=transaction, **kwargs) async def stream( - self, transaction: Transaction = None + self, + transaction: Transaction = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> AsyncIterator[async_document.DocumentSnapshot]: """Read the documents in this collection. @@ -177,11 +201,16 @@ async def stream( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ Transaction`]): An existing transaction that the query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: The next document that fulfills the query. """ - query = async_query.AsyncQuery(self) - async for d in query.stream(transaction=transaction): + query, kwargs = self._prep_get_or_stream(retry, timeout) + + async for d in query.stream(transaction=transaction, **kwargs): yield d # pytype: disable=name-error diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 064797f6d..5f821b655 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -14,6 +14,9 @@ """Classes for representing documents for the Google Cloud Firestore API.""" +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_document import ( BaseDocumentReference, DocumentSnapshot, @@ -22,7 +25,6 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.types import common from typing import Any, AsyncGenerator, Coroutine, Iterable, Union @@ -54,12 +56,21 @@ class AsyncDocumentReference(BaseDocumentReference): def __init__(self, *path, **kwargs) -> None: super(AsyncDocumentReference, self).__init__(*path, **kwargs) - async def create(self, document_data: dict) -> Coroutine: + async def create( + self, + document_data: dict, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Coroutine: """Create the current document in the Firestore database. Args: document_data (dict): Property names and values to use for creating a document. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.types.WriteResult`: @@ -70,12 +81,17 @@ async def create(self, document_data: dict) -> Coroutine: :class:`~google.cloud.exceptions.Conflict`: If the document already exists. """ - batch = self._client.batch() - batch.create(self, document_data) - write_results = await batch.commit() + batch, kwargs = self._prep_create(document_data, retry, timeout) + write_results = await batch.commit(**kwargs) return _first_write_result(write_results) - async def set(self, document_data: dict, merge: bool = False) -> Coroutine: + async def set( + self, + document_data: dict, + merge: bool = False, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Coroutine: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -95,19 +111,26 @@ async def set(self, document_data: dict, merge: bool = False) -> Coroutine: merge (Optional[bool] or Optional[List]): If True, apply merging instead of overwriting the state of the document. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.types.WriteResult`: The write result corresponding to the committed document. A write result contains an ``update_time`` field. """ - batch = self._client.batch() - batch.set(self, document_data, merge=merge) - write_results = await batch.commit() + batch, kwargs = self._prep_set(document_data, merge, retry, timeout) + write_results = await batch.commit(**kwargs) return _first_write_result(write_results) async def update( - self, field_updates: dict, option: _helpers.WriteOption = None + self, + field_updates: dict, + option: _helpers.WriteOption = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> Coroutine: """Update an existing document in the Firestore database. @@ -242,6 +265,10 @@ async def update( option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): A write option to make assertions / preconditions on the server state of the document before applying changes. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.types.WriteResult`: @@ -251,18 +278,26 @@ async def update( Raises: ~google.cloud.exceptions.NotFound: If the document does not exist. """ - batch = self._client.batch() - batch.update(self, field_updates, option=option) - write_results = await batch.commit() + batch, kwargs = self._prep_update(field_updates, option, retry, timeout) + write_results = await batch.commit(**kwargs) return _first_write_result(write_results) - async def delete(self, option: _helpers.WriteOption = None) -> Coroutine: + async def delete( + self, + option: _helpers.WriteOption = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Coroutine: """Delete the current document in the Firestore database. Args: option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): A write option to make assertions / preconditions on the server state of the document before applying changes. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`google.protobuf.timestamp_pb2.Timestamp`: @@ -271,20 +306,20 @@ async def delete(self, option: _helpers.WriteOption = None) -> Coroutine: nothing was deleted), this method will still succeed and will still return the time that the request was received by the server. """ - write_pb = _helpers.pb_for_delete(self._document_path, option) + request, kwargs = self._prep_delete(option, retry, timeout) + commit_response = await self._client._firestore_api.commit( - request={ - "database": self._client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) return commit_response.commit_time async def get( - self, field_paths: Iterable[str] = None, transaction=None + self, + field_paths: Iterable[str] = None, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> Union[DocumentSnapshot, Coroutine[Any, Any, DocumentSnapshot]]: """Retrieve a snapshot of the current document. @@ -303,6 +338,10 @@ async def get( transaction (Optional[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`]): An existing transaction that this reference will be retrieved in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot`: @@ -312,23 +351,12 @@ async def get( :attr:`create_time` attributes will all be ``None`` and its :attr:`exists` attribute will be ``False``. """ - if isinstance(field_paths, str): - raise ValueError("'field_paths' must be a sequence of paths, not a string.") - - if field_paths is not None: - mask = common.DocumentMask(field_paths=sorted(field_paths)) - else: - mask = None + request, kwargs = self._prep_get(field_paths, transaction, retry, timeout) firestore_api = self._client._firestore_api try: document_pb = await firestore_api.get_document( - request={ - "name": self._document_path, - "mask": mask, - "transaction": _helpers.get_transaction_id(transaction), - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) except exceptions.NotFound: data = None @@ -350,13 +378,22 @@ async def get( update_time=update_time, ) - async def collections(self, page_size: int = None) -> AsyncGenerator: + async def collections( + self, + page_size: int = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> AsyncGenerator: """List subcollections of the current document. Args: page_size (Optional[int]]): The maximum number of collections - in each page of results from this request. Non-positive values - are ignored. Defaults to a sensible value set by the API. + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: @@ -364,22 +401,20 @@ async def collections(self, page_size: int = None) -> AsyncGenerator: document does not exist at the time of `snapshot`, the iterator will be empty """ + request, kwargs = self._prep_collections(page_size, retry, timeout) + iterator = await self._client._firestore_api.list_collection_ids( - request={"parent": self._document_path, "page_size": page_size}, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) while True: for i in iterator.collection_ids: yield self.collection(i) if iterator.next_page_token: + next_request = request.copy() + next_request["page_token"] = iterator.next_page_token iterator = await self._client._firestore_api.list_collection_ids( - request={ - "parent": self._document_path, - "page_size": page_size, - "page_token": iterator.next_page_token, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs ) else: return diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 2750f290f..f772194e8 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -18,6 +18,10 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_query import ( BaseCollectionGroup, BaseQuery, @@ -27,7 +31,6 @@ _enum_from_direction, ) -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import async_document from typing import AsyncGenerator @@ -117,7 +120,12 @@ def __init__( all_descendants=all_descendants, ) - async def get(self, transaction: Transaction = None) -> list: + async def get( + self, + transaction: Transaction = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> list: """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and returns a list of documents @@ -127,6 +135,10 @@ async def get(self, transaction: Transaction = None) -> list: transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -149,7 +161,7 @@ async def get(self, transaction: Transaction = None) -> list: ) self._limit_to_last = False - result = self.stream(transaction=transaction) + result = self.stream(transaction=transaction, retry=retry, timeout=timeout) result = [d async for d in result] if is_limited_to_last: result = list(reversed(result)) @@ -157,7 +169,10 @@ async def get(self, transaction: Transaction = None) -> list: return result async def stream( - self, transaction: Transaction = None + self, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: """Read the documents in the collection that match this query. @@ -180,25 +195,21 @@ async def stream( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: :class:`~google.cloud.firestore_v1.async_document.DocumentSnapshot`: The next document that fulfills the query. """ - if self._limit_to_last: - raise ValueError( - "Query results for queries that include limit_to_last() " - "constraints cannot be streamed. Use Query.get() instead." - ) + request, expected_prefix, kwargs = self._prep_stream( + transaction, retry, timeout, + ) - parent_path, expected_prefix = self._parent._parent_info() response_iterator = await self._client._firestore_api.run_query( - request={ - "parent": parent_path, - "structured_query": self._to_protobuf(), - "transaction": _helpers.get_transaction_id(transaction), - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) async for response in response_iterator: @@ -252,8 +263,15 @@ def __init__( all_descendants=all_descendants, ) + @staticmethod + def _get_query_class(): + return AsyncQuery + async def get_partitions( - self, partition_count + self, + partition_count, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> AsyncGenerator[QueryPartition, None]: """Partition a query for parallelization. @@ -265,24 +283,14 @@ async def get_partitions( partition_count (int): The desired maximum number of partition points. The number must be strictly positive. The actual number of partitions returned may be fewer. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. """ - self._validate_partition_query() - query = AsyncQuery( - self._parent, - orders=self._PARTITION_QUERY_ORDER, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - parent_path, expected_prefix = self._parent._parent_info() + request, kwargs = self._prep_get_partitions(partition_count, retry, timeout) pager = await self._client._firestore_api.partition_query( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "partition_count": partition_count, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) start_at = None diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 81316b8e6..fd639e1ed 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -18,6 +18,9 @@ import asyncio import random +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_transaction import ( _BaseTransactional, BaseTransaction, @@ -34,6 +37,7 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import async_batch +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.async_document import AsyncDocumentReference @@ -144,32 +148,56 @@ async def _commit(self) -> list: self._clean_up() return list(commit_response.write_results) - async def get_all(self, references: list) -> Coroutine: + async def get_all( + self, + references: list, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Coroutine: """Retrieves multiple documents from Firestore. Args: references (List[.AsyncDocumentReference, ...]): Iterable of document references to be retrieved. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ - return await self._client.get_all(references, transaction=self) - - async def get(self, ref_or_query) -> AsyncGenerator[DocumentSnapshot, Any]: + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + return await self._client.get_all(references, transaction=self, **kwargs) + + async def get( + self, + ref_or_query, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> AsyncGenerator[DocumentSnapshot, Any]: """ Retrieve a document or a query result from the database. + Args: ref_or_query The document references or query object to return. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) if isinstance(ref_or_query, AsyncDocumentReference): - return await self._client.get_all([ref_or_query], transaction=self) + return await self._client.get_all( + [ref_or_query], transaction=self, **kwargs + ) elif isinstance(ref_or_query, AsyncQuery): - return await ref_or_query.stream(transaction=self) + return await ref_or_query.stream(transaction=self, **kwargs) else: raise ValueError( 'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.' diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index f84af4b3d..348a6ac45 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -19,6 +19,7 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.document import DocumentReference + from typing import Union @@ -146,3 +147,13 @@ def delete( """ write_pb = _helpers.pb_for_delete(reference._document_path, option) self._add_write_pbs([write_pb]) + + def _prep_commit(self, retry, timeout): + """Shared setup for async/sync :meth:`commit`.""" + request = { + "database": self._client._database_string, + "writes": self._write_pbs, + "transaction": None, + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + return request, kwargs diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index b2a422291..285ad82d5 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -28,6 +28,7 @@ import google.api_core.client_options # type: ignore import google.api_core.path_template # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core.gapic_v1 import client_info # type: ignore from google.cloud.client import ClientWithProject # type: ignore @@ -353,18 +354,50 @@ def write_option( extra = "{!r} was provided".format(name) raise TypeError(_BAD_OPTION_ERR, extra) + def _prep_get_all( + self, + references: list, + field_paths: Iterable[str] = None, + transaction: BaseTransaction = None, + retry: retries.Retry = None, + timeout: float = None, + ) -> Tuple[dict, dict, dict]: + """Shared setup for async/sync :meth:`get_all`.""" + document_paths, reference_map = _reference_info(references) + mask = _get_doc_mask(field_paths) + request = { + "database": self._database_string, + "documents": document_paths, + "mask": mask, + "transaction": _helpers.get_transaction_id(transaction), + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, reference_map, kwargs + def get_all( self, references: list, field_paths: Iterable[str] = None, transaction: BaseTransaction = None, + retry: retries.Retry = None, + timeout: float = None, ) -> Union[ AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any] ]: raise NotImplementedError + def _prep_collections( + self, retry: retries.Retry = None, timeout: float = None, + ) -> Tuple[dict, dict]: + """Shared setup for async/sync :meth:`collections`.""" + request = {"parent": "{}/documents".format(self._database_string)} + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, kwargs + def collections( - self, + self, retry: retries.Retry = None, timeout: float = None, ) -> Union[ AsyncGenerator[BaseCollectionReference, Any], Generator[BaseCollectionReference, Any, Any], diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 72480a911..ae58fe820 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -15,6 +15,8 @@ """Classes for representing collections for the Google Cloud Firestore API.""" import random +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.document import DocumentReference from typing import ( @@ -146,13 +148,48 @@ def _parent_info(self) -> Tuple[Any, str]: expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) return parent_path, expected_prefix + def _prep_add( + self, + document_data: dict, + document_id: str = None, + retry: retries.Retry = None, + timeout: float = None, + ) -> Tuple[DocumentReference, dict]: + """Shared setup for async / sync :method:`add`""" + if document_id is None: + document_id = _auto_id() + + document_ref = self.document(document_id) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return document_ref, kwargs + def add( - self, document_data: dict, document_id: str = None + self, + document_data: dict, + document_id: str = None, + retry: retries.Retry = None, + timeout: float = None, ) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]: raise NotImplementedError + def _prep_list_documents( + self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, + ) -> Tuple[dict, dict]: + """Shared setup for async / sync :method:`list_documents`""" + parent, _ = self._parent_info() + request = { + "parent": parent, + "collection_id": self.id, + "page_size": page_size, + "show_missing": True, + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, kwargs + def list_documents( - self, page_size: int = None + self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, ) -> Union[ Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] ]: @@ -373,15 +410,30 @@ def end_at( query = self._query() return query.end_at(document_fields) + def _prep_get_or_stream( + self, retry: retries.Retry = None, timeout: float = None, + ) -> Tuple[Any, dict]: + """Shared setup for async / sync :meth:`get` / :meth:`stream`""" + query = self._query() + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return query, kwargs + def get( - self, transaction: Transaction = None + self, + transaction: Transaction = None, + retry: retries.Retry = None, + timeout: float = None, ) -> Union[ Generator[DocumentSnapshot, Any, Any], AsyncGenerator[DocumentSnapshot, Any] ]: raise NotImplementedError def stream( - self, transaction: Transaction = None + self, + transaction: Transaction = None, + retry: retries.Retry = None, + timeout: float = None, ) -> Union[Iterator[DocumentSnapshot], AsyncIterator[DocumentSnapshot]]: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 68534c471..7dcf407ec 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -16,9 +16,16 @@ import copy +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import field_path as field_path_module -from typing import Any, Iterable, NoReturn, Tuple +from google.cloud.firestore_v1.types import common + +from typing import Any +from typing import Iterable +from typing import NoReturn +from typing import Tuple class BaseDocumentReference(object): @@ -178,26 +185,135 @@ def collection(self, collection_id: str) -> Any: child_path = self._path + (collection_id,) return self._client.collection(*child_path) - def create(self, document_data: dict) -> NoReturn: + def _prep_create( + self, document_data: dict, retry: retries.Retry = None, timeout: float = None, + ) -> Tuple[Any, dict]: + batch = self._client.batch() + batch.create(self, document_data) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return batch, kwargs + + def create( + self, document_data: dict, retry: retries.Retry = None, timeout: float = None, + ) -> NoReturn: raise NotImplementedError - def set(self, document_data: dict, merge: bool = False) -> NoReturn: + def _prep_set( + self, + document_data: dict, + merge: bool = False, + retry: retries.Retry = None, + timeout: float = None, + ) -> Tuple[Any, dict]: + batch = self._client.batch() + batch.set(self, document_data, merge=merge) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return batch, kwargs + + def set( + self, + document_data: dict, + merge: bool = False, + retry: retries.Retry = None, + timeout: float = None, + ) -> NoReturn: raise NotImplementedError + def _prep_update( + self, + field_updates: dict, + option: _helpers.WriteOption = None, + retry: retries.Retry = None, + timeout: float = None, + ) -> Tuple[Any, dict]: + batch = self._client.batch() + batch.update(self, field_updates, option=option) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return batch, kwargs + def update( - self, field_updates: dict, option: _helpers.WriteOption = None + self, + field_updates: dict, + option: _helpers.WriteOption = None, + retry: retries.Retry = None, + timeout: float = None, ) -> NoReturn: raise NotImplementedError - def delete(self, option: _helpers.WriteOption = None) -> NoReturn: + def _prep_delete( + self, + option: _helpers.WriteOption = None, + retry: retries.Retry = None, + timeout: float = None, + ) -> Tuple[dict, dict]: + """Shared setup for async/sync :meth:`delete`.""" + write_pb = _helpers.pb_for_delete(self._document_path, option) + request = { + "database": self._client._database_string, + "writes": [write_pb], + "transaction": None, + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, kwargs + + def delete( + self, + option: _helpers.WriteOption = None, + retry: retries.Retry = None, + timeout: float = None, + ) -> NoReturn: raise NotImplementedError + def _prep_get( + self, + field_paths: Iterable[str] = None, + transaction=None, + retry: retries.Retry = None, + timeout: float = None, + ) -> Tuple[dict, dict]: + """Shared setup for async/sync :meth:`get`.""" + if isinstance(field_paths, str): + raise ValueError("'field_paths' must be a sequence of paths, not a string.") + + if field_paths is not None: + mask = common.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None + + request = { + "name": self._document_path, + "mask": mask, + "transaction": _helpers.get_transaction_id(transaction), + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, kwargs + def get( - self, field_paths: Iterable[str] = None, transaction=None + self, + field_paths: Iterable[str] = None, + transaction=None, + retry: retries.Retry = None, + timeout: float = None, ) -> "DocumentSnapshot": raise NotImplementedError - def collections(self, page_size: int = None) -> NoReturn: + def _prep_collections( + self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, + ) -> Tuple[dict, dict]: + """Shared setup for async/sync :meth:`collections`.""" + request = {"parent": self._document_path, "page_size": page_size} + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, kwargs + + def collections( + self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, + ) -> NoReturn: raise NotImplementedError def on_snapshot(self, callback) -> NoReturn: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 38d08dd14..2393d3711 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -21,6 +21,7 @@ import copy import math +from google.api_core import retry as retries # type: ignore from google.protobuf import wrappers_pb2 from google.cloud.firestore_v1 import _helpers @@ -802,10 +803,34 @@ def _to_protobuf(self) -> StructuredQuery: return query.StructuredQuery(**query_kwargs) - def get(self, transaction=None) -> NoReturn: + def get( + self, transaction=None, retry: retries.Retry = None, timeout: float = None, + ) -> NoReturn: raise NotImplementedError - def stream(self, transaction=None) -> NoReturn: + def _prep_stream( + self, transaction=None, retry: retries.Retry = None, timeout: float = None, + ) -> Tuple[dict, str, dict]: + """Shared setup for async / sync :meth:`stream`""" + if self._limit_to_last: + raise ValueError( + "Query results for queries that include limit_to_last() " + "constraints cannot be streamed. Use Query.get() instead." + ) + + parent_path, expected_prefix = self._parent._parent_info() + request = { + "parent": parent_path, + "structured_query": self._to_protobuf(), + "transaction": _helpers.get_transaction_id(transaction), + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, expected_prefix, kwargs + + def stream( + self, transaction=None, retry: retries.Retry = None, timeout: float = None, + ) -> NoReturn: raise NotImplementedError def on_snapshot(self, callback) -> NoReturn: @@ -1101,6 +1126,36 @@ def _validate_partition_query(self): if self._offset: raise ValueError("Can't partition query with offset.") + def _get_query_class(self): + raise NotImplementedError + + def _prep_get_partitions( + self, partition_count, retry: retries.Retry = None, timeout: float = None, + ) -> Tuple[dict, dict]: + self._validate_partition_query() + parent_path, expected_prefix = self._parent._parent_info() + klass = self._get_query_class() + query = klass( + self._parent, + orders=self._PARTITION_QUERY_ORDER, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + request = { + "parent": parent_path, + "structured_query": query._to_protobuf(), + "partition_count": partition_count, + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, kwargs + + def get_partitions( + self, partition_count, retry: retries.Retry = None, timeout: float = None, + ) -> NoReturn: + raise NotImplementedError + class QueryPartition: """Represents a bounded partition of a collection group query. diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index c676d3d7a..5eac1d7fe 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -14,6 +14,7 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" +from google.api_core import retry as retries # type: ignore from google.cloud.firestore_v1 import types from typing import Any, Coroutine, NoReturn, Optional, Union @@ -141,10 +142,14 @@ def _rollback(self) -> NoReturn: def _commit(self) -> Union[list, Coroutine[Any, Any, list]]: raise NotImplementedError - def get_all(self, references: list) -> NoReturn: + def get_all( + self, references: list, retry: retries.Retry = None, timeout: float = None, + ) -> NoReturn: raise NotImplementedError - def get(self, ref_or_query) -> NoReturn: + def get( + self, ref_or_query, retry: retries.Retry = None, timeout: float = None, + ) -> NoReturn: raise NotImplementedError diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index c4e5c7a6f..175805122 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -14,6 +14,8 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.cloud.firestore_v1.base_batch import BaseWriteBatch @@ -33,27 +35,33 @@ class WriteBatch(BaseWriteBatch): def __init__(self, client) -> None: super(WriteBatch, self).__init__(client=client) - def commit(self) -> list: + def commit( + self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None + ) -> list: """Commit the changes accumulated in this batch. + Args: + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + Returns: List[:class:`google.cloud.proto.firestore.v1.write.WriteResult`, ...]: The write results corresponding to the changes committed, returned in the same order as the changes were applied to this batch. A write result contains an ``update_time`` field. """ + request, kwargs = self._prep_commit(retry, timeout) + commit_response = self._client._firestore_api.commit( - request={ - "database": self._client._database_string, - "writes": self._write_pbs, - "transaction": None, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) self._write_pbs = [] self.write_results = results = list(commit_response.write_results) self.commit_time = commit_response.commit_time + return results def __enter__(self): diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index e6c9f45c9..c3f75aba5 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -24,17 +24,17 @@ :class:`~google.cloud.firestore_v1.document.DocumentReference` """ +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_client import ( BaseClient, DEFAULT_DATABASE, _CLIENT_INFO, - _reference_info, _parse_batch_get, - _get_doc_mask, _path_helper, ) -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.query import CollectionGroup from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference @@ -207,6 +207,8 @@ def get_all( references: list, field_paths: Iterable[str] = None, transaction: Transaction = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> Generator[Any, Any, None]: """Retrieve a batch of documents. @@ -237,48 +239,55 @@ def get_all( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that these ``references`` will be retrieved in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ - document_paths, reference_map = _reference_info(references) - mask = _get_doc_mask(field_paths) + request, reference_map, kwargs = self._prep_get_all( + references, field_paths, transaction, retry, timeout + ) + response_iterator = self._firestore_api.batch_get_documents( - request={ - "database": self._database_string, - "documents": document_paths, - "mask": mask, - "transaction": _helpers.get_transaction_id(transaction), - }, - metadata=self._rpc_metadata, + request=request, metadata=self._rpc_metadata, **kwargs, ) for get_doc_response in response_iterator: yield _parse_batch_get(get_doc_response, reference_map, self) - def collections(self) -> Generator[Any, Any, None]: + def collections( + self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, + ) -> Generator[Any, Any, None]: """List top-level collections of the client's database. + Args: + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: iterator of subcollections of the current document. """ + request, kwargs = self._prep_collections(retry, timeout) + iterator = self._firestore_api.list_collection_ids( - request={"parent": "{}/documents".format(self._database_string)}, - metadata=self._rpc_metadata, + request=request, metadata=self._rpc_metadata, **kwargs, ) while True: for i in iterator.collection_ids: yield self.collection(i) if iterator.next_page_token: + next_request = request.copy() + next_request["page_token"] = iterator.next_page_token iterator = self._firestore_api.list_collection_ids( - request={ - "parent": "{}/documents".format(self._database_string), - "page_token": iterator.next_page_token, - }, - metadata=self._rpc_metadata, + request=next_request, metadata=self._rpc_metadata, **kwargs, ) else: return diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 4cd857095..96d076e2c 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -13,9 +13,12 @@ # limitations under the License. """Classes for representing collections for the Google Cloud Firestore API.""" + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_collection import ( BaseCollectionReference, - _auto_id, _item_to_document_ref, ) from google.cloud.firestore_v1 import query as query_mod @@ -64,7 +67,13 @@ def _query(self) -> query_mod.Query: """ return query_mod.Query(self) - def add(self, document_data: dict, document_id: str = None) -> Tuple[Any, Any]: + def add( + self, + document_data: dict, + document_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. Args: @@ -75,6 +84,10 @@ def add(self, document_data: dict, document_id: str = None) -> Tuple[Any, Any]: automatically assigned by the server (the assigned ID will be a random 20 character string composed of digits, uppercase and lowercase letters). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: Tuple[:class:`google.protobuf.timestamp_pb2.Timestamp`, \ @@ -88,20 +101,28 @@ def add(self, document_data: dict, document_id: str = None) -> Tuple[Any, Any]: ~google.cloud.exceptions.Conflict: If ``document_id`` is provided and the document already exists. """ - if document_id is None: - document_id = _auto_id() - - document_ref = self.document(document_id) - write_result = document_ref.create(document_data) + document_ref, kwargs = self._prep_add( + document_data, document_id, retry, timeout, + ) + write_result = document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref - def list_documents(self, page_size: int = None) -> Generator[Any, Any, None]: + def list_documents( + self, + page_size: int = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Generator[Any, Any, None]: """List all subdocuments of the current collection. Args: page_size (Optional[int]]): The maximum number of documents - in each page of results from this request. Non-positive values - are ignored. Defaults to a sensible value set by the API. + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.DocumentReference`]: @@ -109,20 +130,19 @@ def list_documents(self, page_size: int = None) -> Generator[Any, Any, None]: collection does not exist at the time of `snapshot`, the iterator will be empty """ - parent, _ = self._parent_info() + request, kwargs = self._prep_list_documents(page_size, retry, timeout) iterator = self._client._firestore_api.list_documents( - request={ - "parent": parent, - "collection_id": self.id, - "page_size": page_size, - "show_missing": True, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) return (_item_to_document_ref(self, i) for i in iterator) - def get(self, transaction: Transaction = None) -> list: + def get( + self, + transaction: Transaction = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> list: """Read the documents in this collection. This sends a ``RunQuery`` RPC and returns a list of documents @@ -132,6 +152,10 @@ def get(self, transaction: Transaction = None) -> list: transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -140,11 +164,15 @@ def get(self, transaction: Transaction = None) -> list: Returns: list: The documents in this collection that match the query. """ - query = query_mod.Query(self) - return query.get(transaction=transaction) + query, kwargs = self._prep_get_or_stream(retry, timeout) + + return query.get(transaction=transaction, **kwargs) def stream( - self, transaction: Transaction = None + self, + transaction: Transaction = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> Generator[document.DocumentSnapshot, Any, None]: """Read the documents in this collection. @@ -167,13 +195,18 @@ def stream( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ Transaction`]): An existing transaction that the query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: The next document that fulfills the query. """ - query = query_mod.Query(self) - return query.stream(transaction=transaction) + query, kwargs = self._prep_get_or_stream(retry, timeout) + + return query.stream(transaction=transaction, **kwargs) def on_snapshot(self, callback: Callable) -> Watch: """Monitor the documents in this collection. diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index ca5fc8378..55e8797c4 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -14,6 +14,9 @@ """Classes for representing documents for the Google Cloud Firestore API.""" +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_document import ( BaseDocumentReference, DocumentSnapshot, @@ -22,7 +25,6 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.watch import Watch from typing import Any, Callable, Generator, Iterable @@ -55,12 +57,21 @@ class DocumentReference(BaseDocumentReference): def __init__(self, *path, **kwargs) -> None: super(DocumentReference, self).__init__(*path, **kwargs) - def create(self, document_data) -> Any: + def create( + self, + document_data: dict, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Any: """Create the current document in the Firestore database. Args: document_data (dict): Property names and values to use for creating a document. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.types.WriteResult`: @@ -71,12 +82,17 @@ def create(self, document_data) -> Any: :class:`~google.cloud.exceptions.Conflict`: If the document already exists. """ - batch = self._client.batch() - batch.create(self, document_data) - write_results = batch.commit() + batch, kwargs = self._prep_create(document_data, retry, timeout) + write_results = batch.commit(**kwargs) return _first_write_result(write_results) - def set(self, document_data: dict, merge: bool = False) -> Any: + def set( + self, + document_data: dict, + merge: bool = False, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Any: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -96,18 +112,27 @@ def set(self, document_data: dict, merge: bool = False) -> Any: merge (Optional[bool] or Optional[List]): If True, apply merging instead of overwriting the state of the document. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.types.WriteResult`: The write result corresponding to the committed document. A write result contains an ``update_time`` field. """ - batch = self._client.batch() - batch.set(self, document_data, merge=merge) - write_results = batch.commit() + batch, kwargs = self._prep_set(document_data, merge, retry, timeout) + write_results = batch.commit(**kwargs) return _first_write_result(write_results) - def update(self, field_updates: dict, option: _helpers.WriteOption = None) -> Any: + def update( + self, + field_updates: dict, + option: _helpers.WriteOption = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Any: """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -241,6 +266,10 @@ def update(self, field_updates: dict, option: _helpers.WriteOption = None) -> An option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): A write option to make assertions / preconditions on the server state of the document before applying changes. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.types.WriteResult`: @@ -250,18 +279,26 @@ def update(self, field_updates: dict, option: _helpers.WriteOption = None) -> An Raises: ~google.cloud.exceptions.NotFound: If the document does not exist. """ - batch = self._client.batch() - batch.update(self, field_updates, option=option) - write_results = batch.commit() + batch, kwargs = self._prep_update(field_updates, option, retry, timeout) + write_results = batch.commit(**kwargs) return _first_write_result(write_results) - def delete(self, option: _helpers.WriteOption = None) -> Any: + def delete( + self, + option: _helpers.WriteOption = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Any: """Delete the current document in the Firestore database. Args: option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): A write option to make assertions / preconditions on the server state of the document before applying changes. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`google.protobuf.timestamp_pb2.Timestamp`: @@ -270,20 +307,20 @@ def delete(self, option: _helpers.WriteOption = None) -> Any: nothing was deleted), this method will still succeed and will still return the time that the request was received by the server. """ - write_pb = _helpers.pb_for_delete(self._document_path, option) + request, kwargs = self._prep_delete(option, retry, timeout) + commit_response = self._client._firestore_api.commit( - request={ - "database": self._client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) return commit_response.commit_time def get( - self, field_paths: Iterable[str] = None, transaction=None + self, + field_paths: Iterable[str] = None, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. @@ -302,6 +339,10 @@ def get( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this reference will be retrieved in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: :class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot`: @@ -311,23 +352,12 @@ def get( :attr:`create_time` attributes will all be ``None`` and its :attr:`exists` attribute will be ``False``. """ - if isinstance(field_paths, str): - raise ValueError("'field_paths' must be a sequence of paths, not a string.") - - if field_paths is not None: - mask = common.DocumentMask(field_paths=sorted(field_paths)) - else: - mask = None + request, kwargs = self._prep_get(field_paths, transaction, retry, timeout) firestore_api = self._client._firestore_api try: document_pb = firestore_api.get_document( - request={ - "name": self._document_path, - "mask": mask, - "transaction": _helpers.get_transaction_id(transaction), - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) except exceptions.NotFound: data = None @@ -349,13 +379,22 @@ def get( update_time=update_time, ) - def collections(self, page_size: int = None) -> Generator[Any, Any, None]: + def collections( + self, + page_size: int = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Generator[Any, Any, None]: """List subcollections of the current document. Args: page_size (Optional[int]]): The maximum number of collections - in each page of results from this request. Non-positive values - are ignored. Defaults to a sensible value set by the API. + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: @@ -363,22 +402,20 @@ def collections(self, page_size: int = None) -> Generator[Any, Any, None]: document does not exist at the time of `snapshot`, the iterator will be empty """ + request, kwargs = self._prep_collections(page_size, retry, timeout) + iterator = self._client._firestore_api.list_collection_ids( - request={"parent": self._document_path, "page_size": page_size}, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) while True: for i in iterator.collection_ids: yield self.collection(i) if iterator.next_page_token: + next_request = request.copy() + next_request["page_token"] = iterator.next_page_token iterator = self._client._firestore_api.list_collection_ids( - request={ - "parent": self._document_path, - "page_size": page_size, - "page_token": iterator.next_page_token, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs ) else: return diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index ef38b68f4..1716999be 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -18,6 +18,10 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_query import ( BaseCollectionGroup, BaseQuery, @@ -27,10 +31,11 @@ _enum_from_direction, ) -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator +from typing import Any +from typing import Callable +from typing import Generator class Query(BaseQuery): @@ -115,7 +120,12 @@ def __init__( all_descendants=all_descendants, ) - def get(self, transaction=None) -> list: + def get( + self, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> list: """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and returns a list of documents @@ -125,9 +135,13 @@ def get(self, transaction=None) -> list: transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. - If a ``transaction`` is used and it already has write operations - added, this method cannot be used (i.e. read-after-write is not - allowed). + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Returns: list: The documents in the collection that match this query. @@ -146,14 +160,17 @@ def get(self, transaction=None) -> list: ) self._limit_to_last = False - result = self.stream(transaction=transaction) + result = self.stream(transaction=transaction, retry=retry, timeout=timeout) if is_limited_to_last: result = reversed(list(result)) return list(result) def stream( - self, transaction=None + self, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, ) -> Generator[document.DocumentSnapshot, Any, None]: """Read the documents in the collection that match this query. @@ -176,25 +193,21 @@ def stream( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: The next document that fulfills the query. """ - if self._limit_to_last: - raise ValueError( - "Query results for queries that include limit_to_last() " - "constraints cannot be streamed. Use Query.get() instead." - ) + request, expected_prefix, kwargs = self._prep_stream( + transaction, retry, timeout, + ) - parent_path, expected_prefix = self._parent._parent_info() response_iterator = self._client._firestore_api.run_query( - request={ - "parent": parent_path, - "structured_query": self._to_protobuf(), - "transaction": _helpers.get_transaction_id(transaction), - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) for response in response_iterator: @@ -281,7 +294,16 @@ def __init__( all_descendants=all_descendants, ) - def get_partitions(self, partition_count) -> Generator[QueryPartition, None, None]: + @staticmethod + def _get_query_class(): + return Query + + def get_partitions( + self, + partition_count, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Generator[QueryPartition, None, None]: """Partition a query for parallelization. Partitions a query by returning partition cursors that can be used to run the @@ -292,24 +314,15 @@ def get_partitions(self, partition_count) -> Generator[QueryPartition, None, Non partition_count (int): The desired maximum number of partition points. The number must be strictly positive. The actual number of partitions returned may be fewer. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. """ - self._validate_partition_query() - query = Query( - self._parent, - orders=self._PARTITION_QUERY_ORDER, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) + request, kwargs = self._prep_get_partitions(partition_count, retry, timeout) - parent_path, expected_prefix = self._parent._parent_info() pager = self._client._firestore_api.partition_query( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "partition_count": partition_count, - }, - metadata=self._client._rpc_metadata, + request=request, metadata=self._client._rpc_metadata, **kwargs, ) start_at = None diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 1549fcf7d..7bab4b595 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -18,6 +18,9 @@ import random import time +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_transaction import ( _BaseTransactional, BaseTransaction, @@ -35,6 +38,7 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import batch from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.query import Query from typing import Any, Callable, Optional @@ -136,32 +140,53 @@ def _commit(self) -> list: self._clean_up() return list(commit_response.write_results) - def get_all(self, references: list) -> Any: + def get_all( + self, + references: list, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Any: """Retrieves multiple documents from Firestore. Args: references (List[.DocumentReference, ...]): Iterable of document references to be retrieved. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ - return self._client.get_all(references, transaction=self) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + return self._client.get_all(references, transaction=self, **kwargs) + + def get( + self, + ref_or_query, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Any: + """Retrieve a document or a query result from the database. - def get(self, ref_or_query) -> Any: - """ - Retrieve a document or a query result from the database. Args: - ref_or_query The document references or query object to return. + ref_or_query: The document references or query object to return. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) if isinstance(ref_or_query, DocumentReference): - return self._client.get_all([ref_or_query], transaction=self) + return self._client.get_all([ref_or_query], transaction=self, **kwargs) elif isinstance(ref_or_query, Query): - return ref_or_query.stream(transaction=self) + return ref_or_query.stream(transaction=self, **kwargs) else: raise ValueError( 'Value for argument "ref_or_query" must be a DocumentReference or a Query.' diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index c51084ac5..ff2aa3e1c 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -2173,7 +2173,7 @@ def test_without_option(self): self._helper(current_document=precondition) def test_with_exists_option(self): - from google.cloud.firestore_v1.client import _helpers + from google.cloud.firestore_v1 import _helpers option = _helpers.ExistsOption(False) self._helper(option=option) @@ -2387,6 +2387,51 @@ def test_modify_write(self): self.assertEqual(write_pb.current_document, expected_doc) +class Test_make_retry_timeout_kwargs(unittest.TestCase): + @staticmethod + def _call_fut(retry, timeout): + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs + + return make_retry_timeout_kwargs(retry, timeout) + + def test_default(self): + from google.api_core.gapic_v1.method import DEFAULT + + kwargs = self._call_fut(DEFAULT, None) + expected = {} + self.assertEqual(kwargs, expected) + + def test_retry_None(self): + kwargs = self._call_fut(None, None) + expected = {"retry": None} + self.assertEqual(kwargs, expected) + + def test_retry_only(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + kwargs = self._call_fut(retry, None) + expected = {"retry": retry} + self.assertEqual(kwargs, expected) + + def test_timeout_only(self): + from google.api_core.gapic_v1.method import DEFAULT + + timeout = 123.0 + kwargs = self._call_fut(DEFAULT, timeout) + expected = {"timeout": timeout} + self.assertEqual(kwargs, expected) + + def test_retry_and_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + kwargs = self._call_fut(retry, timeout) + expected = {"retry": retry, "timeout": timeout} + self.assertEqual(kwargs, expected) + + def _value_pb(**kwargs): from google.cloud.firestore_v1.types.document import Value diff --git a/tests/unit/v1/test_async_batch.py b/tests/unit/v1/test_async_batch.py index 59852fd88..dce1cefdf 100644 --- a/tests/unit/v1/test_async_batch.py +++ b/tests/unit/v1/test_async_batch.py @@ -37,9 +37,9 @@ def test_constructor(self): self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) - @pytest.mark.asyncio - async def test_commit(self): + async def _commit_helper(self, retry=None, timeout=None): from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import write @@ -51,6 +51,7 @@ async def test_commit(self): commit_time=timestamp, ) firestore_api.commit.return_value = commit_response + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Attach the fake GAPIC to a real client. client = _make_client("grand") @@ -59,12 +60,13 @@ async def test_commit(self): # Actually make a batch with some mutations and call commit(). batch = self._make_one(client) document1 = client.document("a", "b") - batch.create(document1, {"ten": 10, "buck": u"ets"}) + batch.create(document1, {"ten": 10, "buck": "ets"}) document2 = client.document("c", "d", "e", "f") batch.delete(document2) write_pbs = batch._write_pbs[::] - write_results = await batch.commit() + write_results = await batch.commit(**kwargs) + self.assertEqual(write_results, list(commit_response.write_results)) self.assertEqual(batch.write_results, write_results) self.assertEqual(batch.commit_time.timestamp_pb(), timestamp) @@ -79,8 +81,22 @@ async def test_commit(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) + @pytest.mark.asyncio + async def test_commit(self): + await self._commit_helper() + + @pytest.mark.asyncio + async def test_commit_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + + await self._commit_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_as_context_mgr_wo_error(self): from google.protobuf import timestamp_pb2 @@ -102,7 +118,7 @@ async def test_as_context_mgr_wo_error(self): async with batch as ctx_mgr: self.assertIs(ctx_mgr, batch) - ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) ctx_mgr.delete(document2) write_pbs = batch._write_pbs[::] @@ -132,7 +148,7 @@ async def test_as_context_mgr_w_error(self): with self.assertRaises(RuntimeError): async with batch as ctx_mgr: - ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) ctx_mgr.delete(document2) raise RuntimeError("testing") diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 770d6ae20..bf9787841 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -131,11 +131,11 @@ def test__get_collection_reference(self): def test_collection_group(self): client = self._make_default_one() - query = client.collection_group("collectionId").where("foo", "==", u"bar") + query = client.collection_group("collectionId").where("foo", "==", "bar") self.assertTrue(query._all_descendants) self.assertEqual(query._field_filters[0].field.field_path, "foo") - self.assertEqual(query._field_filters[0].value.string_value, u"bar") + self.assertEqual(query._field_filters[0].value.string_value, "bar") self.assertEqual( query._field_filters[0].op, query._field_filters[0].Operator.EQUAL ) @@ -195,11 +195,11 @@ def test_document_factory_w_nested_path(self): self.assertIs(document2._client, client) self.assertIsInstance(document2, AsyncDocumentReference) - @pytest.mark.asyncio - async def test_collections(self): + async def _collections_helper(self, retry=None, timeout=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + from google.cloud.firestore_v1 import _helpers collection_ids = ["users", "projects"] client = self._make_default_one() @@ -220,10 +220,11 @@ def _next_page(self): page, self._pages = self._pages[0], self._pages[1:] return Page(self, page, self.item_to_value) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = [c async for c in client.collections()] + collections = [c async for c in client.collections(**kwargs)] self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): @@ -233,10 +234,22 @@ def _next_page(self): base_path = client._database_string + "/documents" firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": base_path}, metadata=client._rpc_metadata + request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs, ) - async def _get_all_helper(self, client, references, document_pbs, **kwargs): + @pytest.mark.asyncio + async def test_collections(self): + await self._collections_helper() + + @pytest.mark.asyncio + async def test_collections_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._collections_helper(retry=retry, timeout=timeout) + + async def _invoke_get_all(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["batch_get_documents"]) response_iterator = AsyncIter(document_pbs) @@ -251,159 +264,115 @@ async def _get_all_helper(self, client, references, document_pbs, **kwargs): return [s async for s in snapshots] - def _info_for_get_all(self, data1, data2): + async def _get_all_helper( + self, num_snapshots=2, txn_id=None, retry=None, timeout=None + ): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.async_document import DocumentSnapshot + client = self._make_default_one() - document1 = client.document("pineapple", "lamp1") - document2 = client.document("pineapple", "lamp2") - # Make response protobufs. + data1 = {"a": "cheese"} + document1 = client.document("pineapple", "lamp1") document_pb1, read_time = _doc_get_info(document1._document_path, data1) response1 = _make_batch_response(found=document_pb1, read_time=read_time) + data2 = {"b": True, "c": 18} + document2 = client.document("pineapple", "lamp2") document, read_time = _doc_get_info(document2._document_path, data2) response2 = _make_batch_response(found=document, read_time=read_time) - return client, document1, document2, response1, response2 + document3 = client.document("pineapple", "lamp3") + response3 = _make_batch_response(missing=document3._document_path) - @pytest.mark.asyncio - async def test_get_all(self): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.async_document import DocumentSnapshot + expected_data = [data1, data2, None][:num_snapshots] + documents = [document1, document2, document3][:num_snapshots] + responses = [response1, response2, response3][:num_snapshots] + field_paths = [ + field_path for field_path in ["a", "b", None][:num_snapshots] if field_path + ] + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - data1 = {"a": u"cheese"} - data2 = {"b": True, "c": 18} - info = self._info_for_get_all(data1, data2) - client, document1, document2, response1, response2 = info + if txn_id is not None: + transaction = client.transaction() + transaction._id = txn_id + kwargs["transaction"] = transaction - # Exercise the mocked ``batch_get_documents``. - field_paths = ["a", "b"] - snapshots = await self._get_all_helper( - client, - [document1, document2], - [response1, response2], - field_paths=field_paths, + snapshots = await self._invoke_get_all( + client, documents, responses, field_paths=field_paths, **kwargs, ) - self.assertEqual(len(snapshots), 2) - snapshot1 = snapshots[0] - self.assertIsInstance(snapshot1, DocumentSnapshot) - self.assertIs(snapshot1._reference, document1) - self.assertEqual(snapshot1._data, data1) + self.assertEqual(len(snapshots), num_snapshots) - snapshot2 = snapshots[1] - self.assertIsInstance(snapshot2, DocumentSnapshot) - self.assertIs(snapshot2._reference, document2) - self.assertEqual(snapshot2._data, data2) + for data, document, snapshot in zip(expected_data, documents, snapshots): + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertIs(snapshot._reference, document) + if data is None: + self.assertFalse(snapshot.exists) + else: + self.assertEqual(snapshot._data, data) # Verify the call to the mock. - doc_paths = [document1._document_path, document2._document_path] + doc_paths = [document._document_path for document in documents] mask = common.DocumentMask(field_paths=field_paths) + + kwargs.pop("transaction", None) + client._firestore_api.batch_get_documents.assert_called_once_with( request={ "database": client._database_string, "documents": doc_paths, "mask": mask, - "transaction": None, + "transaction": txn_id, }, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio - async def test_get_all_with_transaction(self): - from google.cloud.firestore_v1.async_document import DocumentSnapshot + async def test_get_all(self): + await self._get_all_helper() - data = {"so-much": 484} - info = self._info_for_get_all(data, {}) - client, document, _, response, _ = info - transaction = client.transaction() + @pytest.mark.asyncio + async def test_get_all_with_transaction(self): txn_id = b"the-man-is-non-stop" - transaction._id = txn_id + await self._get_all_helper(num_snapshots=1, txn_id=txn_id) - # Exercise the mocked ``batch_get_documents``. - snapshots = await self._get_all_helper( - client, [document], [response], transaction=transaction - ) - self.assertEqual(len(snapshots), 1) + @pytest.mark.asyncio + async def test_get_all_w_retry_timeout(self): + from google.api_core.retry import Retry - snapshot = snapshots[0] - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertIs(snapshot._reference, document) - self.assertEqual(snapshot._data, data) + retry = Retry(predicate=object()) + timeout = 123.0 + await self._get_all_helper(retry=retry, timeout=timeout) - # Verify the call to the mock. - doc_paths = [document._document_path] - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": None, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) + @pytest.mark.asyncio + async def test_get_all_wrong_order(self): + await self._get_all_helper(num_snapshots=3) @pytest.mark.asyncio async def test_get_all_unknown_result(self): from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE - info = self._info_for_get_all({"z": 28.5}, {}) - client, document, _, _, response = info + client = self._make_default_one() + + expected_document = client.document("pineapple", "lamp1") + + data = {"z": 28.5} + wrong_document = client.document("pineapple", "lamp2") + document_pb, read_time = _doc_get_info(wrong_document._document_path, data) + response = _make_batch_response(found=document_pb, read_time=read_time) # Exercise the mocked ``batch_get_documents``. with self.assertRaises(ValueError) as exc_info: - await self._get_all_helper(client, [document], [response]) + await self._invoke_get_all(client, [expected_document], [response]) err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) self.assertEqual(exc_info.exception.args, (err_msg,)) # Verify the call to the mock. - doc_paths = [document._document_path] - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": None, - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test_get_all_wrong_order(self): - from google.cloud.firestore_v1.async_document import DocumentSnapshot - - data1 = {"up": 10} - data2 = {"down": -10} - info = self._info_for_get_all(data1, data2) - client, document1, document2, response1, response2 = info - document3 = client.document("pineapple", "lamp3") - response3 = _make_batch_response(missing=document3._document_path) - - # Exercise the mocked ``batch_get_documents``. - snapshots = await self._get_all_helper( - client, [document1, document2, document3], [response2, response1, response3] - ) - - self.assertEqual(len(snapshots), 3) - - snapshot1 = snapshots[0] - self.assertIsInstance(snapshot1, DocumentSnapshot) - self.assertIs(snapshot1._reference, document2) - self.assertEqual(snapshot1._data, data2) - - snapshot2 = snapshots[1] - self.assertIsInstance(snapshot2, DocumentSnapshot) - self.assertIs(snapshot2._reference, document1) - self.assertEqual(snapshot2._data, data1) - - self.assertFalse(snapshots[2].exists) - - # Verify the call to the mock. - doc_paths = [ - document1._document_path, - document2._document_path, - document3._document_path, - ] + doc_paths = [expected_document._document_path] client._firestore_api.batch_get_documents.assert_called_once_with( request={ "database": client._database_string, diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 1b7587c73..4a2f30de1 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -100,7 +100,7 @@ async def test_add_auto_assigned(self): # sure transforms during adds work. document_data = {"been": "here", "now": SERVER_TIMESTAMP} - patch = mock.patch("google.cloud.firestore_v1.async_collection._auto_id") + patch = mock.patch("google.cloud.firestore_v1.base_collection._auto_id") random_doc_id = "DEADBEEF" with patch as patched: patched.return_value = random_doc_id @@ -139,9 +139,9 @@ def _write_pb_for_create(document_path, document_data): current_document=common.Precondition(exists=False), ) - @pytest.mark.asyncio - async def test_add_explicit_id(self): + async def _add_helper(self, retry=None, timeout=None): from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1 import _helpers # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["commit"]) @@ -163,8 +163,10 @@ async def test_add_explicit_id(self): collection = self._make_one("parent", client=client) document_data = {"zorp": 208.75, "i-did-not": b"know that"} doc_id = "child" + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + update_time, document_ref = await collection.add( - document_data, document_id=doc_id + document_data, document_id=doc_id, **kwargs, ) # Verify the response and the mocks. @@ -181,10 +183,24 @@ async def test_add_explicit_id(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio - async def _list_documents_helper(self, page_size=None): + async def test_add_explicit_id(self): + await self._add_helper() + + @pytest.mark.asyncio + async def test_add_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._add_helper(retry=retry, timeout=timeout) + + @pytest.mark.asyncio + async def _list_documents_helper(self, page_size=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers from google.api_core.page_iterator_async import AsyncIterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_document import AsyncDocumentReference @@ -212,13 +228,15 @@ async def _next_page(self): firestore_api.list_documents.return_value = iterator client._firestore_api_internal = firestore_api collection = self._make_one("collection", client=client) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) if page_size is not None: documents = [ - i async for i in collection.list_documents(page_size=page_size) + i + async for i in collection.list_documents(page_size=page_size, **kwargs,) ] else: - documents = [i async for i in collection.list_documents()] + documents = [i async for i in collection.list_documents(**kwargs)] # Verify the response and the mocks. self.assertEqual(len(documents), len(document_ids)) @@ -236,12 +254,21 @@ async def _next_page(self): "show_missing": True, }, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio async def test_list_documents_wo_page_size(self): await self._list_documents_helper() + @pytest.mark.asyncio + async def test_list_documents_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._list_documents_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_list_documents_w_page_size(self): await self._list_documents_helper(page_size=25) @@ -258,6 +285,24 @@ async def test_get(self, query_class): self.assertIs(get_response, query_instance.get.return_value) query_instance.get.assert_called_once_with(transaction=None) + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) + @pytest.mark.asyncio + async def test_get_w_retry_timeout(self, query_class): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + collection = self._make_one("collection") + get_response = await collection.get(retry=retry, timeout=timeout) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + + self.assertIs(get_response, query_instance.get.return_value) + query_instance.get.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_get_with_transaction(self, query_class): @@ -286,6 +331,27 @@ async def test_stream(self, query_class): query_instance = query_class.return_value query_instance.stream.assert_called_once_with(transaction=None) + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) + @pytest.mark.asyncio + async def test_stream_w_retry_timeout(self, query_class): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + query_class.return_value.stream.return_value = AsyncIter(range(3)) + + collection = self._make_one("collection") + stream_response = collection.stream(retry=retry, timeout=timeout) + + async for _ in stream_response: + pass + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_stream_with_transaction(self, query_class): diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 79a89d4ab..04214fda8 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -71,8 +71,9 @@ def _write_pb_for_create(document_path, document_data): current_document=common.Precondition(exists=False), ) - @pytest.mark.asyncio - async def test_create(self): + async def _create_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock() firestore_api.commit.mock_add_spec(spec=["commit"]) @@ -85,7 +86,9 @@ async def test_create(self): # Actually make a document and call create(). document = self._make_one("foo", "twelve", client=client) document_data = {"hello": "goodbye", "count": 99} - write_result = await document.create(document_data) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + write_result = await document.create(document_data, **kwargs) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -97,8 +100,21 @@ async def test_create(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) + @pytest.mark.asyncio + async def test_create(self): + await self._create_helper() + + @pytest.mark.asyncio + async def test_create_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._create_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_create_empty(self): # Create a minimal fake GAPIC with a dummy response. @@ -153,7 +169,9 @@ def _write_pb_for_set(document_path, document_data, merge): return write_pbs @pytest.mark.asyncio - async def _set_helper(self, merge=False, **option_kwargs): + async def _set_helper(self, merge=False, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["commit"]) firestore_api.commit.return_value = self._make_commit_repsonse() @@ -165,7 +183,9 @@ async def _set_helper(self, merge=False, **option_kwargs): # Actually make a document and call create(). document = self._make_one("User", "Interface", client=client) document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} - write_result = await document.set(document_data, merge) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + write_result = await document.set(document_data, merge, **kwargs) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -178,12 +198,21 @@ async def _set_helper(self, merge=False, **option_kwargs): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio async def test_set(self): await self._set_helper() + @pytest.mark.asyncio + async def test_set_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._set_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_set_merge(self): await self._set_helper(merge=True) @@ -204,7 +233,8 @@ def _write_pb_for_update(document_path, update_values, field_paths): ) @pytest.mark.asyncio - async def _update_helper(self, **option_kwargs): + async def _update_helper(self, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.transforms import DELETE_FIELD # Create a minimal fake GAPIC with a dummy response. @@ -221,12 +251,14 @@ async def _update_helper(self, **option_kwargs): field_updates = collections.OrderedDict( (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) ) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if option_kwargs: option = client.write_option(**option_kwargs) - write_result = await document.update(field_updates, option=option) + write_result = await document.update(field_updates, option=option, **kwargs) else: option = None - write_result = await document.update(field_updates) + write_result = await document.update(field_updates, **kwargs) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -247,6 +279,7 @@ async def _update_helper(self, **option_kwargs): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio @@ -258,6 +291,14 @@ async def test_update_with_exists(self): async def test_update(self): await self._update_helper() + @pytest.mark.asyncio + async def test_update_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._update_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_update_with_precondition(self): from google.protobuf import timestamp_pb2 @@ -283,7 +324,8 @@ async def test_empty_update(self): await document.update(field_updates) @pytest.mark.asyncio - async def _delete_helper(self, **option_kwargs): + async def _delete_helper(self, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import write # Create a minimal fake GAPIC with a dummy response. @@ -293,15 +335,16 @@ async def _delete_helper(self, **option_kwargs): # Attach the fake GAPIC to a real client. client = _make_client("donut-base") client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) if option_kwargs: option = client.write_option(**option_kwargs) - delete_time = await document.delete(option=option) + delete_time = await document.delete(option=option, **kwargs) else: option = None - delete_time = await document.delete() + delete_time = await document.delete(**kwargs) # Verify the response and the mocks. self.assertIs(delete_time, mock.sentinel.commit_time) @@ -315,6 +358,7 @@ async def _delete_helper(self, **option_kwargs): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio @@ -328,11 +372,25 @@ async def test_delete_with_option(self): timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) await self._delete_helper(last_update_time=timestamp_pb) + @pytest.mark.asyncio + async def test_delete_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._delete_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def _get_helper( - self, field_paths=None, use_transaction=False, not_found=False + self, + field_paths=None, + use_transaction=False, + not_found=False, + retry=None, + timeout=None, ): from google.api_core.exceptions import NotFound + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.transaction import Transaction @@ -362,7 +420,11 @@ async def _get_helper( else: transaction = None - snapshot = await document.get(field_paths=field_paths, transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + snapshot = await document.get( + field_paths=field_paths, transaction=transaction, **kwargs, + ) self.assertIs(snapshot.reference, document) if not_found: @@ -396,6 +458,7 @@ async def _get_helper( "transaction": expected_transaction_id, }, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio @@ -406,6 +469,14 @@ async def test_get_not_found(self): async def test_get_default(self): await self._get_helper() + @pytest.mark.asyncio + async def test_get_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._get_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_get_w_string_field_path(self): with self.assertRaises(ValueError): @@ -424,7 +495,8 @@ async def test_get_with_transaction(self): await self._get_helper(use_transaction=True) @pytest.mark.asyncio - async def _collections_helper(self, page_size=None): + async def _collections_helper(self, page_size=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -449,13 +521,16 @@ def _next_page(self): client = _make_client() client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) if page_size is not None: - collections = [c async for c in document.collections(page_size=page_size)] + collections = [ + c async for c in document.collections(page_size=page_size, **kwargs) + ] else: - collections = [c async for c in document.collections()] + collections = [c async for c in document.collections(**kwargs)] # Verify the response and the mocks. self.assertEqual(len(collections), len(collection_ids)) @@ -467,12 +542,21 @@ def _next_page(self): firestore_api.list_collection_ids.assert_called_once_with( request={"parent": document._document_path, "page_size": page_size}, metadata=client._rpc_metadata, + **kwargs, ) @pytest.mark.asyncio - async def test_collections_wo_page_size(self): + async def test_collections(self): await self._collections_helper() + @pytest.mark.asyncio + async def test_collections_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._collections_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_collections_w_page_size(self): await self._collections_helper(page_size=10) diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 944c63ae0..23173ba17 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -58,8 +58,9 @@ def test_constructor(self): self.assertIsNone(query._end_at) self.assertFalse(query._all_descendants) - @pytest.mark.asyncio - async def test_get(self): + async def _get_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_query"]) @@ -76,12 +77,12 @@ async def test_get(self): data = {"snooze": 10} response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = AsyncIter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. query = self._make_one(parent) - returned = await query.get() + returned = await query.get(**kwargs) self.assertIsInstance(returned, list) self.assertEqual(len(returned), 1) @@ -90,6 +91,30 @@ async def test_get(self): self.assertEqual(snapshot.reference._path, ("dee", "sleep")) self.assertEqual(snapshot.to_dict(), data) + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + @pytest.mark.asyncio + async def test_get(self): + await self._get_helper() + + @pytest.mark.asyncio + async def test_get_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._get_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_get_limit_to_last(self): from google.cloud import firestore @@ -119,7 +144,7 @@ async def test_get_limit_to_last(self): # Execute the query and check the response. query = self._make_one(parent) query = query.order_by( - u"snooze", direction=firestore.AsyncQuery.DESCENDING + "snooze", direction=firestore.AsyncQuery.DESCENDING ).limit_to_last(2) returned = await query.get() @@ -149,8 +174,9 @@ async def test_get_limit_to_last(self): metadata=client._rpc_metadata, ) - @pytest.mark.asyncio - async def test_stream_simple(self): + async def _stream_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_query"]) @@ -167,10 +193,13 @@ async def test_stream_simple(self): data = {"snooze": 10} response_pb = _make_query_response(name=name, data=data) firestore_api.run_query.return_value = AsyncIter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. query = self._make_one(parent) - get_response = query.stream() + + get_response = query.stream(**kwargs) + self.assertIsInstance(get_response, types.AsyncGeneratorType) returned = [x async for x in get_response] self.assertEqual(len(returned), 1) @@ -187,8 +216,21 @@ async def test_stream_simple(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) + @pytest.mark.asyncio + async def test_stream_simple(self): + await self._stream_helper() + + @pytest.mark.asyncio + async def test_stream_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._stream_helper(retry=retry, timeout=timeout) + @pytest.mark.asyncio async def test_stream_with_limit_to_last(self): # Attach the fake GAPIC to a real client. @@ -466,7 +508,9 @@ def test_constructor_all_descendents_is_false(self): self._make_one(mock.sentinel.parent, all_descendants=False) @pytest.mark.asyncio - async def test_get_partitions(self): + async def _get_partitions_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["partition_query"]) @@ -485,10 +529,12 @@ async def test_get_partitions(self): cursor_pb1 = _make_cursor_pb(([document1], False)) cursor_pb2 = _make_cursor_pb(([document2], False)) firestore_api.partition_query.return_value = AsyncIter([cursor_pb1, cursor_pb2]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. query = self._make_one(parent) - get_response = query.get_partitions(2) + get_response = query.get_partitions(2, **kwargs) + self.assertIsInstance(get_response, types.AsyncGeneratorType) returned = [i async for i in get_response] self.assertEqual(len(returned), 3) @@ -505,8 +551,21 @@ async def test_get_partitions(self): "partition_count": 2, }, metadata=client._rpc_metadata, + **kwargs, ) + @pytest.mark.asyncio + async def test_get_partitions(self): + await self._get_partitions_helper() + + @pytest.mark.asyncio + async def test_get_partitions_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._get_partitions_helper(retry=retry, timeout=timeout) + async def test_get_partitions_w_filter(self): # Make a **real** collection reference as parent. client = _make_client() diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index ed732ae92..2e0f572b0 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -279,38 +279,84 @@ async def test__commit_failure(self): metadata=client._rpc_metadata, ) - @pytest.mark.asyncio - async def test_get_all(self): + async def _get_all_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + client = AsyncMock(spec=["get_all"]) transaction = self._make_one(client) ref1, ref2 = mock.Mock(), mock.Mock() - result = await transaction.get_all([ref1, ref2]) - client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = await transaction.get_all([ref1, ref2], **kwargs) + + client.get_all.assert_called_once_with( + [ref1, ref2], transaction=transaction, **kwargs, + ) self.assertIs(result, client.get_all.return_value) @pytest.mark.asyncio - async def test_get_document_ref(self): + async def test_get_all(self): + await self._get_all_helper() + + @pytest.mark.asyncio + async def test_get_all_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._get_all_helper(retry=retry, timeout=timeout) + + async def _get_w_document_ref_helper(self, retry=None, timeout=None): from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1 import _helpers client = AsyncMock(spec=["get_all"]) transaction = self._make_one(client) ref = AsyncDocumentReference("documents", "doc-id") - result = await transaction.get(ref) - client.get_all.assert_called_once_with([ref], transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = await transaction.get(ref, **kwargs) + + client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) self.assertIs(result, client.get_all.return_value) @pytest.mark.asyncio - async def test_get_w_query(self): + async def test_get_w_document_ref(self): + await self._get_w_document_ref_helper() + + @pytest.mark.asyncio + async def test_get_w_document_ref_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await self._get_w_document_ref_helper(retry=retry, timeout=timeout) + + async def _get_w_query_helper(self, retry=None, timeout=None): from google.cloud.firestore_v1.async_query import AsyncQuery + from google.cloud.firestore_v1 import _helpers client = AsyncMock(spec=[]) transaction = self._make_one(client) query = AsyncQuery(parent=AsyncMock(spec=[])) query.stream = AsyncMock() - result = await transaction.get(query) - query.stream.assert_called_once_with(transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = await transaction.get(query, **kwargs,) + + query.stream.assert_called_once_with( + transaction=transaction, **kwargs, + ) self.assertIs(result, query.stream.return_value) + @pytest.mark.asyncio + async def test_get_w_query(self): + await self._get_w_query_helper() + + @pytest.mark.asyncio + async def test_get_w_query_w_retry_timeout(self): + await self._get_w_query_helper() + @pytest.mark.asyncio async def test_get_failure(self): client = _make_client() diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index f21dee622..119942fc3 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -35,8 +35,9 @@ def test_constructor(self): self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) - def test_commit(self): + def _commit_helper(self, retry=None, timeout=None): from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import write @@ -48,6 +49,7 @@ def test_commit(self): commit_time=timestamp, ) firestore_api.commit.return_value = commit_response + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Attach the fake GAPIC to a real client. client = _make_client("grand") @@ -56,12 +58,12 @@ def test_commit(self): # Actually make a batch with some mutations and call commit(). batch = self._make_one(client) document1 = client.document("a", "b") - batch.create(document1, {"ten": 10, "buck": u"ets"}) + batch.create(document1, {"ten": 10, "buck": "ets"}) document2 = client.document("c", "d", "e", "f") batch.delete(document2) write_pbs = batch._write_pbs[::] - write_results = batch.commit() + write_results = batch.commit(**kwargs) self.assertEqual(write_results, list(commit_response.write_results)) self.assertEqual(batch.write_results, write_results) self.assertEqual(batch.commit_time.timestamp_pb(), timestamp) @@ -76,8 +78,20 @@ def test_commit(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) + def test_commit(self): + self._commit_helper() + + def test_commit_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + + self._commit_helper(retry=retry, timeout=timeout) + def test_as_context_mgr_wo_error(self): from google.protobuf import timestamp_pb2 from google.cloud.firestore_v1.types import firestore @@ -98,7 +112,7 @@ def test_as_context_mgr_wo_error(self): with batch as ctx_mgr: self.assertIs(ctx_mgr, batch) - ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) ctx_mgr.delete(document2) write_pbs = batch._write_pbs[::] @@ -127,7 +141,7 @@ def test_as_context_mgr_w_error(self): with self.assertRaises(RuntimeError): with batch as ctx_mgr: - ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) ctx_mgr.delete(document2) raise RuntimeError("testing") diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index b943fd1e1..e1995e5d4 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -129,11 +129,11 @@ def test__get_collection_reference(self): def test_collection_group(self): client = self._make_default_one() - query = client.collection_group("collectionId").where("foo", "==", u"bar") + query = client.collection_group("collectionId").where("foo", "==", "bar") self.assertTrue(query._all_descendants) self.assertEqual(query._field_filters[0].field.field_path, "foo") - self.assertEqual(query._field_filters[0].value.string_value, u"bar") + self.assertEqual(query._field_filters[0].value.string_value, "bar") self.assertEqual( query._field_filters[0].op, query._field_filters[0].Operator.EQUAL ) @@ -193,7 +193,8 @@ def test_document_factory_w_nested_path(self): self.assertIs(document2._client, client) self.assertIsInstance(document2, DocumentReference) - def test_collections(self): + def _collections_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.collection import CollectionReference @@ -216,10 +217,11 @@ def _next_page(self): page, self._pages = self._pages[0], self._pages[1:] return Page(self, page, self.item_to_value) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = list(client.collections()) + collections = list(client.collections(**kwargs)) self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): @@ -229,10 +231,20 @@ def _next_page(self): base_path = client._database_string + "/documents" firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": base_path}, metadata=client._rpc_metadata + request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs, ) - def _get_all_helper(self, client, references, document_pbs, **kwargs): + def test_collections(self): + self._collections_helper() + + def test_collections_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._collections_helper(retry=retry, timeout=timeout) + + def _invoke_get_all(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["batch_get_documents"]) response_iterator = iter(document_pbs) @@ -261,141 +273,108 @@ def _info_for_get_all(self, data1, data2): return client, document1, document2, response1, response2 - def test_get_all(self): + def _get_all_helper(self, num_snapshots=2, txn_id=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + client = self._make_default_one() + + data1 = {"a": "cheese"} + document1 = client.document("pineapple", "lamp1") + document_pb1, read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=read_time) - data1 = {"a": u"cheese"} data2 = {"b": True, "c": 18} - info = self._info_for_get_all(data1, data2) - client, document1, document2, response1, response2 = info + document2 = client.document("pineapple", "lamp2") + document, read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=read_time) - # Exercise the mocked ``batch_get_documents``. - field_paths = ["a", "b"] - snapshots = self._get_all_helper( - client, - [document1, document2], - [response1, response2], - field_paths=field_paths, + document3 = client.document("pineapple", "lamp3") + response3 = _make_batch_response(missing=document3._document_path) + + expected_data = [data1, data2, None][:num_snapshots] + documents = [document1, document2, document3][:num_snapshots] + responses = [response1, response2, response3][:num_snapshots] + field_paths = [ + field_path for field_path in ["a", "b", None][:num_snapshots] if field_path + ] + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + if txn_id is not None: + transaction = client.transaction() + transaction._id = txn_id + kwargs["transaction"] = transaction + + snapshots = self._invoke_get_all( + client, documents, responses, field_paths=field_paths, **kwargs, ) - self.assertEqual(len(snapshots), 2) - snapshot1 = snapshots[0] - self.assertIsInstance(snapshot1, DocumentSnapshot) - self.assertIs(snapshot1._reference, document1) - self.assertEqual(snapshot1._data, data1) + self.assertEqual(len(snapshots), num_snapshots) - snapshot2 = snapshots[1] - self.assertIsInstance(snapshot2, DocumentSnapshot) - self.assertIs(snapshot2._reference, document2) - self.assertEqual(snapshot2._data, data2) + for data, document, snapshot in zip(expected_data, documents, snapshots): + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertIs(snapshot._reference, document) + if data is None: + self.assertFalse(snapshot.exists) + else: + self.assertEqual(snapshot._data, data) # Verify the call to the mock. - doc_paths = [document1._document_path, document2._document_path] + doc_paths = [document._document_path for document in documents] mask = common.DocumentMask(field_paths=field_paths) + + kwargs.pop("transaction", None) + client._firestore_api.batch_get_documents.assert_called_once_with( request={ "database": client._database_string, "documents": doc_paths, "mask": mask, - "transaction": None, + "transaction": txn_id, }, metadata=client._rpc_metadata, + **kwargs, ) - def test_get_all_with_transaction(self): - from google.cloud.firestore_v1.document import DocumentSnapshot + def test_get_all(self): + self._get_all_helper() - data = {"so-much": 484} - info = self._info_for_get_all(data, {}) - client, document, _, response, _ = info - transaction = client.transaction() + def test_get_all_with_transaction(self): txn_id = b"the-man-is-non-stop" - transaction._id = txn_id + self._get_all_helper(num_snapshots=1, txn_id=txn_id) - # Exercise the mocked ``batch_get_documents``. - snapshots = self._get_all_helper( - client, [document], [response], transaction=transaction - ) - self.assertEqual(len(snapshots), 1) + def test_get_all_w_retry_timeout(self): + from google.api_core.retry import Retry - snapshot = snapshots[0] - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertIs(snapshot._reference, document) - self.assertEqual(snapshot._data, data) + retry = Retry(predicate=object()) + timeout = 123.0 + self._get_all_helper(retry=retry, timeout=timeout) - # Verify the call to the mock. - doc_paths = [document._document_path] - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": None, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) + def test_get_all_wrong_order(self): + self._get_all_helper(num_snapshots=3) def test_get_all_unknown_result(self): from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE - info = self._info_for_get_all({"z": 28.5}, {}) - client, document, _, _, response = info + client = self._make_default_one() + + expected_document = client.document("pineapple", "lamp1") + + data = {"z": 28.5} + wrong_document = client.document("pineapple", "lamp2") + document_pb, read_time = _doc_get_info(wrong_document._document_path, data) + response = _make_batch_response(found=document_pb, read_time=read_time) # Exercise the mocked ``batch_get_documents``. with self.assertRaises(ValueError) as exc_info: - self._get_all_helper(client, [document], [response]) + self._invoke_get_all(client, [expected_document], [response]) err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) self.assertEqual(exc_info.exception.args, (err_msg,)) # Verify the call to the mock. - doc_paths = [document._document_path] - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": None, - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - - def test_get_all_wrong_order(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - - data1 = {"up": 10} - data2 = {"down": -10} - info = self._info_for_get_all(data1, data2) - client, document1, document2, response1, response2 = info - document3 = client.document("pineapple", "lamp3") - response3 = _make_batch_response(missing=document3._document_path) - - # Exercise the mocked ``batch_get_documents``. - snapshots = self._get_all_helper( - client, [document1, document2, document3], [response2, response1, response3] - ) - - self.assertEqual(len(snapshots), 3) - - snapshot1 = snapshots[0] - self.assertIsInstance(snapshot1, DocumentSnapshot) - self.assertIs(snapshot1._reference, document2) - self.assertEqual(snapshot1._data, data2) - - snapshot2 = snapshots[1] - self.assertIsInstance(snapshot2, DocumentSnapshot) - self.assertIs(snapshot2._reference, document1) - self.assertEqual(snapshot2._data, data1) - - self.assertFalse(snapshots[2].exists) - - # Verify the call to the mock. - doc_paths = [ - document1._document_path, - document2._document_path, - document3._document_path, - ] + doc_paths = [expected_document._document_path] client._firestore_api.batch_get_documents.assert_called_once_with( request={ "database": client._database_string, diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 982cacdbc..b75dfdfa2 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -99,7 +99,7 @@ def test_add_auto_assigned(self): # sure transforms during adds work. document_data = {"been": "here", "now": SERVER_TIMESTAMP} - patch = mock.patch("google.cloud.firestore_v1.collection._auto_id") + patch = mock.patch("google.cloud.firestore_v1.base_collection._auto_id") random_doc_id = "DEADBEEF" with patch as patched: patched.return_value = random_doc_id @@ -138,8 +138,9 @@ def _write_pb_for_create(document_path, document_data): current_document=common.Precondition(exists=False), ) - def test_add_explicit_id(self): + def _add_helper(self, retry=None, timeout=None): from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1 import _helpers # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) @@ -161,7 +162,11 @@ def test_add_explicit_id(self): collection = self._make_one("parent", client=client) document_data = {"zorp": 208.75, "i-did-not": b"know that"} doc_id = "child" - update_time, document_ref = collection.add(document_data, document_id=doc_id) + + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + update_time, document_ref = collection.add( + document_data, document_id=doc_id, **kwargs + ) # Verify the response and the mocks. self.assertIs(update_time, mock.sentinel.update_time) @@ -177,9 +182,21 @@ def test_add_explicit_id(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) - def _list_documents_helper(self, page_size=None): + def test_add_explicit_id(self): + self._add_helper() + + def test_add_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._add_helper(retry=retry, timeout=timeout) + + def _list_documents_helper(self, page_size=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.document import DocumentReference @@ -207,11 +224,12 @@ def _next_page(self): api_client.list_documents.return_value = iterator client._firestore_api_internal = api_client collection = self._make_one("collection", client=client) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) if page_size is not None: - documents = list(collection.list_documents(page_size=page_size)) + documents = list(collection.list_documents(page_size=page_size, **kwargs)) else: - documents = list(collection.list_documents()) + documents = list(collection.list_documents(**kwargs)) # Verify the response and the mocks. self.assertEqual(len(documents), len(document_ids)) @@ -229,11 +247,19 @@ def _next_page(self): "show_missing": True, }, metadata=client._rpc_metadata, + **kwargs, ) def test_list_documents_wo_page_size(self): self._list_documents_helper() + def test_list_documents_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._list_documents_helper(retry=retry, timeout=timeout) + def test_list_documents_w_page_size(self): self._list_documents_helper(page_size=25) @@ -248,6 +274,23 @@ def test_get(self, query_class): self.assertIs(get_response, query_instance.get.return_value) query_instance.get.assert_called_once_with(transaction=None) + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + def test_get_w_retry_timeout(self, query_class): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + collection = self._make_one("collection") + get_response = collection.get(retry=retry, timeout=timeout) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + + self.assertIs(get_response, query_instance.get.return_value) + query_instance.get.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) def test_get_with_transaction(self, query_class): @@ -271,6 +314,22 @@ def test_stream(self, query_class): self.assertIs(stream_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=None) + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + def test_stream_w_retry_timeout(self, query_class): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + collection = self._make_one("collection") + stream_response = collection.stream(retry=retry, timeout=timeout) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + self.assertIs(stream_response, query_instance.stream.return_value) + query_instance.stream.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) def test_stream_with_transaction(self, query_class): collection = self._make_one("collection") diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index ff06532c4..ef55508d1 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -69,7 +69,9 @@ def _write_pb_for_create(document_path, document_data): current_document=common.Precondition(exists=False), ) - def test_create(self): + def _create_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock() firestore_api.commit.mock_add_spec(spec=["commit"]) @@ -82,7 +84,9 @@ def test_create(self): # Actually make a document and call create(). document = self._make_one("foo", "twelve", client=client) document_data = {"hello": "goodbye", "count": 99} - write_result = document.create(document_data) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + write_result = document.create(document_data, **kwargs) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -94,8 +98,19 @@ def test_create(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) + def test_create(self): + self._create_helper() + + def test_create_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._create_helper(retry=retry, timeout=timeout) + def test_create_empty(self): # Create a minimal fake GAPIC with a dummy response. from google.cloud.firestore_v1.document import DocumentReference @@ -148,7 +163,9 @@ def _write_pb_for_set(document_path, document_data, merge): write_pbs._pb.update_mask.CopyFrom(mask._pb) return write_pbs - def _set_helper(self, merge=False, **option_kwargs): + def _set_helper(self, merge=False, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) firestore_api.commit.return_value = self._make_commit_repsonse() @@ -160,7 +177,9 @@ def _set_helper(self, merge=False, **option_kwargs): # Actually make a document and call create(). document = self._make_one("User", "Interface", client=client) document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} - write_result = document.set(document_data, merge) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + write_result = document.set(document_data, merge, **kwargs) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -173,11 +192,19 @@ def _set_helper(self, merge=False, **option_kwargs): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) def test_set(self): self._set_helper() + def test_set_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._set_helper(retry=retry, timeout=timeout) + def test_set_merge(self): self._set_helper(merge=True) @@ -196,7 +223,8 @@ def _write_pb_for_update(document_path, update_values, field_paths): current_document=common.Precondition(exists=True), ) - def _update_helper(self, **option_kwargs): + def _update_helper(self, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.transforms import DELETE_FIELD # Create a minimal fake GAPIC with a dummy response. @@ -213,12 +241,14 @@ def _update_helper(self, **option_kwargs): field_updates = collections.OrderedDict( (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) ) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if option_kwargs: option = client.write_option(**option_kwargs) - write_result = document.update(field_updates, option=option) + write_result = document.update(field_updates, option=option, **kwargs) else: option = None - write_result = document.update(field_updates) + write_result = document.update(field_updates, **kwargs) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -239,6 +269,7 @@ def _update_helper(self, **option_kwargs): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) def test_update_with_exists(self): @@ -248,6 +279,13 @@ def test_update_with_exists(self): def test_update(self): self._update_helper() + def test_update_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._update_helper(retry=retry, timeout=timeout) + def test_update_with_precondition(self): from google.protobuf import timestamp_pb2 @@ -270,7 +308,8 @@ def test_empty_update(self): with self.assertRaises(ValueError): document.update(field_updates) - def _delete_helper(self, **option_kwargs): + def _delete_helper(self, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import write # Create a minimal fake GAPIC with a dummy response. @@ -280,15 +319,16 @@ def _delete_helper(self, **option_kwargs): # Attach the fake GAPIC to a real client. client = _make_client("donut-base") client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) if option_kwargs: option = client.write_option(**option_kwargs) - delete_time = document.delete(option=option) + delete_time = document.delete(option=option, **kwargs) else: option = None - delete_time = document.delete() + delete_time = document.delete(**kwargs) # Verify the response and the mocks. self.assertIs(delete_time, mock.sentinel.commit_time) @@ -302,6 +342,7 @@ def _delete_helper(self, **option_kwargs): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) def test_delete(self): @@ -313,8 +354,23 @@ def test_delete_with_option(self): timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) self._delete_helper(last_update_time=timestamp_pb) - def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): + def test_delete_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._delete_helper(retry=retry, timeout=timeout) + + def _get_helper( + self, + field_paths=None, + use_transaction=False, + not_found=False, + retry=None, + timeout=None, + ): from google.api_core.exceptions import NotFound + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.transaction import Transaction @@ -344,7 +400,11 @@ def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): else: transaction = None - snapshot = document.get(field_paths=field_paths, transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + snapshot = document.get( + field_paths=field_paths, transaction=transaction, **kwargs + ) self.assertIs(snapshot.reference, document) if not_found: @@ -378,6 +438,7 @@ def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): "transaction": expected_transaction_id, }, metadata=client._rpc_metadata, + **kwargs, ) def test_get_not_found(self): @@ -386,6 +447,13 @@ def test_get_not_found(self): def test_get_default(self): self._get_helper() + def test_get_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._get_helper(retry=retry, timeout=timeout) + def test_get_w_string_field_path(self): with self.assertRaises(ValueError): self._get_helper(field_paths="foo") @@ -399,10 +467,11 @@ def test_get_with_multiple_field_paths(self): def test_get_with_transaction(self): self._get_helper(use_transaction=True) - def _collections_helper(self, page_size=None): + def _collections_helper(self, page_size=None, retry=None, timeout=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.services.firestore.client import FirestoreClient # TODO(microgen): https://github.com/googleapis/gapic-generator-python/issues/516 @@ -424,13 +493,14 @@ def _next_page(self): client = _make_client() client._firestore_api_internal = api_client + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) if page_size is not None: - collections = list(document.collections(page_size=page_size)) + collections = list(document.collections(page_size=page_size, **kwargs)) else: - collections = list(document.collections()) + collections = list(document.collections(**kwargs)) # Verify the response and the mocks. self.assertEqual(len(collections), len(collection_ids)) @@ -442,6 +512,7 @@ def _next_page(self): api_client.list_collection_ids.assert_called_once_with( request={"parent": document._document_path, "page_size": page_size}, metadata=client._rpc_metadata, + **kwargs, ) def test_collections_wo_page_size(self): @@ -450,6 +521,13 @@ def test_collections_wo_page_size(self): def test_collections_w_page_size(self): self._collections_helper(page_size=10) + def test_collections_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._collections_helper(retry=retry, timeout=timeout) + @mock.patch("google.cloud.firestore_v1.document.Watch", autospec=True) def test_on_snapshot(self, watch): client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index e2290db37..91172b120 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -46,7 +46,9 @@ def test_constructor(self): self.assertIsNone(query._end_at) self.assertFalse(query._all_descendants) - def test_get(self): + def _get_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -63,12 +65,12 @@ def test_get(self): data = {"snooze": 10} response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = iter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. query = self._make_one(parent) - returned = query.get() + returned = query.get(**kwargs) self.assertIsInstance(returned, list) self.assertEqual(len(returned), 1) @@ -77,6 +79,28 @@ def test_get(self): self.assertEqual(snapshot.reference._path, ("dee", "sleep")) self.assertEqual(snapshot.to_dict(), data) + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + def test_get(self): + self._get_helper() + + def test_get_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._get_helper(retry=retry, timeout=timeout) + def test_get_limit_to_last(self): from google.cloud import firestore from google.cloud.firestore_v1.base_query import _enum_from_direction @@ -105,7 +129,7 @@ def test_get_limit_to_last(self): # Execute the query and check the response. query = self._make_one(parent) query = query.order_by( - u"snooze", direction=firestore.Query.DESCENDING + "snooze", direction=firestore.Query.DESCENDING ).limit_to_last(2) returned = query.get() @@ -134,7 +158,9 @@ def test_get_limit_to_last(self): metadata=client._rpc_metadata, ) - def test_stream_simple(self): + def _stream_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -151,10 +177,13 @@ def test_stream_simple(self): data = {"snooze": 10} response_pb = _make_query_response(name=name, data=data) firestore_api.run_query.return_value = iter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. query = self._make_one(parent) - get_response = query.stream() + + get_response = query.stream(**kwargs) + self.assertIsInstance(get_response, types.GeneratorType) returned = list(get_response) self.assertEqual(len(returned), 1) @@ -171,8 +200,19 @@ def test_stream_simple(self): "transaction": None, }, metadata=client._rpc_metadata, + **kwargs, ) + def test_stream_simple(self): + self._stream_helper() + + def test_stream_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._stream_helper(retry=retry, timeout=timeout) + def test_stream_with_limit_to_last(self): # Attach the fake GAPIC to a real client. client = _make_client() @@ -448,7 +488,9 @@ def test_constructor_all_descendents_is_false(self): with pytest.raises(ValueError): self._make_one(mock.sentinel.parent, all_descendants=False) - def test_get_partitions(self): + def _get_partitions_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["partition_query"]) @@ -467,10 +509,13 @@ def test_get_partitions(self): cursor_pb1 = _make_cursor_pb(([document1], False)) cursor_pb2 = _make_cursor_pb(([document2], False)) firestore_api.partition_query.return_value = iter([cursor_pb1, cursor_pb2]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. query = self._make_one(parent) - get_response = query.get_partitions(2) + + get_response = query.get_partitions(2, **kwargs) + self.assertIsInstance(get_response, types.GeneratorType) returned = list(get_response) self.assertEqual(len(returned), 3) @@ -487,8 +532,19 @@ def test_get_partitions(self): "partition_count": 2, }, metadata=client._rpc_metadata, + **kwargs, ) + def test_get_partitions(self): + self._get_partitions_helper() + + def test_get_partitions_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._get_partitions_helper(retry=retry, timeout=timeout) + def test_get_partitions_w_filter(self): # Make a **real** collection reference as parent. client = _make_client() diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index a32e58c10..3a093a335 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -291,34 +291,79 @@ def test__commit_failure(self): metadata=client._rpc_metadata, ) - def test_get_all(self): + def _get_all_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + client = mock.Mock(spec=["get_all"]) transaction = self._make_one(client) ref1, ref2 = mock.Mock(), mock.Mock() - result = transaction.get_all([ref1, ref2]) - client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = transaction.get_all([ref1, ref2], **kwargs) + + client.get_all.assert_called_once_with( + [ref1, ref2], transaction=transaction, **kwargs, + ) self.assertIs(result, client.get_all.return_value) - def test_get_document_ref(self): + def test_get_all(self): + self._get_all_helper() + + def test_get_all_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._get_all_helper(retry=retry, timeout=timeout) + + def _get_w_document_ref_helper(self, retry=None, timeout=None): from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1 import _helpers client = mock.Mock(spec=["get_all"]) transaction = self._make_one(client) ref = DocumentReference("documents", "doc-id") - result = transaction.get(ref) - client.get_all.assert_called_once_with([ref], transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = transaction.get(ref, **kwargs) + self.assertIs(result, client.get_all.return_value) + client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) - def test_get_w_query(self): + def test_get_w_document_ref(self): + self._get_w_document_ref_helper() + + def test_get_w_document_ref_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._get_w_document_ref_helper(retry=retry, timeout=timeout) + + def _get_w_query_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.query import Query client = mock.Mock(spec=[]) transaction = self._make_one(client) query = Query(parent=mock.Mock(spec=[])) query.stream = mock.MagicMock() - result = transaction.get(query) - query.stream.assert_called_once_with(transaction=transaction) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = transaction.get(query, **kwargs) + self.assertIs(result, query.stream.return_value) + query.stream.assert_called_once_with(transaction=transaction, **kwargs) + + def test_get_w_query(self): + self._get_w_query_helper() + + def test_get_w_query_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + self._get_w_query_helper(retry=retry, timeout=timeout) def test_get_failure(self): client = _make_client()