Skip to content

Commit

Permalink
feat: add retry/timeout to 'async_transaction.AsyncTransaction' methods
Browse files Browse the repository at this point in the history
Methods affected:

- 'get_all'
- 'get'

Towards #221
  • Loading branch information
tseaver committed Oct 14, 2020
1 parent b998db2 commit 46f27e6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 15 deletions.
29 changes: 24 additions & 5 deletions google/cloud/firestore_v1/async_transaction.py
Expand Up @@ -18,6 +18,8 @@
import asyncio
import random

from google.api_core import retry as retries # type: ignore

from google.cloud.firestore_v1.base_transaction import (
_BaseTransactional,
BaseTransaction,
Expand All @@ -34,6 +36,7 @@

from google.api_core import exceptions # type: ignore
from google.cloud.firestore_v1 import async_batch
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1 import types

from google.cloud.firestore_v1.async_document import AsyncDocumentReference
Expand Down Expand Up @@ -144,32 +147,48 @@ async def _commit(self) -> list:
self._clean_up()
return list(commit_response.write_results)

async def get_all(self, references: list) -> Coroutine:
async def get_all(
self, references: list, retry: retries.Retry = None, timeout: float = None
) -> Coroutine:
"""Retrieves multiple documents from Firestore.
Args:
references (List[.AsyncDocumentReference, ...]): Iterable of document
references to be retrieved.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Yields:
.DocumentSnapshot: The next document snapshot that fulfills the
query, or :data:`None` if the document does not exist.
"""
return await self._client.get_all(references, transaction=self)
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
return await self._client.get_all(references, transaction=self, **kwargs)

async def get(self, ref_or_query) -> AsyncGenerator[DocumentSnapshot, Any]:
async def get(
self, ref_or_query, retry: retries.Retry = None, timeout: float = None
) -> AsyncGenerator[DocumentSnapshot, Any]:
"""
Retrieve a document or a query result from the database.
Args:
ref_or_query The document references or query object to return.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Yields:
.DocumentSnapshot: The next document snapshot that fulfills the
query, or :data:`None` if the document does not exist.
"""
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
if isinstance(ref_or_query, AsyncDocumentReference):
return await self._client.get_all([ref_or_query], transaction=self)
return await self._client.get_all(
[ref_or_query], transaction=self, **kwargs
)
elif isinstance(ref_or_query, AsyncQuery):
return await ref_or_query.stream(transaction=self)
return await ref_or_query.stream(transaction=self, **kwargs)
else:
raise ValueError(
'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.'
Expand Down
66 changes: 56 additions & 10 deletions tests/unit/v1/test_async_transaction.py
Expand Up @@ -279,38 +279,84 @@ async def test__commit_failure(self):
metadata=client._rpc_metadata,
)

@pytest.mark.asyncio
async def test_get_all(self):
async def _get_all_helper(self, retry=None, timeout=None):
from google.cloud.firestore_v1 import _helpers

client = AsyncMock(spec=["get_all"])
transaction = self._make_one(client)
ref1, ref2 = mock.Mock(), mock.Mock()
result = await transaction.get_all([ref1, ref2])
client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction)
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

result = await transaction.get_all([ref1, ref2], **kwargs)

client.get_all.assert_called_once_with(
[ref1, ref2], transaction=transaction, **kwargs,
)
self.assertIs(result, client.get_all.return_value)

@pytest.mark.asyncio
async def test_get_document_ref(self):
async def test_get_all(self):
await self._get_all_helper()

@pytest.mark.asyncio
async def test_get_all_w_retry_timeout(self):
from google.api_core.retry import Retry

retry = Retry(predicate=object())
timeout = 123.0
await self._get_all_helper(retry=retry, timeout=timeout)

async def _get_w_document_ref_helper(self, retry=None, timeout=None):
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1 import _helpers

client = AsyncMock(spec=["get_all"])
transaction = self._make_one(client)
ref = AsyncDocumentReference("documents", "doc-id")
result = await transaction.get(ref)
client.get_all.assert_called_once_with([ref], transaction=transaction)
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

result = await transaction.get(ref, **kwargs)

client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs)
self.assertIs(result, client.get_all.return_value)

@pytest.mark.asyncio
async def test_get_w_query(self):
async def test_get_w_document_ref(self):
await self._get_w_document_ref_helper()

@pytest.mark.asyncio
async def test_get_w_document_ref_w_retry_timeout(self):
from google.api_core.retry import Retry

retry = Retry(predicate=object())
timeout = 123.0
await self._get_w_document_ref_helper()

async def _get_w_query_helper(self, retry=None, timeout=None):
from google.cloud.firestore_v1.async_query import AsyncQuery
from google.cloud.firestore_v1 import _helpers

client = AsyncMock(spec=[])
transaction = self._make_one(client)
query = AsyncQuery(parent=AsyncMock(spec=[]))
query.stream = AsyncMock()
result = await transaction.get(query)
query.stream.assert_called_once_with(transaction=transaction)
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

result = await transaction.get(query, **kwargs,)

query.stream.assert_called_once_with(
transaction=transaction, **kwargs,
)
self.assertIs(result, query.stream.return_value)

@pytest.mark.asyncio
async def test_get_w_query(self):
await self._get_w_query_helper()

@pytest.mark.asyncio
async def test_get_w_query_w_retry_timeout(self):
await self._get_w_query_helper()

@pytest.mark.asyncio
async def test_get_failure(self):
client = _make_client()
Expand Down

0 comments on commit 46f27e6

Please sign in to comment.