Skip to content

Commit

Permalink
feat: add retry/timeout to 'async_client.AsyncClient.{collections.get…
Browse files Browse the repository at this point in the history
…_all}'

Toward #221
  • Loading branch information
tseaver committed Oct 14, 2020
1 parent bad75e1 commit db01b59
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 128 deletions.
50 changes: 29 additions & 21 deletions google/cloud/firestore_v1/async_client.py
Expand Up @@ -24,17 +24,16 @@
:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`
"""

from google.api_core import retry as retries # type: ignore

from google.cloud.firestore_v1.base_client import (
BaseClient,
DEFAULT_DATABASE,
_CLIENT_INFO,
_reference_info, # type: ignore
_parse_batch_get, # type: ignore
_get_doc_mask,
_path_helper,
)

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.async_query import AsyncCollectionGroup
from google.cloud.firestore_v1.async_batch import AsyncWriteBatch
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference
Expand Down Expand Up @@ -208,7 +207,12 @@ def document(self, *document_path: Tuple[str]) -> AsyncDocumentReference:
)

async def get_all(
self, references: list, field_paths: Iterable[str] = None, transaction=None,
self,
references: list,
field_paths: Iterable[str] = None,
transaction=None,
retry: retries.Retry = None,
timeout: float = None,
) -> AsyncGenerator[DocumentSnapshot, Any]:
"""Retrieve a batch of documents.
Expand Down Expand Up @@ -239,48 +243,52 @@ async def get_all(
transaction (Optional[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`]):
An existing transaction that these ``references`` will be
retrieved in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Yields:
.DocumentSnapshot: The next document snapshot that fulfills the
query, or :data:`None` if the document does not exist.
"""
document_paths, reference_map = _reference_info(references)
mask = _get_doc_mask(field_paths)
request, reference_map, kwargs = self._prep_get_all(
references, field_paths, transaction, retry, timeout
)

response_iterator = await self._firestore_api.batch_get_documents(
request={
"database": self._database_string,
"documents": document_paths,
"mask": mask,
"transaction": _helpers.get_transaction_id(transaction),
},
metadata=self._rpc_metadata,
request=request, metadata=self._rpc_metadata, **kwargs,
)

async for get_doc_response in response_iterator:
yield _parse_batch_get(get_doc_response, reference_map, self)

async def collections(self) -> AsyncGenerator[AsyncCollectionReference, Any]:
async def collections(
self, retry: retries.Retry = None, timeout: float = None,
) -> AsyncGenerator[AsyncCollectionReference, Any]:
"""List top-level collections of the client's database.
Args:
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]:
iterator of subcollections of the current document.
"""
request, kwargs = self._prep_collections(retry, timeout)
iterator = await self._firestore_api.list_collection_ids(
request={"parent": "{}/documents".format(self._database_string)},
metadata=self._rpc_metadata,
request=request, metadata=self._rpc_metadata, **kwargs,
)

while True:
for i in iterator.collection_ids:
yield self.collection(i)
if iterator.next_page_token:
next_request = request.copy()
next_request["page_token"] = iterator.next_page_token
iterator = await self._firestore_api.list_collection_ids(
request={
"parent": "{}/documents".format(self._database_string),
"page_token": iterator.next_page_token,
},
metadata=self._rpc_metadata,
request=next_request, metadata=self._rpc_metadata, **kwargs,
)
else:
return
Expand Down
209 changes: 102 additions & 107 deletions tests/unit/v1/test_async_client.py
Expand Up @@ -131,11 +131,11 @@ def test__get_collection_reference(self):

def test_collection_group(self):
client = self._make_default_one()
query = client.collection_group("collectionId").where("foo", "==", u"bar")
query = client.collection_group("collectionId").where("foo", "==", "bar")

self.assertTrue(query._all_descendants)
self.assertEqual(query._field_filters[0].field.field_path, "foo")
self.assertEqual(query._field_filters[0].value.string_value, u"bar")
self.assertEqual(query._field_filters[0].value.string_value, "bar")
self.assertEqual(
query._field_filters[0].op, query._field_filters[0].Operator.EQUAL
)
Expand Down Expand Up @@ -195,8 +195,7 @@ def test_document_factory_w_nested_path(self):
self.assertIs(document2._client, client)
self.assertIsInstance(document2, AsyncDocumentReference)

@pytest.mark.asyncio
async def test_collections(self):
async def _collections_helper(self, retry=None, timeout=None):
from google.api_core.page_iterator import Iterator
from google.api_core.page_iterator import Page
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference
Expand All @@ -220,10 +219,18 @@ def _next_page(self):
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)

kwargs = {}

if retry is not None:
kwargs["retry"] = retry

if timeout is not None:
kwargs["timeout"] = timeout

iterator = _Iterator(pages=[collection_ids])
firestore_api.list_collection_ids.return_value = iterator

collections = [c async for c in client.collections()]
collections = [c async for c in client.collections(**kwargs)]

self.assertEqual(len(collections), len(collection_ids))
for collection, collection_id in zip(collections, collection_ids):
Expand All @@ -233,10 +240,22 @@ def _next_page(self):

base_path = client._database_string + "/documents"
firestore_api.list_collection_ids.assert_called_once_with(
request={"parent": base_path}, metadata=client._rpc_metadata
request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs,
)

async def _get_all_helper(self, client, references, document_pbs, **kwargs):
@pytest.mark.asyncio
async def test_collections(self):
await self._collections_helper()

@pytest.mark.asyncio
async def test_collections_w_retry_timeout(self):
from google.api_core.retry import Retry

retry = Retry(predicate=object())
timeout = 123.0
await self._collections_helper(retry=retry, timeout=timeout)

async def _invoke_get_all(self, client, references, document_pbs, **kwargs):
# Create a minimal fake GAPIC with a dummy response.
firestore_api = AsyncMock(spec=["batch_get_documents"])
response_iterator = AsyncIter(document_pbs)
Expand Down Expand Up @@ -265,145 +284,121 @@ def _info_for_get_all(self, data1, data2):

return client, document1, document2, response1, response2

@pytest.mark.asyncio
async def test_get_all(self):
async def _get_all_helper(
self, num_snapshots=2, txn_id=None, retry=None, timeout=None
):
from google.cloud.firestore_v1.types import common
from google.cloud.firestore_v1.async_document import DocumentSnapshot

data1 = {"a": u"cheese"}
client = self._make_default_one()

data1 = {"a": "cheese"}
document1 = client.document("pineapple", "lamp1")
document_pb1, read_time = _doc_get_info(document1._document_path, data1)
response1 = _make_batch_response(found=document_pb1, read_time=read_time)

data2 = {"b": True, "c": 18}
info = self._info_for_get_all(data1, data2)
client, document1, document2, response1, response2 = info
document2 = client.document("pineapple", "lamp2")
document, read_time = _doc_get_info(document2._document_path, data2)
response2 = _make_batch_response(found=document, read_time=read_time)

# Exercise the mocked ``batch_get_documents``.
field_paths = ["a", "b"]
snapshots = await self._get_all_helper(
client,
[document1, document2],
[response1, response2],
field_paths=field_paths,
)
self.assertEqual(len(snapshots), 2)
document3 = client.document("pineapple", "lamp3")
response3 = _make_batch_response(missing=document3._document_path)

snapshot1 = snapshots[0]
self.assertIsInstance(snapshot1, DocumentSnapshot)
self.assertIs(snapshot1._reference, document1)
self.assertEqual(snapshot1._data, data1)
expected_data = [data1, data2, None][:num_snapshots]
documents = [document1, document2, document3][:num_snapshots]
responses = [response1, response2, response3][:num_snapshots]
field_paths = [
field_path for field_path in ["a", "b", None][:num_snapshots] if field_path
]

snapshot2 = snapshots[1]
self.assertIsInstance(snapshot2, DocumentSnapshot)
self.assertIs(snapshot2._reference, document2)
self.assertEqual(snapshot2._data, data2)
kwargs = {}

# Verify the call to the mock.
doc_paths = [document1._document_path, document2._document_path]
mask = common.DocumentMask(field_paths=field_paths)
client._firestore_api.batch_get_documents.assert_called_once_with(
request={
"database": client._database_string,
"documents": doc_paths,
"mask": mask,
"transaction": None,
},
metadata=client._rpc_metadata,
)
if retry is not None:
kwargs["retry"] = retry

@pytest.mark.asyncio
async def test_get_all_with_transaction(self):
from google.cloud.firestore_v1.async_document import DocumentSnapshot
if timeout is not None:
kwargs["timeout"] = timeout

data = {"so-much": 484}
info = self._info_for_get_all(data, {})
client, document, _, response, _ = info
transaction = client.transaction()
txn_id = b"the-man-is-non-stop"
transaction._id = txn_id
if txn_id is not None:
transaction = client.transaction()
transaction._id = txn_id
kwargs["transaction"] = transaction

# Exercise the mocked ``batch_get_documents``.
snapshots = await self._get_all_helper(
client, [document], [response], transaction=transaction
snapshots = await self._invoke_get_all(
client, documents, responses, field_paths=field_paths, **kwargs,
)
self.assertEqual(len(snapshots), 1)

snapshot = snapshots[0]
self.assertIsInstance(snapshot, DocumentSnapshot)
self.assertIs(snapshot._reference, document)
self.assertEqual(snapshot._data, data)
self.assertEqual(len(snapshots), num_snapshots)

for data, document, snapshot in zip(expected_data, documents, snapshots):
self.assertIsInstance(snapshot, DocumentSnapshot)
self.assertIs(snapshot._reference, document)
if data is None:
self.assertFalse(snapshot.exists)
else:
self.assertEqual(snapshot._data, data)

# Verify the call to the mock.
doc_paths = [document._document_path]
doc_paths = [document._document_path for document in documents]
mask = common.DocumentMask(field_paths=field_paths)

kwargs.pop("transaction", None)

client._firestore_api.batch_get_documents.assert_called_once_with(
request={
"database": client._database_string,
"documents": doc_paths,
"mask": None,
"mask": mask,
"transaction": txn_id,
},
metadata=client._rpc_metadata,
**kwargs,
)

@pytest.mark.asyncio
async def test_get_all_unknown_result(self):
from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE
async def test_get_all(self):
await self._get_all_helper()

info = self._info_for_get_all({"z": 28.5}, {})
client, document, _, _, response = info
@pytest.mark.asyncio
async def test_get_all_with_transaction(self):
txn_id = b"the-man-is-non-stop"
await self._get_all_helper(num_snapshots=1, txn_id=txn_id)

# Exercise the mocked ``batch_get_documents``.
with self.assertRaises(ValueError) as exc_info:
await self._get_all_helper(client, [document], [response])
@pytest.mark.asyncio
async def test_get_all_w_retry_timeout(self):
from google.api_core.retry import Retry

err_msg = _BAD_DOC_TEMPLATE.format(response.found.name)
self.assertEqual(exc_info.exception.args, (err_msg,))

# Verify the call to the mock.
doc_paths = [document._document_path]
client._firestore_api.batch_get_documents.assert_called_once_with(
request={
"database": client._database_string,
"documents": doc_paths,
"mask": None,
"transaction": None,
},
metadata=client._rpc_metadata,
)
retry = Retry(predicate=object())
timeout = 123.0
await self._get_all_helper(retry=retry, timeout=timeout)

@pytest.mark.asyncio
async def test_get_all_wrong_order(self):
from google.cloud.firestore_v1.async_document import DocumentSnapshot
await self._get_all_helper(num_snapshots=3)

data1 = {"up": 10}
data2 = {"down": -10}
info = self._info_for_get_all(data1, data2)
client, document1, document2, response1, response2 = info
document3 = client.document("pineapple", "lamp3")
response3 = _make_batch_response(missing=document3._document_path)
@pytest.mark.asyncio
async def test_get_all_unknown_result(self):
from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE

# Exercise the mocked ``batch_get_documents``.
snapshots = await self._get_all_helper(
client, [document1, document2, document3], [response2, response1, response3]
)
client = self._make_default_one()

self.assertEqual(len(snapshots), 3)
document1 = client.document("pineapple", "lamp1")

snapshot1 = snapshots[0]
self.assertIsInstance(snapshot1, DocumentSnapshot)
self.assertIs(snapshot1._reference, document2)
self.assertEqual(snapshot1._data, data2)
data = {"z": 28.5}
wrong_document = client.document("pineapple", "lamp2")
document_pb, read_time = _doc_get_info(wrong_document._document_path, data)
response = _make_batch_response(found=document_pb, read_time=read_time)

snapshot2 = snapshots[1]
self.assertIsInstance(snapshot2, DocumentSnapshot)
self.assertIs(snapshot2._reference, document1)
self.assertEqual(snapshot2._data, data1)
# Exercise the mocked ``batch_get_documents``.
with self.assertRaises(ValueError) as exc_info:
await self._invoke_get_all(client, [document1], [response])

self.assertFalse(snapshots[2].exists)
err_msg = _BAD_DOC_TEMPLATE.format(response.found.name)
self.assertEqual(exc_info.exception.args, (err_msg,))

# Verify the call to the mock.
doc_paths = [
document1._document_path,
document2._document_path,
document3._document_path,
]
doc_paths = [document1._document_path]
client._firestore_api.batch_get_documents.assert_called_once_with(
request={
"database": client._database_string,
Expand Down

0 comments on commit db01b59

Please sign in to comment.