Skip to content

Commit

Permalink
feat: add retry/timeout to 'async_batch.AsyncBatch.commit'
Browse files Browse the repository at this point in the history
Toward #221
  • Loading branch information
tseaver committed Oct 14, 2020
1 parent df615e4 commit a557a15
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 22 deletions.
19 changes: 12 additions & 7 deletions google/cloud/firestore_v1/async_batch.py
Expand Up @@ -15,6 +15,8 @@
"""Helpers for batch requests to the Google Cloud Firestore API."""


from google.api_core import retry as retries # type: ignore

from google.cloud.firestore_v1.base_batch import BaseWriteBatch


Expand All @@ -33,27 +35,30 @@ class AsyncWriteBatch(BaseWriteBatch):
def __init__(self, client) -> None:
super(AsyncWriteBatch, self).__init__(client=client)

async def commit(self) -> list:
async def commit(self, retry: retries.Retry = None, timeout: float = None) -> list:
"""Commit the changes accumulated in this batch.
Args:
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
List[:class:`google.cloud.proto.firestore.v1.write.WriteResult`, ...]:
The write results corresponding to the changes committed, returned
in the same order as the changes were applied to this batch. A
write result contains an ``update_time`` field.
"""
request, kwargs = self._prep_commit(retry, timeout)

commit_response = await self._client._firestore_api.commit(
request={
"database": self._client._database_string,
"writes": self._write_pbs,
"transaction": None,
},
metadata=self._client._rpc_metadata,
request=request, metadata=self._client._rpc_metadata, **kwargs,
)

self._write_pbs = []
self.write_results = results = list(commit_response.write_results)
self.commit_time = commit_response.commit_time

return results

async def __aenter__(self):
Expand Down
11 changes: 11 additions & 0 deletions google/cloud/firestore_v1/base_batch.py
Expand Up @@ -19,6 +19,7 @@

# Types needed only for Type Hints
from google.cloud.firestore_v1.document import DocumentReference

from typing import Union


Expand Down Expand Up @@ -146,3 +147,13 @@ def delete(
"""
write_pb = _helpers.pb_for_delete(reference._document_path, option)
self._add_write_pbs([write_pb])

def _prep_commit(self, retry, timeout):
"""Shared setup for async/sync :meth:`commit`."""
request = {
"database": self._client._database_string,
"writes": self._write_pbs,
"transaction": None,
}
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
return request, kwargs
12 changes: 3 additions & 9 deletions google/cloud/firestore_v1/batch.py
Expand Up @@ -17,7 +17,6 @@
from google.api_core import retry as retries # type: ignore

from google.cloud.firestore_v1.base_batch import BaseWriteBatch
from google.cloud.firestore_v1 import _helpers


class WriteBatch(BaseWriteBatch):
Expand Down Expand Up @@ -49,21 +48,16 @@ def commit(self, retry: retries.Retry = None, timeout: float = None) -> list:
in the same order as the changes were applied to this batch. A
write result contains an ``update_time`` field.
"""
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
request, kwargs = self._prep_commit(retry, timeout)

commit_response = self._client._firestore_api.commit(
request={
"database": self._client._database_string,
"writes": self._write_pbs,
"transaction": None,
},
metadata=self._client._rpc_metadata,
**kwargs,
request=request, metadata=self._client._rpc_metadata, **kwargs,
)

self._write_pbs = []
self.write_results = results = list(commit_response.write_results)
self.commit_time = commit_response.commit_time

return results

def __enter__(self):
Expand Down
28 changes: 22 additions & 6 deletions tests/unit/v1/test_async_batch.py
Expand Up @@ -37,9 +37,9 @@ def test_constructor(self):
self.assertIsNone(batch.write_results)
self.assertIsNone(batch.commit_time)

@pytest.mark.asyncio
async def test_commit(self):
async def _commit_helper(self, retry=None, timeout=None):
from google.protobuf import timestamp_pb2
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import firestore
from google.cloud.firestore_v1.types import write

Expand All @@ -51,6 +51,7 @@ async def test_commit(self):
commit_time=timestamp,
)
firestore_api.commit.return_value = commit_response
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

# Attach the fake GAPIC to a real client.
client = _make_client("grand")
Expand All @@ -59,12 +60,13 @@ async def test_commit(self):
# Actually make a batch with some mutations and call commit().
batch = self._make_one(client)
document1 = client.document("a", "b")
batch.create(document1, {"ten": 10, "buck": u"ets"})
batch.create(document1, {"ten": 10, "buck": "ets"})
document2 = client.document("c", "d", "e", "f")
batch.delete(document2)
write_pbs = batch._write_pbs[::]

write_results = await batch.commit()
write_results = await batch.commit(**kwargs)

self.assertEqual(write_results, list(commit_response.write_results))
self.assertEqual(batch.write_results, write_results)
self.assertEqual(batch.commit_time.timestamp_pb(), timestamp)
Expand All @@ -79,8 +81,22 @@ async def test_commit(self):
"transaction": None,
},
metadata=client._rpc_metadata,
**kwargs,
)

@pytest.mark.asyncio
async def test_commit(self):
await self._commit_helper()

@pytest.mark.asyncio
async def test_commit_w_retry_timeout(self):
from google.api_core.retry import Retry

retry = Retry(predicate=object())
timeout = 123.0

await self._commit_helper(retry=retry, timeout=timeout)

@pytest.mark.asyncio
async def test_as_context_mgr_wo_error(self):
from google.protobuf import timestamp_pb2
Expand All @@ -102,7 +118,7 @@ async def test_as_context_mgr_wo_error(self):

async with batch as ctx_mgr:
self.assertIs(ctx_mgr, batch)
ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"})
ctx_mgr.create(document1, {"ten": 10, "buck": "ets"})
ctx_mgr.delete(document2)
write_pbs = batch._write_pbs[::]

Expand Down Expand Up @@ -132,7 +148,7 @@ async def test_as_context_mgr_w_error(self):

with self.assertRaises(RuntimeError):
async with batch as ctx_mgr:
ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"})
ctx_mgr.create(document1, {"ten": 10, "buck": "ets"})
ctx_mgr.delete(document2)
raise RuntimeError("testing")

Expand Down

0 comments on commit a557a15

Please sign in to comment.