Skip to content

Commit

Permalink
feat: add retry/timeout to 'async_document.AsyncDocument` methods
Browse files Browse the repository at this point in the history
Methods include:

- 'create'
- 'set'
- 'update'
- 'delete'
- 'get'
- 'collections'

Toward #221
  • Loading branch information
tseaver committed Oct 14, 2020
1 parent a557a15 commit 4e3be50
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 109 deletions.
114 changes: 68 additions & 46 deletions google/cloud/firestore_v1/async_document.py
Expand Up @@ -14,6 +14,8 @@

"""Classes for representing documents for the Google Cloud Firestore API."""

from google.api_core import retry as retries # type: ignore

from google.cloud.firestore_v1.base_document import (
BaseDocumentReference,
DocumentSnapshot,
Expand All @@ -22,7 +24,6 @@

from google.api_core import exceptions # type: ignore
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import common
from typing import Any, AsyncGenerator, Coroutine, Iterable, Union


Expand Down Expand Up @@ -54,12 +55,17 @@ class AsyncDocumentReference(BaseDocumentReference):
def __init__(self, *path, **kwargs) -> None:
super(AsyncDocumentReference, self).__init__(*path, **kwargs)

async def create(self, document_data: dict) -> Coroutine:
async def create(
self, document_data: dict, retry: retries.Retry = None, timeout: float = None,
) -> Coroutine:
"""Create the current document in the Firestore database.
Args:
document_data (dict): Property names and values to use for
creating a document.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
:class:`~google.cloud.firestore_v1.types.WriteResult`:
Expand All @@ -70,12 +76,17 @@ async def create(self, document_data: dict) -> Coroutine:
:class:`~google.cloud.exceptions.Conflict`:
If the document already exists.
"""
batch = self._client.batch()
batch.create(self, document_data)
write_results = await batch.commit()
batch, kwargs = self._prep_create(document_data, retry, timeout)
write_results = await batch.commit(**kwargs)
return _first_write_result(write_results)

async def set(self, document_data: dict, merge: bool = False) -> Coroutine:
async def set(
self,
document_data: dict,
merge: bool = False,
retry: retries.Retry = None,
timeout: float = None,
) -> Coroutine:
"""Replace the current document in the Firestore database.
A write ``option`` can be specified to indicate preconditions of
Expand All @@ -95,19 +106,25 @@ async def set(self, document_data: dict, merge: bool = False) -> Coroutine:
merge (Optional[bool] or Optional[List<apispec>]):
If True, apply merging instead of overwriting the state
of the document.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
:class:`~google.cloud.firestore_v1.types.WriteResult`:
The write result corresponding to the committed document. A write
result contains an ``update_time`` field.
"""
batch = self._client.batch()
batch.set(self, document_data, merge=merge)
write_results = await batch.commit()
batch, kwargs = self._prep_set(document_data, merge, retry, timeout)
write_results = await batch.commit(**kwargs)
return _first_write_result(write_results)

async def update(
self, field_updates: dict, option: _helpers.WriteOption = None
self,
field_updates: dict,
option: _helpers.WriteOption = None,
retry: retries.Retry = None,
timeout: float = None,
) -> Coroutine:
"""Update an existing document in the Firestore database.
Expand Down Expand Up @@ -242,6 +259,9 @@ async def update(
option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
A write option to make assertions / preconditions on the server
state of the document before applying changes.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
:class:`~google.cloud.firestore_v1.types.WriteResult`:
Expand All @@ -251,18 +271,25 @@ async def update(
Raises:
~google.cloud.exceptions.NotFound: If the document does not exist.
"""
batch = self._client.batch()
batch.update(self, field_updates, option=option)
write_results = await batch.commit()
batch, kwargs = self._prep_update(field_updates, option, retry, timeout)
write_results = await batch.commit(**kwargs)
return _first_write_result(write_results)

async def delete(self, option: _helpers.WriteOption = None) -> Coroutine:
async def delete(
self,
option: _helpers.WriteOption = None,
retry: retries.Retry = None,
timeout: float = None,
) -> Coroutine:
"""Delete the current document in the Firestore database.
Args:
option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
A write option to make assertions / preconditions on the server
state of the document before applying changes.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
:class:`google.protobuf.timestamp_pb2.Timestamp`:
Expand All @@ -271,20 +298,20 @@ async def delete(self, option: _helpers.WriteOption = None) -> Coroutine:
nothing was deleted), this method will still succeed and will
still return the time that the request was received by the server.
"""
write_pb = _helpers.pb_for_delete(self._document_path, option)
request, kwargs = self._prep_delete(option, retry, timeout)

commit_response = await self._client._firestore_api.commit(
request={
"database": self._client._database_string,
"writes": [write_pb],
"transaction": None,
},
metadata=self._client._rpc_metadata,
request=request, metadata=self._client._rpc_metadata, **kwargs,
)

return commit_response.commit_time

async def get(
self, field_paths: Iterable[str] = None, transaction=None
self,
field_paths: Iterable[str] = None,
transaction=None,
retry: retries.Retry = None,
timeout: float = None,
) -> Union[DocumentSnapshot, Coroutine[Any, Any, DocumentSnapshot]]:
"""Retrieve a snapshot of the current document.
Expand All @@ -303,6 +330,9 @@ async def get(
transaction (Optional[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`]):
An existing transaction that this reference
will be retrieved in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
:class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot`:
Expand All @@ -312,23 +342,12 @@ async def get(
:attr:`create_time` attributes will all be ``None`` and
its :attr:`exists` attribute will be ``False``.
"""
if isinstance(field_paths, str):
raise ValueError("'field_paths' must be a sequence of paths, not a string.")

if field_paths is not None:
mask = common.DocumentMask(field_paths=sorted(field_paths))
else:
mask = None
request, kwargs = self._prep_get(field_paths, transaction, retry, timeout)

firestore_api = self._client._firestore_api
try:
document_pb = await firestore_api.get_document(
request={
"name": self._document_path,
"mask": mask,
"transaction": _helpers.get_transaction_id(transaction),
},
metadata=self._client._rpc_metadata,
request=request, metadata=self._client._rpc_metadata, **kwargs,
)
except exceptions.NotFound:
data = None
Expand All @@ -350,36 +369,39 @@ async def get(
update_time=update_time,
)

async def collections(self, page_size: int = None) -> AsyncGenerator:
async def collections(
self, page_size: int = None, retry: retries.Retry = None, timeout: float = None,
) -> AsyncGenerator:
"""List subcollections of the current document.
Args:
page_size (Optional[int]]): The maximum number of collections
in each page of results from this request. Non-positive values
are ignored. Defaults to a sensible value set by the API.
in each page of results from this request. Non-positive values
are ignored. Defaults to a sensible value set by the API.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
Returns:
Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]:
iterator of subcollections of the current document. If the
document does not exist at the time of `snapshot`, the
iterator will be empty
"""
request, kwargs = self._prep_collections(page_size, retry, timeout)

iterator = await self._client._firestore_api.list_collection_ids(
request={"parent": self._document_path, "page_size": page_size},
metadata=self._client._rpc_metadata,
request=request, metadata=self._client._rpc_metadata, **kwargs,
)

while True:
for i in iterator.collection_ids:
yield self.collection(i)
if iterator.next_page_token:
next_request = request.cpoy()
next_request["page_token"] = iterator.next_page_token
iterator = await self._client._firestore_api.list_collection_ids(
request={
"parent": self._document_path,
"page_size": page_size,
"page_token": iterator.next_page_token,
},
metadata=self._client._rpc_metadata,
request=request, metadata=self._client._rpc_metadata, **kwargs
)
else:
return
Expand Down
93 changes: 92 additions & 1 deletion google/cloud/firestore_v1/base_document.py
Expand Up @@ -20,7 +20,12 @@

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1 import field_path as field_path_module
from typing import Any, Iterable, NoReturn, Tuple
from google.cloud.firestore_v1.types import common

from typing import Any
from typing import Iterable
from typing import NoReturn
from typing import Tuple


class BaseDocumentReference(object):
Expand Down Expand Up @@ -180,11 +185,33 @@ def collection(self, collection_id: str) -> Any:
child_path = self._path + (collection_id,)
return self._client.collection(*child_path)

def _prep_create(
self, document_data: dict, retry: retries.Retry = None, timeout: float = None,
) -> Tuple[Any, dict]:
batch = self._client.batch()
batch.create(self, document_data)
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return batch, kwargs

def create(
self, document_data: dict, retry: retries.Retry = None, timeout: float = None,
) -> NoReturn:
raise NotImplementedError

def _prep_set(
self,
document_data: dict,
merge: bool = False,
retry: retries.Retry = None,
timeout: float = None,
) -> Tuple[Any, dict]:
batch = self._client.batch()
batch.set(self, document_data, merge=merge)
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return batch, kwargs

def set(
self,
document_data: dict,
Expand All @@ -194,6 +221,19 @@ def set(
) -> NoReturn:
raise NotImplementedError

def _prep_update(
self,
field_updates: dict,
option: _helpers.WriteOption = None,
retry: retries.Retry = None,
timeout: float = None,
) -> Tuple[Any, dict]:
batch = self._client.batch()
batch.update(self, field_updates, option=option)
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return batch, kwargs

def update(
self,
field_updates: dict,
Expand All @@ -203,6 +243,23 @@ def update(
) -> NoReturn:
raise NotImplementedError

def _prep_delete(
self,
option: _helpers.WriteOption = None,
retry: retries.Retry = None,
timeout: float = None,
) -> Tuple[dict, dict]:
"""Shared setup for async/sync :meth:`delete`."""
write_pb = _helpers.pb_for_delete(self._document_path, option)
request = {
"database": self._client._database_string,
"writes": [write_pb],
"transaction": None,
}
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return request, kwargs

def delete(
self,
option: _helpers.WriteOption = None,
Expand All @@ -211,6 +268,31 @@ def delete(
) -> NoReturn:
raise NotImplementedError

def _prep_get(
self,
field_paths: Iterable[str] = None,
transaction=None,
retry: retries.Retry = None,
timeout: float = None,
) -> Tuple[dict, dict]:
"""Shared setup for async/sync :meth:`get`."""
if isinstance(field_paths, str):
raise ValueError("'field_paths' must be a sequence of paths, not a string.")

if field_paths is not None:
mask = common.DocumentMask(field_paths=sorted(field_paths))
else:
mask = None

request = {
"name": self._document_path,
"mask": mask,
"transaction": _helpers.get_transaction_id(transaction),
}
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return request, kwargs

def get(
self,
field_paths: Iterable[str] = None,
Expand All @@ -220,6 +302,15 @@ def get(
) -> "DocumentSnapshot":
raise NotImplementedError

def _prep_collections(
self, page_size: int = None, retry: retries.Retry = None, timeout: float = None,
) -> Tuple[dict, dict]:
"""Shared setup for async/sync :meth:`collections`."""
request = {"parent": self._document_path, "page_size": page_size}
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return request, kwargs

def collections(
self, page_size: int = None, retry: retries.Retry = None, timeout: float = None,
) -> NoReturn:
Expand Down

0 comments on commit 4e3be50

Please sign in to comment.