From e640e663f525233a8173767f6886537dfd97b121 Mon Sep 17 00:00:00 2001 From: Raphael Long Date: Fri, 7 Aug 2020 12:34:52 -0500 Subject: [PATCH] fix: await on to_wrap in AsyncTransactional (#147) --- google/cloud/firestore_v1/async_transaction.py | 18 +++++++++--------- tests/unit/v1/test_async_transaction.py | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 19a436b0b..4793e216c 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -188,31 +188,31 @@ class _AsyncTransactional(_BaseTransactional): :func:`~google.cloud.firestore_v1.async_transaction.transactional`. Args: - to_wrap (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]): - A callable that should be run (and retried) in a transaction. + to_wrap (Coroutine[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]): + A coroutine that should be run (and retried) in a transaction. """ def __init__(self, to_wrap) -> None: super(_AsyncTransactional, self).__init__(to_wrap) async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine: - """Begin transaction and call the wrapped callable. + """Begin transaction and call the wrapped coroutine. - If the callable raises an exception, the transaction will be rolled + If the coroutine raises an exception, the transaction will be rolled back. If not, the transaction will be "ready" for ``Commit`` (i.e. it will have staged writes). Args: transaction (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`): - A transaction to execute the callable within. + A transaction to execute the coroutine within. args (Tuple[Any, ...]): The extra positional arguments to pass - along to the wrapped callable. + along to the wrapped coroutine. kwargs (Dict[str, Any]): The extra keyword arguments to pass - along to the wrapped callable. + along to the wrapped coroutine. Returns: - Any: result of the wrapped callable. + Any: result of the wrapped coroutine. Raises: Exception: Any failure caused by ``to_wrap``. @@ -226,7 +226,7 @@ async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine: if self.retry_id is None: self.retry_id = self.current_id try: - return self.to_wrap(transaction, *args, **kwargs) + return await self.to_wrap(transaction, *args, **kwargs) except: # noqa # NOTE: If ``rollback`` fails this will lose the information # from the original failure. diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index a7774a28c..ed732ae92 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -339,7 +339,7 @@ def test_constructor(self): @pytest.mark.asyncio async def test__pre_commit_success(self): - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"totes-began" @@ -368,7 +368,7 @@ async def test__pre_commit_success(self): async def test__pre_commit_retry_id_already_set_success(self): from google.cloud.firestore_v1.types import common - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) txn_id1 = b"already-set" wrapped.retry_id = txn_id1 @@ -401,7 +401,7 @@ async def test__pre_commit_retry_id_already_set_success(self): @pytest.mark.asyncio async def test__pre_commit_failure(self): exc = RuntimeError("Nope not today.") - to_wrap = mock.Mock(side_effect=exc, spec=[]) + to_wrap = AsyncMock(side_effect=exc, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"gotta-fail" @@ -438,7 +438,7 @@ async def test__pre_commit_failure_with_rollback_failure(self): from google.api_core import exceptions exc1 = ValueError("I will not be only failure.") - to_wrap = mock.Mock(side_effect=exc1, spec=[]) + to_wrap = AsyncMock(side_effect=exc1, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"both-will-fail" @@ -614,7 +614,7 @@ async def test__maybe_commit_failure_cannot_retry(self): @pytest.mark.asyncio async def test___call__success_first_attempt(self): - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"whole-enchilada" @@ -650,7 +650,7 @@ async def test___call__success_second_attempt(self): from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import write - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"whole-enchilada" @@ -707,7 +707,7 @@ async def test___call__failure(self): _EXCEED_ATTEMPTS_TEMPLATE, ) - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"only-one-shot"