From 9095368eaec4271b87ad792ff9bbd065364109f6 Mon Sep 17 00:00:00 2001 From: Raphael Long Date: Thu, 23 Jul 2020 12:39:06 -0500 Subject: [PATCH] fix: asyncio microgen client get_all type (#126) * feat: create AsyncIter class for mocking * fix: type error on mocked return on batch_get_documents --- google/cloud/firestore_v1/async_client.py | 2 +- tests/unit/v1/test__helpers.py | 9 +++++++++ tests/unit/v1/test_async_client.py | 4 ++-- tests/unit/v1/test_async_collection.py | 19 +++++-------------- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 00029074b..f37b28ddc 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -252,7 +252,7 @@ async def get_all(self, references, field_paths=None, transaction=None): metadata=self._rpc_metadata, ) - for get_doc_response in response_iterator: + async for get_doc_response in response_iterator: yield _parse_batch_get(get_doc_response, reference_map, self) async def collections(self): diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index caa456c91..55b74f89d 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -25,6 +25,15 @@ async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) +class AsyncIter: + def __init__(self, items): + self.items = items + + async def __aiter__(self, **_): + for i in self.items: + yield i + + class TestGeoPoint(unittest.TestCase): @staticmethod def _get_target_class(): diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 1a4724e13..0beb0157c 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -18,7 +18,7 @@ import aiounittest import mock -from tests.unit.v1.test__helpers import AsyncMock +from tests.unit.v1.test__helpers import AsyncMock, AsyncIter class TestAsyncClient(aiounittest.AsyncTestCase): @@ -237,7 +237,7 @@ def _next_page(self): async def _get_all_helper(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["batch_get_documents"]) - response_iterator = iter(document_pbs) + response_iterator = AsyncIter(document_pbs) firestore_api.batch_get_documents.return_value = response_iterator # Attach the fake GAPIC to a real client. diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index bb002ea97..742a381db 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -17,16 +17,7 @@ import aiounittest import mock -from tests.unit.v1.test__helpers import AsyncMock - - -class MockAsyncIter: - def __init__(self, count): - self.count = count - - async def __aiter__(self, **_): - for i in range(self.count): - yield i +from tests.unit.v1.test__helpers import AsyncMock, AsyncIter class TestAsyncCollectionReference(aiounittest.AsyncTestCase): @@ -258,7 +249,7 @@ async def test_list_documents_w_page_size(self): async def test_get(self, query_class): import warnings - query_class.return_value.stream.return_value = MockAsyncIter(3) + query_class.return_value.stream.return_value = AsyncIter(range(3)) collection = self._make_one("collection") with warnings.catch_warnings(record=True) as warned: @@ -280,7 +271,7 @@ async def test_get(self, query_class): async def test_get_with_transaction(self, query_class): import warnings - query_class.return_value.stream.return_value = MockAsyncIter(3) + query_class.return_value.stream.return_value = AsyncIter(range(3)) collection = self._make_one("collection") transaction = mock.sentinel.txn @@ -301,7 +292,7 @@ async def test_get_with_transaction(self, query_class): @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_stream(self, query_class): - query_class.return_value.stream.return_value = MockAsyncIter(3) + query_class.return_value.stream.return_value = AsyncIter(range(3)) collection = self._make_one("collection") stream_response = collection.stream() @@ -316,7 +307,7 @@ async def test_stream(self, query_class): @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_stream_with_transaction(self, query_class): - query_class.return_value.stream.return_value = MockAsyncIter(3) + query_class.return_value.stream.return_value = AsyncIter(range(3)) collection = self._make_one("collection") transaction = mock.sentinel.txn