Skip to content

Commit

Permalink
feat: add retry/timeout to 'client.Client.get_all'
Browse files Browse the repository at this point in the history
Toward #221
  • Loading branch information
tseaver committed Oct 13, 2020
1 parent c122e41 commit 1b61f49
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
16 changes: 16 additions & 0 deletions google/cloud/firestore_v1/client.py
Expand Up @@ -24,6 +24,8 @@
:class:`~google.cloud.firestore_v1.document.DocumentReference`
"""

from google.api_core import retry as retries # type: ignore

from google.cloud.firestore_v1.base_client import (
BaseClient,
DEFAULT_DATABASE,
Expand Down Expand Up @@ -207,6 +209,8 @@ def get_all(
references: list,
field_paths: Iterable[str] = None,
transaction: Transaction = None,
retry: retries.Retry = None,
timeout: float = None,
) -> Generator[Any, Any, None]:
"""Retrieve a batch of documents.
Expand Down Expand Up @@ -237,13 +241,24 @@ def get_all(
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that these ``references`` will be
retrieved in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Yields:
.DocumentSnapshot: The next document snapshot that fulfills the
query, or :data:`None` if the document does not exist.
"""
document_paths, reference_map = _reference_info(references)
mask = _get_doc_mask(field_paths)
kwargs = {}

if retry is not None:
kwargs["retry"] = retry

if timeout is not None:
kwargs["timeout"] = timeout

response_iterator = self._firestore_api.batch_get_documents(
request={
"database": self._database_string,
Expand All @@ -252,6 +267,7 @@ def get_all(
"transaction": _helpers.get_transaction_id(transaction),
},
metadata=self._rpc_metadata,
**kwargs,
)

for get_doc_response in response_iterator:
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/v1/test_client.py
Expand Up @@ -303,6 +303,55 @@ def test_get_all(self):
metadata=client._rpc_metadata,
)

def test_get_all_w_retry_timeout(self):
from google.api_core.retry import Retry
from google.cloud.firestore_v1.types import common
from google.cloud.firestore_v1.document import DocumentSnapshot

data1 = {"a": u"cheese"}
data2 = {"b": True, "c": 18}
retry = Retry(predicate=object())
timeout = 123.0
info = self._info_for_get_all(data1, data2)
client, document1, document2, response1, response2 = info

# Exercise the mocked ``batch_get_documents``.
field_paths = ["a", "b"]
snapshots = self._get_all_helper(
client,
[document1, document2],
[response1, response2],
field_paths=field_paths,
retry=retry,
timeout=timeout,
)
self.assertEqual(len(snapshots), 2)

snapshot1 = snapshots[0]
self.assertIsInstance(snapshot1, DocumentSnapshot)
self.assertIs(snapshot1._reference, document1)
self.assertEqual(snapshot1._data, data1)

snapshot2 = snapshots[1]
self.assertIsInstance(snapshot2, DocumentSnapshot)
self.assertIs(snapshot2._reference, document2)
self.assertEqual(snapshot2._data, data2)

# Verify the call to the mock.
doc_paths = [document1._document_path, document2._document_path]
mask = common.DocumentMask(field_paths=field_paths)
client._firestore_api.batch_get_documents.assert_called_once_with(
request={
"database": client._database_string,
"documents": doc_paths,
"mask": mask,
"transaction": None,
},
retry=retry,
timeout=timeout,
metadata=client._rpc_metadata,
)

def test_get_all_with_transaction(self):
from google.cloud.firestore_v1.document import DocumentSnapshot

Expand Down

0 comments on commit 1b61f49

Please sign in to comment.