New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add retry/timeout to manual surface #222
Changes from 1 commit
8d03921
e7d2119
6bfe32f
6ba6d21
e6ad4a1
6e806b0
00736fe
2b80b91
49d3b03
16a1e34
a9547e6
d6df19c
6eae0e7
559f2eb
8038cce
2d413df
e15b8f6
9f5bbb4
5a1ef50
6dec6f3
812c41f
8c67138
bad75e1
db01b59
4090a00
40fae96
df615e4
a557a15
4e3be50
ec8002c
c81cd8c
1acbde3
b998db2
46f27e6
9b4707a
7a976a5
4b1ec26
c5e4056
2360eac
f15a523
3feb63b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,11 +131,11 @@ def test__get_collection_reference(self): | |
|
||
def test_collection_group(self): | ||
client = self._make_default_one() | ||
query = client.collection_group("collectionId").where("foo", "==", u"bar") | ||
query = client.collection_group("collectionId").where("foo", "==", "bar") | ||
|
||
self.assertTrue(query._all_descendants) | ||
self.assertEqual(query._field_filters[0].field.field_path, "foo") | ||
self.assertEqual(query._field_filters[0].value.string_value, u"bar") | ||
self.assertEqual(query._field_filters[0].value.string_value, "bar") | ||
self.assertEqual( | ||
query._field_filters[0].op, query._field_filters[0].Operator.EQUAL | ||
) | ||
|
@@ -195,8 +195,7 @@ def test_document_factory_w_nested_path(self): | |
self.assertIs(document2._client, client) | ||
self.assertIsInstance(document2, AsyncDocumentReference) | ||
|
||
@pytest.mark.asyncio | ||
async def test_collections(self): | ||
async def _collections_helper(self, retry=None, timeout=None): | ||
from google.api_core.page_iterator import Iterator | ||
from google.api_core.page_iterator import Page | ||
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference | ||
|
@@ -220,10 +219,18 @@ def _next_page(self): | |
page, self._pages = self._pages[0], self._pages[1:] | ||
return Page(self, page, self.item_to_value) | ||
|
||
kwargs = {} | ||
|
||
if retry is not None: | ||
kwargs["retry"] = retry | ||
|
||
if timeout is not None: | ||
kwargs["timeout"] = timeout | ||
|
||
iterator = _Iterator(pages=[collection_ids]) | ||
firestore_api.list_collection_ids.return_value = iterator | ||
|
||
collections = [c async for c in client.collections()] | ||
collections = [c async for c in client.collections(**kwargs)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just confirming, you have added these changes to a few tests to get complete coverage? The changes of this PR shouldn't result in any required changes to users correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. See |
||
|
||
self.assertEqual(len(collections), len(collection_ids)) | ||
for collection, collection_id in zip(collections, collection_ids): | ||
|
@@ -233,10 +240,22 @@ def _next_page(self): | |
|
||
base_path = client._database_string + "/documents" | ||
firestore_api.list_collection_ids.assert_called_once_with( | ||
request={"parent": base_path}, metadata=client._rpc_metadata | ||
request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs, | ||
) | ||
|
||
async def _get_all_helper(self, client, references, document_pbs, **kwargs): | ||
@pytest.mark.asyncio | ||
async def test_collections(self): | ||
await self._collections_helper() | ||
|
||
@pytest.mark.asyncio | ||
async def test_collections_w_retry_timeout(self): | ||
from google.api_core.retry import Retry | ||
|
||
retry = Retry(predicate=object()) | ||
timeout = 123.0 | ||
await self._collections_helper(retry=retry, timeout=timeout) | ||
|
||
async def _invoke_get_all(self, client, references, document_pbs, **kwargs): | ||
# Create a minimal fake GAPIC with a dummy response. | ||
firestore_api = AsyncMock(spec=["batch_get_documents"]) | ||
response_iterator = AsyncIter(document_pbs) | ||
|
@@ -265,145 +284,121 @@ def _info_for_get_all(self, data1, data2): | |
|
||
return client, document1, document2, response1, response2 | ||
|
||
@pytest.mark.asyncio | ||
async def test_get_all(self): | ||
async def _get_all_helper( | ||
self, num_snapshots=2, txn_id=None, retry=None, timeout=None | ||
): | ||
from google.cloud.firestore_v1.types import common | ||
from google.cloud.firestore_v1.async_document import DocumentSnapshot | ||
|
||
data1 = {"a": u"cheese"} | ||
client = self._make_default_one() | ||
|
||
data1 = {"a": "cheese"} | ||
document1 = client.document("pineapple", "lamp1") | ||
document_pb1, read_time = _doc_get_info(document1._document_path, data1) | ||
response1 = _make_batch_response(found=document_pb1, read_time=read_time) | ||
|
||
data2 = {"b": True, "c": 18} | ||
info = self._info_for_get_all(data1, data2) | ||
client, document1, document2, response1, response2 = info | ||
document2 = client.document("pineapple", "lamp2") | ||
document, read_time = _doc_get_info(document2._document_path, data2) | ||
response2 = _make_batch_response(found=document, read_time=read_time) | ||
|
||
# Exercise the mocked ``batch_get_documents``. | ||
field_paths = ["a", "b"] | ||
snapshots = await self._get_all_helper( | ||
client, | ||
[document1, document2], | ||
[response1, response2], | ||
field_paths=field_paths, | ||
) | ||
self.assertEqual(len(snapshots), 2) | ||
document3 = client.document("pineapple", "lamp3") | ||
response3 = _make_batch_response(missing=document3._document_path) | ||
|
||
snapshot1 = snapshots[0] | ||
self.assertIsInstance(snapshot1, DocumentSnapshot) | ||
self.assertIs(snapshot1._reference, document1) | ||
self.assertEqual(snapshot1._data, data1) | ||
expected_data = [data1, data2, None][:num_snapshots] | ||
documents = [document1, document2, document3][:num_snapshots] | ||
responses = [response1, response2, response3][:num_snapshots] | ||
field_paths = [ | ||
field_path for field_path in ["a", "b", None][:num_snapshots] if field_path | ||
] | ||
|
||
snapshot2 = snapshots[1] | ||
self.assertIsInstance(snapshot2, DocumentSnapshot) | ||
self.assertIs(snapshot2._reference, document2) | ||
self.assertEqual(snapshot2._data, data2) | ||
kwargs = {} | ||
|
||
# Verify the call to the mock. | ||
doc_paths = [document1._document_path, document2._document_path] | ||
mask = common.DocumentMask(field_paths=field_paths) | ||
client._firestore_api.batch_get_documents.assert_called_once_with( | ||
request={ | ||
"database": client._database_string, | ||
"documents": doc_paths, | ||
"mask": mask, | ||
"transaction": None, | ||
}, | ||
metadata=client._rpc_metadata, | ||
) | ||
if retry is not None: | ||
kwargs["retry"] = retry | ||
|
||
@pytest.mark.asyncio | ||
async def test_get_all_with_transaction(self): | ||
from google.cloud.firestore_v1.async_document import DocumentSnapshot | ||
if timeout is not None: | ||
kwargs["timeout"] = timeout | ||
|
||
data = {"so-much": 484} | ||
info = self._info_for_get_all(data, {}) | ||
client, document, _, response, _ = info | ||
transaction = client.transaction() | ||
txn_id = b"the-man-is-non-stop" | ||
transaction._id = txn_id | ||
if txn_id is not None: | ||
transaction = client.transaction() | ||
transaction._id = txn_id | ||
kwargs["transaction"] = transaction | ||
|
||
# Exercise the mocked ``batch_get_documents``. | ||
snapshots = await self._get_all_helper( | ||
client, [document], [response], transaction=transaction | ||
snapshots = await self._invoke_get_all( | ||
client, documents, responses, field_paths=field_paths, **kwargs, | ||
) | ||
self.assertEqual(len(snapshots), 1) | ||
|
||
snapshot = snapshots[0] | ||
self.assertIsInstance(snapshot, DocumentSnapshot) | ||
self.assertIs(snapshot._reference, document) | ||
self.assertEqual(snapshot._data, data) | ||
self.assertEqual(len(snapshots), num_snapshots) | ||
|
||
for data, document, snapshot in zip(expected_data, documents, snapshots): | ||
self.assertIsInstance(snapshot, DocumentSnapshot) | ||
self.assertIs(snapshot._reference, document) | ||
if data is None: | ||
self.assertFalse(snapshot.exists) | ||
else: | ||
self.assertEqual(snapshot._data, data) | ||
|
||
# Verify the call to the mock. | ||
doc_paths = [document._document_path] | ||
doc_paths = [document._document_path for document in documents] | ||
mask = common.DocumentMask(field_paths=field_paths) | ||
|
||
kwargs.pop("transaction", None) | ||
|
||
client._firestore_api.batch_get_documents.assert_called_once_with( | ||
request={ | ||
"database": client._database_string, | ||
"documents": doc_paths, | ||
"mask": None, | ||
"mask": mask, | ||
"transaction": txn_id, | ||
}, | ||
metadata=client._rpc_metadata, | ||
**kwargs, | ||
) | ||
|
||
@pytest.mark.asyncio | ||
async def test_get_all_unknown_result(self): | ||
from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE | ||
async def test_get_all(self): | ||
await self._get_all_helper() | ||
|
||
info = self._info_for_get_all({"z": 28.5}, {}) | ||
client, document, _, _, response = info | ||
@pytest.mark.asyncio | ||
async def test_get_all_with_transaction(self): | ||
txn_id = b"the-man-is-non-stop" | ||
await self._get_all_helper(num_snapshots=1, txn_id=txn_id) | ||
|
||
# Exercise the mocked ``batch_get_documents``. | ||
with self.assertRaises(ValueError) as exc_info: | ||
await self._get_all_helper(client, [document], [response]) | ||
@pytest.mark.asyncio | ||
async def test_get_all_w_retry_timeout(self): | ||
from google.api_core.retry import Retry | ||
|
||
err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) | ||
self.assertEqual(exc_info.exception.args, (err_msg,)) | ||
|
||
# Verify the call to the mock. | ||
doc_paths = [document._document_path] | ||
client._firestore_api.batch_get_documents.assert_called_once_with( | ||
request={ | ||
"database": client._database_string, | ||
"documents": doc_paths, | ||
"mask": None, | ||
"transaction": None, | ||
}, | ||
metadata=client._rpc_metadata, | ||
) | ||
retry = Retry(predicate=object()) | ||
timeout = 123.0 | ||
await self._get_all_helper(retry=retry, timeout=timeout) | ||
|
||
@pytest.mark.asyncio | ||
async def test_get_all_wrong_order(self): | ||
from google.cloud.firestore_v1.async_document import DocumentSnapshot | ||
await self._get_all_helper(num_snapshots=3) | ||
|
||
data1 = {"up": 10} | ||
data2 = {"down": -10} | ||
info = self._info_for_get_all(data1, data2) | ||
client, document1, document2, response1, response2 = info | ||
document3 = client.document("pineapple", "lamp3") | ||
response3 = _make_batch_response(missing=document3._document_path) | ||
@pytest.mark.asyncio | ||
async def test_get_all_unknown_result(self): | ||
from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE | ||
|
||
# Exercise the mocked ``batch_get_documents``. | ||
snapshots = await self._get_all_helper( | ||
client, [document1, document2, document3], [response2, response1, response3] | ||
) | ||
client = self._make_default_one() | ||
|
||
self.assertEqual(len(snapshots), 3) | ||
document1 = client.document("pineapple", "lamp1") | ||
|
||
snapshot1 = snapshots[0] | ||
self.assertIsInstance(snapshot1, DocumentSnapshot) | ||
self.assertIs(snapshot1._reference, document2) | ||
self.assertEqual(snapshot1._data, data2) | ||
data = {"z": 28.5} | ||
wrong_document = client.document("pineapple", "lamp2") | ||
document_pb, read_time = _doc_get_info(wrong_document._document_path, data) | ||
response = _make_batch_response(found=document_pb, read_time=read_time) | ||
|
||
snapshot2 = snapshots[1] | ||
self.assertIsInstance(snapshot2, DocumentSnapshot) | ||
self.assertIs(snapshot2._reference, document1) | ||
self.assertEqual(snapshot2._data, data1) | ||
# Exercise the mocked ``batch_get_documents``. | ||
with self.assertRaises(ValueError) as exc_info: | ||
await self._invoke_get_all(client, [document1], [response]) | ||
|
||
self.assertFalse(snapshots[2].exists) | ||
err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) | ||
self.assertEqual(exc_info.exception.args, (err_msg,)) | ||
|
||
# Verify the call to the mock. | ||
doc_paths = [ | ||
document1._document_path, | ||
document2._document_path, | ||
document3._document_path, | ||
] | ||
doc_paths = [document1._document_path] | ||
client._firestore_api.batch_get_documents.assert_called_once_with( | ||
request={ | ||
"database": client._database_string, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Factored out later to use the helper.