diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 50c5559b1..e8af7a667 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -20,6 +20,7 @@ """ from google.cloud import firestore_v1 from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -208,6 +209,29 @@ def _chunkify( ): return + def _get_stream_iterator(self, transaction, retry, timeout): + """Helper method for :meth:`stream`.""" + request, expected_prefix, kwargs = self._prep_stream( + transaction, retry, timeout, + ) + + response_iterator = self._client._firestore_api.run_query( + request=request, metadata=self._client._rpc_metadata, **kwargs, + ) + + return response_iterator, expected_prefix + + def _retry_query_after_exception(self, exc, retry, transaction): + """Helper method for :meth:`stream`.""" + if transaction is None: # no snapshot-based retry inside transaction + if retry is gapic_v1.method.DEFAULT: + transport = self._client._firestore_api._transport + gapic_callable = transport.run_query + retry = gapic_callable._retry + return retry._predicate(exc) + + return False + def stream( self, transaction=None, @@ -244,15 +268,28 @@ def stream( :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: The next document that fulfills the query. """ - request, expected_prefix, kwargs = self._prep_stream( + response_iterator, expected_prefix = self._get_stream_iterator( transaction, retry, timeout, ) - response_iterator = self._client._firestore_api.run_query( - request=request, metadata=self._client._rpc_metadata, **kwargs, - ) + last_snapshot = None + + while True: + try: + response = next(response_iterator, None) + except exceptions.GoogleAPICallError as exc: + if self._retry_query_after_exception(exc, retry, transaction): + new_query = self.start_after(last_snapshot) + response_iterator, _ = new_query._get_stream_iterator( + transaction, retry, timeout, + ) + continue + else: + raise + + if response is None: # EOI + break - for response in response_iterator: if self._all_descendants: snapshot = _collection_group_query_response_to_snapshot( response, self._parent @@ -262,6 +299,7 @@ def stream( response, self._parent, expected_prefix ) if snapshot is not None: + last_snapshot = snapshot yield snapshot def on_snapshot(self, callback: Callable) -> Watch: diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index ea28969a8..6ca82090b 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse import types import unittest import mock import pytest +from google.api_core import gapic_v1 +from google.cloud.firestore_v1.types.document import Document +from google.cloud.firestore_v1.types.firestore import RunQueryResponse from tests.unit.v1.test_base_query import _make_credentials from tests.unit.v1.test_base_query import _make_cursor_pb from tests.unit.v1.test_base_query import _make_query_response @@ -456,6 +457,124 @@ def test_stream_w_collection_group(self): metadata=client._rpc_metadata, ) + def _stream_w_retriable_exc_helper( + self, + retry=gapic_v1.method.DEFAULT, + timeout=None, + transaction=None, + expect_retry=True, + ): + from google.api_core import exceptions + from google.cloud.firestore_v1 import _helpers + + if transaction is not None: + expect_retry = False + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query", "_transport"]) + transport = firestore_api._transport = mock.Mock(spec=["run_query"]) + stub = transport.run_query = mock.create_autospec( + gapic_v1.method._GapicCallable + ) + stub._retry = mock.Mock(spec=["_predicate"]) + stub._predicate = lambda exc: True # pragma: NO COVER + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + retriable_exc = exceptions.ServiceUnavailable("testing") + + def _stream_w_exception(*_args, **_kw): + yield response_pb + raise retriable_exc + + firestore_api.run_query.side_effect = [_stream_w_exception(), iter([])] + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + query = self._make_one(parent) + + get_response = query.stream(transaction=transaction, **kwargs) + + self.assertIsInstance(get_response, types.GeneratorType) + if expect_retry: + returned = list(get_response) + else: + returned = [next(get_response)] + with self.assertRaises(exceptions.ServiceUnavailable): + next(get_response) + + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("dee", "sleep")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + calls = firestore_api.run_query.call_args_list + + if expect_retry: + self.assertEqual(len(calls), 2) + else: + self.assertEqual(len(calls), 1) + + if transaction is not None: + expected_transaction_id = transaction.id + else: + expected_transaction_id = None + + self.assertEqual( + calls[0], + mock.call( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": expected_transaction_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ), + ) + + if expect_retry: + new_query = query.start_after(snapshot) + self.assertEqual( + calls[1], + mock.call( + request={ + "parent": parent_path, + "structured_query": new_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ), + ) + + def test_stream_w_retriable_exc_w_defaults(self): + self._stream_w_retriable_exc_helper() + + def test_stream_w_retriable_exc_w_retry(self): + retry = mock.Mock(spec=["_predicate"]) + retry._predicate = lambda exc: False + self._stream_w_retriable_exc_helper(retry=retry, expect_retry=False) + + def test_stream_w_retriable_exc_w_transaction(self): + from google.cloud.firestore_v1 import transaction + + txn = transaction.Transaction(client=mock.Mock(spec=[])) + txn._id = b"DEADBEEF" + self._stream_w_retriable_exc_helper(transaction=txn) + @mock.patch("google.cloud.firestore_v1.query.Watch", autospec=True) def test_on_snapshot(self, watch): query = self._make_one(mock.sentinel.parent)