Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: harden 'query.stream' against retriable exceptions #456

Merged
merged 3 commits into from Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 43 additions & 5 deletions google/cloud/firestore_v1/query.py
Expand Up @@ -20,6 +20,7 @@
"""
from google.cloud import firestore_v1
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore

Expand Down Expand Up @@ -208,6 +209,29 @@ def _chunkify(
):
return

def _get_stream_iterator(self, transaction, retry, timeout):
"""Helper method for :meth:`stream`."""
request, expected_prefix, kwargs = self._prep_stream(
transaction, retry, timeout,
)

response_iterator = self._client._firestore_api.run_query(
request=request, metadata=self._client._rpc_metadata, **kwargs,
)

return response_iterator, expected_prefix

def _retry_query_after_exception(self, exc, retry, transaction):
"""Helper method for :meth:`stream`."""
if transaction is None: # no snapshot-based retry inside transaction
if retry is gapic_v1.method.DEFAULT:
transport = self._client._firestore_api._transport
gapic_callable = transport.run_query
retry = gapic_callable._retry
return retry._predicate(exc)

return False

def stream(
self,
transaction=None,
Expand Down Expand Up @@ -244,15 +268,28 @@ def stream(
:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`:
The next document that fulfills the query.
"""
request, expected_prefix, kwargs = self._prep_stream(
response_iterator, expected_prefix = self._get_stream_iterator(
transaction, retry, timeout,
)

response_iterator = self._client._firestore_api.run_query(
request=request, metadata=self._client._rpc_metadata, **kwargs,
)
last_snapshot = None

while True:
try:
response = next(response_iterator, None)
except exceptions.GoogleAPICallError as exc:
if self._retry_query_after_exception(exc, retry, transaction):
new_query = self.start_after(last_snapshot)
response_iterator, _ = new_query._get_stream_iterator(
transaction, retry, timeout,
)
continue
else:
raise

if response is None: # EOI
break

for response in response_iterator:
if self._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
response, self._parent
Expand All @@ -262,6 +299,7 @@ def stream(
response, self._parent, expected_prefix
)
if snapshot is not None:
last_snapshot = snapshot
yield snapshot

def on_snapshot(self, callback: Callable) -> Watch:
Expand Down
123 changes: 121 additions & 2 deletions tests/unit/v1/test_query.py
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.firestore_v1.types.document import Document
from google.cloud.firestore_v1.types.firestore import RunQueryResponse
import types
import unittest

import mock
import pytest

from google.api_core import gapic_v1
from google.cloud.firestore_v1.types.document import Document
from google.cloud.firestore_v1.types.firestore import RunQueryResponse
from tests.unit.v1.test_base_query import _make_credentials
from tests.unit.v1.test_base_query import _make_cursor_pb
from tests.unit.v1.test_base_query import _make_query_response
Expand Down Expand Up @@ -456,6 +457,124 @@ def test_stream_w_collection_group(self):
metadata=client._rpc_metadata,
)

def _stream_w_retriable_exc_helper(
self,
retry=gapic_v1.method.DEFAULT,
timeout=None,
transaction=None,
expect_retry=True,
):
from google.api_core import exceptions
from google.cloud.firestore_v1 import _helpers

if transaction is not None:
expect_retry = False

# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["run_query", "_transport"])
transport = firestore_api._transport = mock.Mock(spec=["run_query"])
stub = transport.run_query = mock.create_autospec(
gapic_v1.method._GapicCallable
)
stub._retry = mock.Mock(spec=["_predicate"])
stub._predicate = lambda exc: True # pragma: NO COVER

# Attach the fake GAPIC to a real client.
client = _make_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")

# Add a dummy response to the minimal fake GAPIC.
_, expected_prefix = parent._parent_info()
name = "{}/sleep".format(expected_prefix)
data = {"snooze": 10}
response_pb = _make_query_response(name=name, data=data)
retriable_exc = exceptions.ServiceUnavailable("testing")

def _stream_w_exception(*_args, **_kw):
yield response_pb
raise retriable_exc

firestore_api.run_query.side_effect = [_stream_w_exception(), iter([])]
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

# Execute the query and check the response.
query = self._make_one(parent)

get_response = query.stream(transaction=transaction, **kwargs)

self.assertIsInstance(get_response, types.GeneratorType)
if expect_retry:
returned = list(get_response)
else:
returned = [next(get_response)]
with self.assertRaises(exceptions.ServiceUnavailable):
next(get_response)

self.assertEqual(len(returned), 1)
snapshot = returned[0]
self.assertEqual(snapshot.reference._path, ("dee", "sleep"))
self.assertEqual(snapshot.to_dict(), data)

# Verify the mock call.
parent_path, _ = parent._parent_info()
calls = firestore_api.run_query.call_args_list

if expect_retry:
self.assertEqual(len(calls), 2)
else:
self.assertEqual(len(calls), 1)

if transaction is not None:
expected_transaction_id = transaction.id
else:
expected_transaction_id = None

self.assertEqual(
calls[0],
mock.call(
request={
"parent": parent_path,
"structured_query": query._to_protobuf(),
"transaction": expected_transaction_id,
},
metadata=client._rpc_metadata,
**kwargs,
),
)

if expect_retry:
new_query = query.start_after(snapshot)
self.assertEqual(
calls[1],
mock.call(
request={
"parent": parent_path,
"structured_query": new_query._to_protobuf(),
"transaction": None,
},
metadata=client._rpc_metadata,
**kwargs,
),
)

def test_stream_w_retriable_exc_w_defaults(self):
self._stream_w_retriable_exc_helper()

def test_stream_w_retriable_exc_w_retry(self):
retry = mock.Mock(spec=["_predicate"])
retry._predicate = lambda exc: False
self._stream_w_retriable_exc_helper(retry=retry, expect_retry=False)

def test_stream_w_retriable_exc_w_transaction(self):
from google.cloud.firestore_v1 import transaction

txn = transaction.Transaction(client=mock.Mock(spec=[]))
txn._id = b"DEADBEEF"
self._stream_w_retriable_exc_helper(transaction=txn)

@mock.patch("google.cloud.firestore_v1.query.Watch", autospec=True)
def test_on_snapshot(self, watch):
query = self._make_one(mock.sentinel.parent)
Expand Down