diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py new file mode 100644 index 000000000..d29c30235 --- /dev/null +++ b/google/cloud/firestore_v1/async_batch.py @@ -0,0 +1,64 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for batch requests to the Google Cloud Firestore API.""" + + +from google.cloud.firestore_v1.base_batch import BaseWriteBatch + + +class AsyncWriteBatch(BaseWriteBatch): + """Accumulate write operations to be sent in a batch. + + This has the same set of methods for write operations that + :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` does, + e.g. :meth:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference.create`. + + Args: + client (:class:`~google.cloud.firestore_v1.async_client.AsyncClient`): + The client that created this batch. + """ + + def __init__(self, client): + super(AsyncWriteBatch, self).__init__(client=client) + + async def commit(self): + """Commit the changes accumulated in this batch. + + Returns: + List[:class:`google.cloud.proto.firestore.v1.write.WriteResult`, ...]: + The write results corresponding to the changes committed, returned + in the same order as the changes were applied to this batch. A + write result contains an ``update_time`` field. + """ + commit_response = self._client._firestore_api.commit( + request={ + "database": self._client._database_string, + "writes": self._write_pbs, + "transaction": None, + }, + metadata=self._client._rpc_metadata, + ) + + self._write_pbs = [] + self.write_results = results = list(commit_response.write_results) + self.commit_time = commit_response.commit_time + return results + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_type is None: + await self.commit() diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py new file mode 100644 index 000000000..4dd17035c --- /dev/null +++ b/google/cloud/firestore_v1/async_client.py @@ -0,0 +1,288 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client for interacting with the Google Cloud Firestore API. + +This is the base from which all interactions with the API occur. + +In the hierarchy of API concepts + +* a :class:`~google.cloud.firestore_v1.client.Client` owns a + :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference` +* a :class:`~google.cloud.firestore_v1.client.Client` owns a + :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` +""" + +from google.cloud.firestore_v1.base_client import ( + BaseClient, + DEFAULT_DATABASE, + _CLIENT_INFO, + _reference_info, + _parse_batch_get, + _get_doc_mask, + _path_helper, +) + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.async_query import AsyncQuery +from google.cloud.firestore_v1.async_batch import AsyncWriteBatch +from google.cloud.firestore_v1.async_collection import AsyncCollectionReference +from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + +class AsyncClient(BaseClient): + """Client for interacting with Google Cloud Firestore API. + + .. note:: + + Since the Cloud Firestore API requires the gRPC transport, no + ``_http`` argument is accepted by this class. + + Args: + project (Optional[str]): The project which the client acts on behalf + of. If not passed, falls back to the default inferred + from the environment. + credentials (Optional[~google.auth.credentials.Credentials]): The + OAuth2 Credentials to use for this client. If not passed, falls + back to the default inferred from the environment. + database (Optional[str]): The database name that the client targets. + For now, :attr:`DEFAULT_DATABASE` (the default value) is the + only valid database. + client_info (Optional[google.api_core.gapic_v1.client_info.ClientInfo]): + The client info used to send a user-agent string along with API + requests. If ``None``, then default info will be used. Generally, + you only need to set this if you're developing your own library + or partner tool. + client_options (Union[dict, google.api_core.client_options.ClientOptions]): + Client options used to set user options on the client. API Endpoint + should be set through client_options. + """ + + def __init__( + self, + project=None, + credentials=None, + database=DEFAULT_DATABASE, + client_info=_CLIENT_INFO, + client_options=None, + ): + super(AsyncClient, self).__init__( + project=project, + credentials=credentials, + database=database, + client_info=client_info, + client_options=client_options, + ) + + def collection(self, *collection_path): + """Get a reference to a collection. + + For a top-level collection: + + .. code-block:: python + + >>> client.collection('top') + + For a sub-collection: + + .. code-block:: python + + >>> client.collection('mydocs/doc/subcol') + >>> # is the same as + >>> client.collection('mydocs', 'doc', 'subcol') + + Sub-collections can be nested deeper in a similar fashion. + + Args: + collection_path (Tuple[str, ...]): Can either be + + * A single ``/``-delimited path to a collection + * A tuple of collection path segments + + Returns: + :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`: + A reference to a collection in the Firestore database. + """ + return AsyncCollectionReference(*_path_helper(collection_path), client=self) + + def collection_group(self, collection_id): + """ + Creates and returns a new AsyncQuery that includes all documents in the + database that are contained in a collection or subcollection with the + given collection_id. + + .. code-block:: python + + >>> query = client.collection_group('mygroup') + + Args: + collection_id (str) Identifies the collections to query over. + + Every collection or subcollection with this ID as the last segment of its + path will be included. Cannot contain a slash. + + Returns: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: + The created AsyncQuery. + """ + return AsyncQuery( + self._get_collection_reference(collection_id), all_descendants=True + ) + + def document(self, *document_path): + """Get a reference to a document in a collection. + + For a top-level document: + + .. code-block:: python + + >>> client.document('collek/shun') + >>> # is the same as + >>> client.document('collek', 'shun') + + For a document in a sub-collection: + + .. code-block:: python + + >>> client.document('mydocs/doc/subcol/child') + >>> # is the same as + >>> client.document('mydocs', 'doc', 'subcol', 'child') + + Documents in sub-collections can be nested deeper in a similar fashion. + + Args: + document_path (Tuple[str, ...]): Can either be + + * A single ``/``-delimited path to a document + * A tuple of document path segments + + Returns: + :class:`~google.cloud.firestore_v1.document.AsyncDocumentReference`: + A reference to a document in a collection. + """ + return AsyncDocumentReference( + *self._document_path_helper(*document_path), client=self + ) + + async def get_all(self, references, field_paths=None, transaction=None): + """Retrieve a batch of documents. + + .. note:: + + Documents returned by this method are not guaranteed to be + returned in the same order that they are given in ``references``. + + .. note:: + + If multiple ``references`` refer to the same document, the server + will only return one result. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + references (List[.AsyncDocumentReference, ...]): Iterable of document + references to be retrieved. + field_paths (Optional[Iterable[str, ...]]): An iterable of field + paths (``.``-delimited list of field names) to use as a + projection of document fields in the returned results. If + no value is provided, all fields will be returned. + transaction (Optional[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`]): + An existing transaction that these ``references`` will be + retrieved in. + + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + document_paths, reference_map = _reference_info(references) + mask = _get_doc_mask(field_paths) + response_iterator = self._firestore_api.batch_get_documents( + request={ + "database": self._database_string, + "documents": document_paths, + "mask": mask, + "transaction": _helpers.get_transaction_id(transaction), + }, + metadata=self._rpc_metadata, + ) + + for get_doc_response in response_iterator: + yield _parse_batch_get(get_doc_response, reference_map, self) + + async def collections(self): + """List top-level collections of the client's database. + + Returns: + Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: + iterator of subcollections of the current document. + """ + iterator = self._firestore_api.list_collection_ids( + request={"parent": "{}/documents".format(self._database_string)}, + metadata=self._rpc_metadata, + ) + + while True: + for i in iterator.collection_ids: + yield self.collection(i) + if iterator.next_page_token: + iterator = self._firestore_api.list_collection_ids( + request={ + "parent": "{}/documents".format(self._database_string), + "page_token": iterator.next_page_token, + }, + metadata=self._rpc_metadata, + ) + else: + return + + # TODO(microgen): currently this method is rewritten to iterate/page itself. + # https://github.com/googleapis/gapic-generator-python/issues/516 + # it seems the generator ought to be able to do this itself. + # iterator.client = self + # iterator.item_to_value = _item_to_collection_ref + # return iterator + + def batch(self): + """Get a batch instance from this client. + + Returns: + :class:`~google.cloud.firestore_v1.async_batch.AsyncWriteBatch`: + A "write" batch to be used for accumulating document changes and + sending the changes all at once. + """ + return AsyncWriteBatch(self) + + def transaction(self, **kwargs): + """Get a transaction that uses this client. + + See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for + more information on transactions and the constructor arguments. + + Args: + kwargs (Dict[str, Any]): The keyword arguments (other than + ``client``) to pass along to the + :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` + constructor. + + Returns: + :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`: + A transaction attached to this client. + """ + return AsyncTransaction(self, **kwargs) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py new file mode 100644 index 000000000..aa09e3d9a --- /dev/null +++ b/google/cloud/firestore_v1/async_collection.py @@ -0,0 +1,196 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing collections for the Google Cloud Firestore API.""" +import warnings + + +from google.cloud.firestore_v1.base_collection import ( + BaseCollectionReference, + _auto_id, + _item_to_document_ref, +) +from google.cloud.firestore_v1 import async_query +from google.cloud.firestore_v1.watch import Watch +from google.cloud.firestore_v1 import async_document + + +class AsyncCollectionReference(BaseCollectionReference): + """A reference to a collection in a Firestore database. + + The collection may already exist or this class can facilitate creation + of documents within the collection. + + Args: + path (Tuple[str, ...]): The components in the collection path. + This is a series of strings representing each collection and + sub-collection ID, as well as the document IDs for any documents + that contain a sub-collection. + kwargs (dict): The keyword arguments for the constructor. The only + supported keyword is ``client`` and it must be a + :class:`~google.cloud.firestore_v1.client.Client` if provided. It + represents the client that created this collection reference. + + Raises: + ValueError: if + + * the ``path`` is empty + * there are an even number of elements + * a collection ID in ``path`` is not a string + * a document ID in ``path`` is not a string + TypeError: If a keyword other than ``client`` is used. + """ + + def __init__(self, *path, **kwargs): + super(AsyncCollectionReference, self).__init__(*path, **kwargs) + + def _query(self): + """Query factory. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query` + """ + return async_query.AsyncQuery(self) + + async def add(self, document_data, document_id=None): + """Create a document in the Firestore database with the provided data. + + Args: + document_data (dict): Property names and values to use for + creating the document. + document_id (Optional[str]): The document identifier within the + current collection. If not provided, an ID will be + automatically assigned by the server (the assigned ID will be + a random 20 character string composed of digits, + uppercase and lowercase letters). + + Returns: + Tuple[:class:`google.protobuf.timestamp_pb2.Timestamp`, \ + :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`]: + Pair of + + * The ``update_time`` when the document was created/overwritten. + * A document reference for the created document. + + Raises: + ~google.cloud.exceptions.Conflict: If ``document_id`` is provided + and the document already exists. + """ + if document_id is None: + document_id = _auto_id() + + document_ref = self.document(document_id) + write_result = await document_ref.create(document_data) + return write_result.update_time, document_ref + + async def list_documents(self, page_size=None): + """List all subdocuments of the current collection. + + Args: + page_size (Optional[int]]): The maximum number of documents + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + + Returns: + Sequence[:class:`~google.cloud.firestore_v1.collection.DocumentReference`]: + iterator of subdocuments of the current collection. If the + collection does not exist at the time of `snapshot`, the + iterator will be empty + """ + parent, _ = self._parent_info() + + iterator = self._client._firestore_api.list_documents( + request={ + "parent": parent, + "collection_id": self.id, + "page_size": page_size, + "show_missing": True, + }, + metadata=self._client._rpc_metadata, + ) + return (_item_to_document_ref(self, i) for i in iterator) + + async def get(self, transaction=None): + """Deprecated alias for :meth:`stream`.""" + warnings.warn( + "'Collection.get' is deprecated: please use 'Collection.stream' instead.", + DeprecationWarning, + stacklevel=2, + ) + async for d in self.stream(transaction=transaction): + yield d + + async def stream(self, transaction=None): + """Read the documents in this collection. + + This sends a ``RunQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + .. note:: + + The underlying stream of responses will time out after + the ``max_rpc_timeout_millis`` value set in the GAPIC + client configuration for the ``RunQuery`` API. Snapshots + not consumed from the iterator before that point will be lost. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ + Transaction`]): + An existing transaction that the query will run in. + + Yields: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + The next document that fulfills the query. + """ + query = async_query.AsyncQuery(self) + async for d in query.stream(transaction=transaction): + yield d + + def on_snapshot(self, callback): + """Monitor the documents in this collection. + + This starts a watch on this collection using a background thread. The + provided callback is run on the snapshot of the documents. + + Args: + callback (Callable[[:class:`~google.cloud.firestore.collection.CollectionSnapshot`], NoneType]): + a callback to run when a change occurs. + + Example: + from google.cloud import firestore_v1 + + db = firestore_v1.Client() + collection_ref = db.collection(u'users') + + def on_snapshot(collection_snapshot, changes, read_time): + for doc in collection_snapshot.documents: + print(u'{} => {}'.format(doc.id, doc.to_dict())) + + # Watch this collection + collection_watch = collection_ref.on_snapshot(on_snapshot) + + # Terminate this watch + collection_watch.unsubscribe() + """ + return Watch.for_query( + self._query(), + callback, + async_document.DocumentSnapshot, + async_document.AsyncDocumentReference, + ) diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py new file mode 100644 index 000000000..00672153c --- /dev/null +++ b/google/cloud/firestore_v1/async_document.py @@ -0,0 +1,425 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing documents for the Google Cloud Firestore API.""" + +import six + +from google.cloud.firestore_v1.base_document import ( + BaseDocumentReference, + DocumentSnapshot, + _first_write_result, +) + +from google.api_core import exceptions +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.types import common +from google.cloud.firestore_v1.watch import Watch + + +class AsyncDocumentReference(BaseDocumentReference): + """A reference to a document in a Firestore database. + + The document may already exist or can be created by this class. + + Args: + path (Tuple[str, ...]): The components in the document path. + This is a series of strings representing each collection and + sub-collection ID, as well as the document IDs for any documents + that contain a sub-collection (as well as the base document). + kwargs (dict): The keyword arguments for the constructor. The only + supported keyword is ``client`` and it must be a + :class:`~google.cloud.firestore_v1.client.Client`. It represents + the client that created this document reference. + + Raises: + ValueError: if + + * the ``path`` is empty + * there are an even number of elements + * a collection ID in ``path`` is not a string + * a document ID in ``path`` is not a string + TypeError: If a keyword other than ``client`` is used. + """ + + def __init__(self, *path, **kwargs): + super(AsyncDocumentReference, self).__init__(*path, **kwargs) + + async def create(self, document_data): + """Create the current document in the Firestore database. + + Args: + document_data (dict): Property names and values to use for + creating a document. + + Returns: + :class:`~google.cloud.firestore_v1.types.WriteResult`: + The write result corresponding to the committed document. + A write result contains an ``update_time`` field. + + Raises: + :class:`~google.cloud.exceptions.Conflict`: + If the document already exists. + """ + batch = self._client.batch() + batch.create(self, document_data) + write_results = await batch.commit() + return _first_write_result(write_results) + + async def set(self, document_data, merge=False): + """Replace the current document in the Firestore database. + + A write ``option`` can be specified to indicate preconditions of + the "set" operation. If no ``option`` is specified and this document + doesn't exist yet, this method will create it. + + Overwrites all content for the document with the fields in + ``document_data``. This method performs almost the same functionality + as :meth:`create`. The only difference is that this method doesn't + make any requirements on the existence of the document (unless + ``option`` is used), whereas as :meth:`create` will fail if the + document already exists. + + Args: + document_data (dict): Property names and values to use for + replacing a document. + merge (Optional[bool] or Optional[List]): + If True, apply merging instead of overwriting the state + of the document. + + Returns: + :class:`~google.cloud.firestore_v1.types.WriteResult`: + The write result corresponding to the committed document. A write + result contains an ``update_time`` field. + """ + batch = self._client.batch() + batch.set(self, document_data, merge=merge) + write_results = await batch.commit() + return _first_write_result(write_results) + + async def update(self, field_updates, option=None): + """Update an existing document in the Firestore database. + + By default, this method verifies that the document exists on the + server before making updates. A write ``option`` can be specified to + override these preconditions. + + Each key in ``field_updates`` can either be a field name or a + **field path** (For more information on **field paths**, see + :meth:`~google.cloud.firestore_v1.client.Client.field_path`.) To + illustrate this, consider a document with + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + }, + 'other': True, + } + + stored on the server. If the field name is used in the update: + + .. code-block:: python + + >>> field_updates = { + ... 'foo': { + ... 'quux': 800, + ... }, + ... } + >>> document.update(field_updates) + + then all of ``foo`` will be overwritten on the server and the new + value will be + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'quux': 800, + }, + 'other': True, + } + + On the other hand, if a ``.``-delimited **field path** is used in the + update: + + .. code-block:: python + + >>> field_updates = { + ... 'foo.quux': 800, + ... } + >>> document.update(field_updates) + + then only ``foo.quux`` will be updated on the server and the + field ``foo.bar`` will remain intact: + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + 'quux': 800, + }, + 'other': True, + } + + .. warning:: + + A **field path** can only be used as a top-level key in + ``field_updates``. + + To delete / remove a field from an existing document, use the + :attr:`~google.cloud.firestore_v1.transforms.DELETE_FIELD` sentinel. + So with the example above, sending + + .. code-block:: python + + >>> field_updates = { + ... 'other': firestore.DELETE_FIELD, + ... } + >>> document.update(field_updates) + + would update the value on the server to: + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + }, + } + + To set a field to the current time on the server when the + update is received, use the + :attr:`~google.cloud.firestore_v1.transforms.SERVER_TIMESTAMP` + sentinel. + Sending + + .. code-block:: python + + >>> field_updates = { + ... 'foo.now': firestore.SERVER_TIMESTAMP, + ... } + >>> document.update(field_updates) + + would update the value on the server to: + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + 'now': datetime.datetime(2012, ...), + }, + 'other': True, + } + + Args: + field_updates (dict): Field names or paths to update and values + to update with. + option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + + Returns: + :class:`~google.cloud.firestore_v1.types.WriteResult`: + The write result corresponding to the updated document. A write + result contains an ``update_time`` field. + + Raises: + ~google.cloud.exceptions.NotFound: If the document does not exist. + """ + batch = self._client.batch() + batch.update(self, field_updates, option=option) + write_results = await batch.commit() + return _first_write_result(write_results) + + async def delete(self, option=None): + """Delete the current document in the Firestore database. + + Args: + option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + + Returns: + :class:`google.protobuf.timestamp_pb2.Timestamp`: + The time that the delete request was received by the server. + If the document did not exist when the delete was sent (i.e. + nothing was deleted), this method will still succeed and will + still return the time that the request was received by the server. + """ + write_pb = _helpers.pb_for_delete(self._document_path, option) + commit_response = self._client._firestore_api.commit( + request={ + "database": self._client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=self._client._rpc_metadata, + ) + + return commit_response.commit_time + + async def get(self, field_paths=None, transaction=None): + """Retrieve a snapshot of the current document. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + field_paths (Optional[Iterable[str, ...]]): An iterable of field + paths (``.``-delimited list of field names) to use as a + projection of document fields in the returned results. If + no value is provided, all fields will be returned. + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this reference + will be retrieved in. + + Returns: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + A snapshot of the current document. If the document does not + exist at the time of the snapshot is taken, the snapshot's + :attr:`reference`, :attr:`data`, :attr:`update_time`, and + :attr:`create_time` attributes will all be ``None`` and + its :attr:`exists` attribute will be ``False``. + """ + if isinstance(field_paths, six.string_types): + raise ValueError("'field_paths' must be a sequence of paths, not a string.") + + if field_paths is not None: + mask = common.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None + + firestore_api = self._client._firestore_api + try: + document_pb = firestore_api.get_document( + request={ + "name": self._document_path, + "mask": mask, + "transaction": _helpers.get_transaction_id(transaction), + }, + metadata=self._client._rpc_metadata, + ) + except exceptions.NotFound: + data = None + exists = False + create_time = None + update_time = None + else: + data = _helpers.decode_dict(document_pb.fields, self._client) + exists = True + create_time = document_pb.create_time + update_time = document_pb.update_time + + return DocumentSnapshot( + reference=self, + data=data, + exists=exists, + read_time=None, # No server read_time available + create_time=create_time, + update_time=update_time, + ) + + async def collections(self, page_size=None): + """List subcollections of the current document. + + Args: + page_size (Optional[int]]): The maximum number of collections + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + + Returns: + Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: + iterator of subcollections of the current document. If the + document does not exist at the time of `snapshot`, the + iterator will be empty + """ + iterator = self._client._firestore_api.list_collection_ids( + request={"parent": self._document_path, "page_size": page_size}, + metadata=self._client._rpc_metadata, + ) + + while True: + for i in iterator.collection_ids: + yield self.collection(i) + if iterator.next_page_token: + iterator = self._client._firestore_api.list_collection_ids( + request={ + "parent": self._document_path, + "page_size": page_size, + "page_token": iterator.next_page_token, + }, + metadata=self._client._rpc_metadata, + ) + else: + return + + # TODO(microgen): currently this method is rewritten to iterate/page itself. + # it seems the generator ought to be able to do this itself. + # iterator.document = self + # iterator.item_to_value = _item_to_collection_ref + # return iterator + + def on_snapshot(self, callback): + """Watch this document. + + This starts a watch on this document using a background thread. The + provided callback is run on the snapshot. + + Args: + callback(Callable[[:class:`~google.cloud.firestore.document.DocumentSnapshot`], NoneType]): + a callback to run when a change occurs + + Example: + + .. code-block:: python + + from google.cloud import firestore_v1 + + db = firestore_v1.Client() + collection_ref = db.collection(u'users') + + def on_snapshot(document_snapshot, changes, read_time): + doc = document_snapshot + print(u'{} => {}'.format(doc.id, doc.to_dict())) + + doc_ref = db.collection(u'users').document( + u'alovelace' + unique_resource_id()) + + # Watch this document + doc_watch = doc_ref.on_snapshot(on_snapshot) + + # Terminate this watch + doc_watch.unsubscribe() + """ + return Watch.for_document( + self, callback, DocumentSnapshot, AsyncDocumentReference + ) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py new file mode 100644 index 000000000..dea0c960b --- /dev/null +++ b/google/cloud/firestore_v1/async_query.py @@ -0,0 +1,207 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing queries for the Google Cloud Firestore API. + +A :class:`~google.cloud.firestore_v1.query.Query` can be created directly from +a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be +a more common way to create a query than direct usage of the constructor. +""" +import warnings + +from google.cloud.firestore_v1.base_query import ( + BaseQuery, + _query_response_to_snapshot, + _collection_group_query_response_to_snapshot, +) + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import async_document +from google.cloud.firestore_v1.watch import Watch + + +class AsyncQuery(BaseQuery): + """Represents a query to the Firestore API. + + Instances of this class are considered immutable: all methods that + would modify an instance instead return a new instance. + + Args: + parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + The collection that this query applies to. + projection (Optional[:class:`google.cloud.proto.firestore.v1.\ + query.StructuredQuery.Projection`]): + A projection of document fields to limit the query results to. + field_filters (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ + query.StructuredQuery.FieldFilter`, ...]]): + The filters to be applied in the query. + orders (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ + query.StructuredQuery.Order`, ...]]): + The "order by" entries to use in the query. + limit (Optional[int]): + The maximum number of documents the query is allowed to return. + offset (Optional[int]): + The number of results to skip. + start_at (Optional[Tuple[dict, bool]]): + Two-tuple of : + + * a mapping of fields. Any field that is present in this mapping + must also be present in ``orders`` + * an ``after`` flag + + The fields and the flag combine to form a cursor used as + a starting point in a query result set. If the ``after`` + flag is :data:`True`, the results will start just after any + documents which have fields matching the cursor, otherwise + any matching documents will be included in the result set. + When the query is formed, the document values + will be used in the order given by ``orders``. + end_at (Optional[Tuple[dict, bool]]): + Two-tuple of: + + * a mapping of fields. Any field that is present in this mapping + must also be present in ``orders`` + * a ``before`` flag + + The fields and the flag combine to form a cursor used as + an ending point in a query result set. If the ``before`` + flag is :data:`True`, the results will end just before any + documents which have fields matching the cursor, otherwise + any matching documents will be included in the result set. + When the query is formed, the document values + will be used in the order given by ``orders``. + all_descendants (Optional[bool]): + When false, selects only collections that are immediate children + of the `parent` specified in the containing `RunQueryRequest`. + When true, selects all descendant collections. + """ + + def __init__( + self, + parent, + projection=None, + field_filters=(), + orders=(), + limit=None, + offset=None, + start_at=None, + end_at=None, + all_descendants=False, + ): + super(AsyncQuery, self).__init__( + parent=parent, + projection=projection, + field_filters=field_filters, + orders=orders, + limit=limit, + offset=offset, + start_at=start_at, + end_at=end_at, + all_descendants=all_descendants, + ) + + async def get(self, transaction=None): + """Deprecated alias for :meth:`stream`.""" + warnings.warn( + "'AsyncQuery.get' is deprecated: please use 'AsyncQuery.stream' instead.", + DeprecationWarning, + stacklevel=2, + ) + async for d in self.stream(transaction=transaction): + yield d + + async def stream(self, transaction=None): + """Read the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + .. note:: + + The underlying stream of responses will time out after + the ``max_rpc_timeout_millis`` value set in the GAPIC + client configuration for the ``RunQuery`` API. Snapshots + not consumed from the iterator before that point will be lost. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + + Yields: + :class:`~google.cloud.firestore_v1.async_document.DocumentSnapshot`: + The next document that fulfills the query. + """ + parent_path, expected_prefix = self._parent._parent_info() + response_iterator = self._client._firestore_api.run_query( + request={ + "parent": parent_path, + "structured_query": self._to_protobuf(), + "transaction": _helpers.get_transaction_id(transaction), + }, + metadata=self._client._rpc_metadata, + ) + + for response in response_iterator: + if self._all_descendants: + snapshot = _collection_group_query_response_to_snapshot( + response, self._parent + ) + else: + snapshot = _query_response_to_snapshot( + response, self._parent, expected_prefix + ) + if snapshot is not None: + yield snapshot + + def on_snapshot(self, callback): + """Monitor the documents in this collection that match this query. + + This starts a watch on this query using a background thread. The + provided callback is run on the snapshot of the documents. + + Args: + callback(Callable[[:class:`~google.cloud.firestore.query.QuerySnapshot`], NoneType]): + a callback to run when a change occurs. + + Example: + + .. code-block:: python + + from google.cloud import firestore_v1 + + db = firestore_v1.Client() + query_ref = db.collection(u'users').where("user", "==", u'Ada') + + def on_snapshot(docs, changes, read_time): + for doc in docs: + print(u'{} => {}'.format(doc.id, doc.to_dict())) + + # Watch this query + query_watch = query_ref.on_snapshot(on_snapshot) + + # Terminate this watch + query_watch.unsubscribe() + """ + return Watch.for_query( + self, + callback, + async_document.DocumentSnapshot, + async_document.AsyncDocumentReference, + ) diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py new file mode 100644 index 000000000..569025465 --- /dev/null +++ b/google/cloud/firestore_v1/async_transaction.py @@ -0,0 +1,372 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for applying Google Cloud Firestore changes in a transaction.""" + + +import asyncio +import random + +import six + +from google.cloud.firestore_v1.base_transaction import ( + _BaseTransactional, + BaseTransaction, + MAX_ATTEMPTS, + _CANT_BEGIN, + _CANT_ROLLBACK, + _CANT_COMMIT, + _WRITE_READ_ONLY, + _INITIAL_SLEEP, + _MAX_SLEEP, + _MULTIPLIER, + _EXCEED_ATTEMPTS_TEMPLATE, +) + +from google.api_core import exceptions +from google.cloud.firestore_v1 import async_batch +from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_query import AsyncQuery + + +class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): + """Accumulate read-and-write operations to be sent in a transaction. + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + The client that created this transaction. + max_attempts (Optional[int]): The maximum number of attempts for + the transaction (i.e. allowing retries). Defaults to + :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`. + read_only (Optional[bool]): Flag indicating if the transaction + should be read-only or should allow writes. Defaults to + :data:`False`. + """ + + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): + super(AsyncTransaction, self).__init__(client) + BaseTransaction.__init__(self, max_attempts, read_only) + + def _add_write_pbs(self, write_pbs): + """Add `Write`` protobufs to this transaction. + + Args: + write_pbs (List[google.cloud.proto.firestore.v1.\ + write.Write]): A list of write protobufs to be added. + + Raises: + ValueError: If this transaction is read-only. + """ + if self._read_only: + raise ValueError(_WRITE_READ_ONLY) + + super(AsyncTransaction, self)._add_write_pbs(write_pbs) + + async def _begin(self, retry_id=None): + """Begin the transaction. + + Args: + retry_id (Optional[bytes]): Transaction ID of a transaction to be + retried. + + Raises: + ValueError: If the current transaction has already begun. + """ + if self.in_progress: + msg = _CANT_BEGIN.format(self._id) + raise ValueError(msg) + + transaction_response = self._client._firestore_api.begin_transaction( + request={ + "database": self._client._database_string, + "options": self._options_protobuf(retry_id), + }, + metadata=self._client._rpc_metadata, + ) + self._id = transaction_response.transaction + + async def _rollback(self): + """Roll back the transaction. + + Raises: + ValueError: If no transaction is in progress. + """ + if not self.in_progress: + raise ValueError(_CANT_ROLLBACK) + + try: + # NOTE: The response is just ``google.protobuf.Empty``. + self._client._firestore_api.rollback( + request={ + "database": self._client._database_string, + "transaction": self._id, + }, + metadata=self._client._rpc_metadata, + ) + finally: + self._clean_up() + + async def _commit(self): + """Transactionally commit the changes accumulated. + + Returns: + List[:class:`google.cloud.proto.firestore.v1.write.WriteResult`, ...]: + The write results corresponding to the changes committed, returned + in the same order as the changes were applied to this transaction. + A write result contains an ``update_time`` field. + + Raises: + ValueError: If no transaction is in progress. + """ + if not self.in_progress: + raise ValueError(_CANT_COMMIT) + + commit_response = await _commit_with_retry( + self._client, self._write_pbs, self._id + ) + + self._clean_up() + return list(commit_response.write_results) + + async def get_all(self, references): + """Retrieves multiple documents from Firestore. + + Args: + references (List[.AsyncDocumentReference, ...]): Iterable of document + references to be retrieved. + + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + return self._client.get_all(references, transaction=self) + + async def get(self, ref_or_query): + """ + Retrieve a document or a query result from the database. + Args: + ref_or_query The document references or query object to return. + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + if isinstance(ref_or_query, AsyncDocumentReference): + return self._client.get_all([ref_or_query], transaction=self) + elif isinstance(ref_or_query, AsyncQuery): + return ref_or_query.stream(transaction=self) + else: + raise ValueError( + 'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.' + ) + + +class _AsyncTransactional(_BaseTransactional): + """Provide a callable object to use as a transactional decorater. + + This is surfaced via + :func:`~google.cloud.firestore_v1.async_transaction.transactional`. + + Args: + to_wrap (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]): + A callable that should be run (and retried) in a transaction. + """ + + def __init__(self, to_wrap): + super(_AsyncTransactional, self).__init__(to_wrap) + + async def _pre_commit(self, transaction, *args, **kwargs): + """Begin transaction and call the wrapped callable. + + If the callable raises an exception, the transaction will be rolled + back. If not, the transaction will be "ready" for ``Commit`` (i.e. + it will have staged writes). + + Args: + transaction + (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + A transaction to execute the callable within. + args (Tuple[Any, ...]): The extra positional arguments to pass + along to the wrapped callable. + kwargs (Dict[str, Any]): The extra keyword arguments to pass + along to the wrapped callable. + + Returns: + Any: result of the wrapped callable. + + Raises: + Exception: Any failure caused by ``to_wrap``. + """ + # Force the ``transaction`` to be not "in progress". + transaction._clean_up() + await transaction._begin(retry_id=self.retry_id) + + # Update the stored transaction IDs. + self.current_id = transaction._id + if self.retry_id is None: + self.retry_id = self.current_id + try: + return self.to_wrap(transaction, *args, **kwargs) + except: # noqa + # NOTE: If ``rollback`` fails this will lose the information + # from the original failure. + await transaction._rollback() + raise + + async def _maybe_commit(self, transaction): + """Try to commit the transaction. + + If the transaction is read-write and the ``Commit`` fails with the + ``ABORTED`` status code, it will be retried. Any other failure will + not be caught. + + Args: + transaction + (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + The transaction to be ``Commit``-ed. + + Returns: + bool: Indicating if the commit succeeded. + """ + try: + await transaction._commit() + return True + except exceptions.GoogleAPICallError as exc: + if transaction._read_only: + raise + + if isinstance(exc, exceptions.Aborted): + # If a read-write transaction returns ABORTED, retry. + return False + else: + raise + + async def __call__(self, transaction, *args, **kwargs): + """Execute the wrapped callable within a transaction. + + Args: + transaction + (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + A transaction to execute the callable within. + args (Tuple[Any, ...]): The extra positional arguments to pass + along to the wrapped callable. + kwargs (Dict[str, Any]): The extra keyword arguments to pass + along to the wrapped callable. + + Returns: + Any: The result of the wrapped callable. + + Raises: + ValueError: If the transaction does not succeed in + ``max_attempts``. + """ + self._reset() + + for attempt in six.moves.xrange(transaction._max_attempts): + result = await self._pre_commit(transaction, *args, **kwargs) + succeeded = await self._maybe_commit(transaction) + if succeeded: + return result + + # Subsequent requests will use the failed transaction ID as part of + # the ``BeginTransactionRequest`` when restarting this transaction + # (via ``options.retry_transaction``). This preserves the "spot in + # line" of the transaction, so exponential backoff is not required + # in this case. + + await transaction._rollback() + msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + raise ValueError(msg) + + +def transactional(to_wrap): + """Decorate a callable so that it runs in a transaction. + + Args: + to_wrap + (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): + A callable that should be run (and retried) in a transaction. + + Returns: + Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]: + the wrapped callable. + """ + return _AsyncTransactional(to_wrap) + + +async def _commit_with_retry(client, write_pbs, transaction_id): + """Call ``Commit`` on the GAPIC client with retry / sleep. + + Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level + retry is handled by the underlying GAPICd client, but in this case it + doesn't because ``Commit`` is not always idempotent. But here we know it + is "idempotent"-like because it has a transaction ID. We also need to do + our own retry to special-case the ``INVALID_ARGUMENT`` error. + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + A client with GAPIC client and configuration details. + write_pbs (List[:class:`google.cloud.proto.firestore.v1.write.Write`, ...]): + A ``Write`` protobuf instance to be committed. + transaction_id (bytes): + ID of an existing transaction that this commit will run in. + + Returns: + :class:`google.cloud.firestore_v1.types.CommitResponse`: + The protobuf response from ``Commit``. + + Raises: + ~google.api_core.exceptions.GoogleAPICallError: If a non-retryable + exception is encountered. + """ + current_sleep = _INITIAL_SLEEP + while True: + try: + return client._firestore_api.commit( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": transaction_id, + }, + metadata=client._rpc_metadata, + ) + except exceptions.ServiceUnavailable: + # Retry + pass + + current_sleep = await _sleep(current_sleep) + + +async def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER): + """Sleep and produce a new sleep time. + + .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ + 2015/03/backoff.html + + Select a duration between zero and ``current_sleep``. It might seem + counterintuitive to have so much jitter, but + `Exponential Backoff And Jitter`_ argues that "full jitter" is + the best strategy. + + Args: + current_sleep (float): The current "max" for sleep interval. + max_sleep (Optional[float]): Eventual "max" sleep time + multiplier (Optional[float]): Multiplier for exponential backoff. + + Returns: + float: Newly doubled ``current_sleep`` or ``max_sleep`` (whichever + is smaller) + """ + actual_sleep = random.uniform(0.0, current_sleep) + await asyncio.sleep(actual_sleep) + return min(multiplier * current_sleep, max_sleep) diff --git a/noxfile.py b/noxfile.py index e02ef59ef..600ee8338 100644 --- a/noxfile.py +++ b/noxfile.py @@ -68,7 +68,7 @@ def lint_setup_py(session): session.run("python", "setup.py", "check", "--restructuredtext", "--strict") -def default(session): +def default(session, test_dir, ignore_dir): # Install all test dependencies, then install this package in-place. session.install("asyncmock", "pytest-asyncio") @@ -76,8 +76,7 @@ def default(session): session.install("-e", ".") # Run py.test against the unit tests. - session.run( - "py.test", + args = [ "--quiet", "--cov=google.cloud.firestore", "--cov=google.cloud", @@ -86,15 +85,31 @@ def default(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", - os.path.join("tests", "unit"), + test_dir, *session.posargs, - ) + ] + + if ignore_dir: + args.insert(0, f"--ignore={ignore_dir}") + + session.run("py.test", *args) @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) def unit(session): - """Run the unit test suite.""" - default(session) + """Run the unit test suite for sync tests.""" + default( + session, + os.path.join("tests", "unit"), + os.path.join("tests", "unit", "v1", "async"), + ) + + +@nox.session(python=["3.6", "3.7", "3.8"]) +def unit_async(session): + """Run the unit test suite for async tests.""" + session.install("pytest-asyncio", "aiounittest") + default(session, os.path.join("tests", "unit", "v1", "async"), None) @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) diff --git a/tests/unit/v1/async/__init__.py b/tests/unit/v1/async/__init__.py new file mode 100644 index 000000000..c6334245a --- /dev/null +++ b/tests/unit/v1/async/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py new file mode 100644 index 000000000..acb977d86 --- /dev/null +++ b/tests/unit/v1/async/test_async_batch.py @@ -0,0 +1,159 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import aiounittest + +import mock + + +class TestAsyncWriteBatch(aiounittest.AsyncTestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch + + return AsyncWriteBatch + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + batch = self._make_one(mock.sentinel.client) + self.assertIs(batch._client, mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + self.assertIsNone(batch.write_results) + self.assertIsNone(batch.commit_time) + + @pytest.mark.asyncio + async def test_commit(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.Mock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], + commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client("grand") + client._firestore_api_internal = firestore_api + + # Actually make a batch with some mutations and call commit(). + batch = self._make_one(client) + document1 = client.document("a", "b") + batch.create(document1, {"ten": 10, "buck": u"ets"}) + document2 = client.document("c", "d", "e", "f") + batch.delete(document2) + write_pbs = batch._write_pbs[::] + + write_results = await batch.commit() + self.assertEqual(write_results, list(commit_response.write_results)) + self.assertEqual(batch.write_results, write_results) + # TODO(microgen): v2: commit time is already a datetime, though not with nano + # self.assertEqual(batch.commit_time, timestamp) + # Make sure batch has no more "changes". + self.assertEqual(batch._write_pbs, []) + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_as_context_mgr_wo_error(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + firestore_api = mock.Mock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], + commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + client = _make_client() + client._firestore_api_internal = firestore_api + batch = self._make_one(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + async with batch as ctx_mgr: + self.assertIs(ctx_mgr, batch) + ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.delete(document2) + write_pbs = batch._write_pbs[::] + + self.assertEqual(batch.write_results, list(commit_response.write_results)) + # TODO(microgen): v2: commit time is already a datetime, though not with nano + # self.assertEqual(batch.commit_time, timestamp) + # Make sure batch has no more "changes". + self.assertEqual(batch._write_pbs, []) + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_as_context_mgr_w_error(self): + firestore_api = mock.Mock(spec=["commit"]) + client = _make_client() + client._firestore_api_internal = firestore_api + batch = self._make_one(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + with self.assertRaises(RuntimeError): + async with batch as ctx_mgr: + ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.delete(document2) + raise RuntimeError("testing") + + # batch still has its changes, as _aexit_ (and commit) is not invoked + # changes are preserved so commit can be retried + self.assertIsNone(batch.write_results) + self.assertIsNone(batch.commit_time) + self.assertEqual(len(batch._write_pbs), 2) + + firestore_api.commit.assert_not_called() + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="seventy-nine"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py new file mode 100644 index 000000000..6fd9b93d2 --- /dev/null +++ b/tests/unit/v1/async/test_async_client.py @@ -0,0 +1,464 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import datetime +import types +import aiounittest + +import mock + + +class TestAsyncClient(aiounittest.AsyncTestCase): + + PROJECT = "my-prahjekt" + + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_client import AsyncClient + + return AsyncClient + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def _make_default_one(self): + credentials = _make_credentials() + return self._make_one(project=self.PROJECT, credentials=credentials) + + def test_constructor(self): + from google.cloud.firestore_v1.async_client import _CLIENT_INFO + from google.cloud.firestore_v1.async_client import DEFAULT_DATABASE + + credentials = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=credentials) + self.assertEqual(client.project, self.PROJECT) + self.assertEqual(client._credentials, credentials) + self.assertEqual(client._database, DEFAULT_DATABASE) + self.assertIs(client._client_info, _CLIENT_INFO) + self.assertIsNone(client._emulator_host) + + def test_constructor_with_emulator_host(self): + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST + + credentials = _make_credentials() + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + client = self._make_one(project=self.PROJECT, credentials=credentials) + self.assertEqual(client._emulator_host, emulator_host) + getenv.assert_called_once_with(_FIRESTORE_EMULATOR_HOST) + + def test_constructor_explicit(self): + credentials = _make_credentials() + database = "now-db" + client_info = mock.Mock() + client_options = mock.Mock() + client = self._make_one( + project=self.PROJECT, + credentials=credentials, + database=database, + client_info=client_info, + client_options=client_options, + ) + self.assertEqual(client.project, self.PROJECT) + self.assertEqual(client._credentials, credentials) + self.assertEqual(client._database, database) + self.assertIs(client._client_info, client_info) + self.assertIs(client._client_options, client_options) + + def test_constructor_w_client_options(self): + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, + credentials=credentials, + client_options={"api_endpoint": "foo-firestore.googleapis.com"}, + ) + self.assertEqual(client._target, "foo-firestore.googleapis.com") + + def test_collection_factory(self): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + collection_id = "users" + client = self._make_default_one() + collection = client.collection(collection_id) + + self.assertEqual(collection._path, (collection_id,)) + self.assertIs(collection._client, client) + self.assertIsInstance(collection, AsyncCollectionReference) + + def test_collection_factory_nested(self): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + client = self._make_default_one() + parts = ("users", "alovelace", "beep") + collection_path = "/".join(parts) + collection1 = client.collection(collection_path) + + self.assertEqual(collection1._path, parts) + self.assertIs(collection1._client, client) + self.assertIsInstance(collection1, AsyncCollectionReference) + + # Make sure using segments gives the same result. + collection2 = client.collection(*parts) + self.assertEqual(collection2._path, parts) + self.assertIs(collection2._client, client) + self.assertIsInstance(collection2, AsyncCollectionReference) + + def test__get_collection_reference(self): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + client = self._make_default_one() + collection = client._get_collection_reference("collectionId") + + self.assertIs(collection._client, client) + self.assertIsInstance(collection, AsyncCollectionReference) + + def test_collection_group(self): + client = self._make_default_one() + query = client.collection_group("collectionId").where("foo", "==", u"bar") + + self.assertTrue(query._all_descendants) + self.assertEqual(query._field_filters[0].field.field_path, "foo") + self.assertEqual(query._field_filters[0].value.string_value, u"bar") + self.assertEqual( + query._field_filters[0].op, query._field_filters[0].Operator.EQUAL + ) + self.assertEqual(query._parent.id, "collectionId") + + def test_collection_group_no_slashes(self): + client = self._make_default_one() + with self.assertRaises(ValueError): + client.collection_group("foo/bar") + + def test_document_factory(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + parts = ("rooms", "roomA") + client = self._make_default_one() + doc_path = "/".join(parts) + document1 = client.document(doc_path) + + self.assertEqual(document1._path, parts) + self.assertIs(document1._client, client) + self.assertIsInstance(document1, AsyncDocumentReference) + + # Make sure using segments gives the same result. + document2 = client.document(*parts) + self.assertEqual(document2._path, parts) + self.assertIs(document2._client, client) + self.assertIsInstance(document2, AsyncDocumentReference) + + def test_document_factory_w_absolute_path(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + parts = ("rooms", "roomA") + client = self._make_default_one() + doc_path = "/".join(parts) + to_match = client.document(doc_path) + document1 = client.document(to_match._document_path) + + self.assertEqual(document1._path, parts) + self.assertIs(document1._client, client) + self.assertIsInstance(document1, AsyncDocumentReference) + + def test_document_factory_w_nested_path(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + client = self._make_default_one() + parts = ("rooms", "roomA", "shoes", "dressy") + doc_path = "/".join(parts) + document1 = client.document(doc_path) + + self.assertEqual(document1._path, parts) + self.assertIs(document1._client, client) + self.assertIsInstance(document1, AsyncDocumentReference) + + # Make sure using segments gives the same result. + document2 = client.document(*parts) + self.assertEqual(document2._path, parts) + self.assertIs(document2._client, client) + self.assertIsInstance(document2, AsyncDocumentReference) + + @pytest.mark.asyncio + async def test_collections(self): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + collection_ids = ["users", "projects"] + client = self._make_default_one() + firestore_api = mock.Mock(spec=["list_collection_ids"]) + client._firestore_api_internal = firestore_api + + # TODO(microgen): list_collection_ids isn't a pager. + # https://github.com/googleapis/gapic-generator-python/issues/516 + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + iterator = _Iterator(pages=[collection_ids]) + firestore_api.list_collection_ids.return_value = iterator + + collections = [c async for c in client.collections()] + + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, AsyncCollectionReference) + self.assertEqual(collection.parent, None) + self.assertEqual(collection.id, collection_id) + + base_path = client._database_string + "/documents" + firestore_api.list_collection_ids.assert_called_once_with( + request={"parent": base_path}, metadata=client._rpc_metadata + ) + + async def _get_all_helper(self, client, references, document_pbs, **kwargs): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["batch_get_documents"]) + response_iterator = iter(document_pbs) + firestore_api.batch_get_documents.return_value = response_iterator + + # Attach the fake GAPIC to a real client. + client._firestore_api_internal = firestore_api + + # Actually call get_all(). + snapshots = client.get_all(references, **kwargs) + self.assertIsInstance(snapshots, types.AsyncGeneratorType) + + return [s async for s in snapshots] + + def _info_for_get_all(self, data1, data2): + client = self._make_default_one() + document1 = client.document("pineapple", "lamp1") + document2 = client.document("pineapple", "lamp2") + + # Make response protobufs. + document_pb1, read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=read_time) + + document, read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=read_time) + + return client, document1, document2, response1, response2 + + @pytest.mark.asyncio + async def test_get_all(self): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + data1 = {"a": u"cheese"} + data2 = {"b": True, "c": 18} + info = self._info_for_get_all(data1, data2) + client, document1, document2, response1, response2 = info + + # Exercise the mocked ``batch_get_documents``. + field_paths = ["a", "b"] + snapshots = await self._get_all_helper( + client, + [document1, document2], + [response1, response2], + field_paths=field_paths, + ) + self.assertEqual(len(snapshots), 2) + + snapshot1 = snapshots[0] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document1) + self.assertEqual(snapshot1._data, data1) + + snapshot2 = snapshots[1] + self.assertIsInstance(snapshot2, DocumentSnapshot) + self.assertIs(snapshot2._reference, document2) + self.assertEqual(snapshot2._data, data2) + + # Verify the call to the mock. + doc_paths = [document1._document_path, document2._document_path] + mask = common.DocumentMask(field_paths=field_paths) + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_get_all_with_transaction(self): + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + data = {"so-much": 484} + info = self._info_for_get_all(data, {}) + client, document, _, response, _ = info + transaction = client.transaction() + txn_id = b"the-man-is-non-stop" + transaction._id = txn_id + + # Exercise the mocked ``batch_get_documents``. + snapshots = await self._get_all_helper( + client, [document], [response], transaction=transaction + ) + self.assertEqual(len(snapshots), 1) + + snapshot = snapshots[0] + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertIs(snapshot._reference, document) + self.assertEqual(snapshot._data, data) + + # Verify the call to the mock. + doc_paths = [document._document_path] + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_get_all_unknown_result(self): + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE + + info = self._info_for_get_all({"z": 28.5}, {}) + client, document, _, _, response = info + + # Exercise the mocked ``batch_get_documents``. + with self.assertRaises(ValueError) as exc_info: + await self._get_all_helper(client, [document], [response]) + + err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + # Verify the call to the mock. + doc_paths = [document._document_path] + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_get_all_wrong_order(self): + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + data1 = {"up": 10} + data2 = {"down": -10} + info = self._info_for_get_all(data1, data2) + client, document1, document2, response1, response2 = info + document3 = client.document("pineapple", "lamp3") + response3 = _make_batch_response(missing=document3._document_path) + + # Exercise the mocked ``batch_get_documents``. + snapshots = await self._get_all_helper( + client, [document1, document2, document3], [response2, response1, response3] + ) + + self.assertEqual(len(snapshots), 3) + + snapshot1 = snapshots[0] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document2) + self.assertEqual(snapshot1._data, data2) + + snapshot2 = snapshots[1] + self.assertIsInstance(snapshot2, DocumentSnapshot) + self.assertIs(snapshot2._reference, document1) + self.assertEqual(snapshot2._data, data1) + + self.assertFalse(snapshots[2].exists) + + # Verify the call to the mock. + doc_paths = [ + document1._document_path, + document2._document_path, + document3._document_path, + ] + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + def test_batch(self): + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch + + client = self._make_default_one() + batch = client.batch() + self.assertIsInstance(batch, AsyncWriteBatch) + self.assertIs(batch._client, client) + self.assertEqual(batch._write_pbs, []) + + def test_transaction(self): + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + client = self._make_default_one() + transaction = client.transaction(max_attempts=3, read_only=True) + self.assertIsInstance(transaction, AsyncTransaction) + self.assertEqual(transaction._write_pbs, []) + self.assertEqual(transaction._max_attempts, 3) + self.assertTrue(transaction._read_only) + self.assertIsNone(transaction._id) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_batch_response(**kwargs): + from google.cloud.firestore_v1.types import firestore + + return firestore.BatchGetDocumentsResponse(**kwargs) + + +def _doc_get_info(ref_string, values): + from google.cloud.firestore_v1.types import document + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + + document_pb = document.Document( + name=ref_string, + fields=_helpers.encode_dict(values), + create_time=create_time, + update_time=update_time, + ) + + return document_pb, read_time diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py new file mode 100644 index 000000000..680b0eb85 --- /dev/null +++ b/tests/unit/v1/async/test_async_collection.py @@ -0,0 +1,363 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import types +import aiounittest + +import mock +import six + + +class MockAsyncIter: + def __init__(self, count): + self.count = count + + async def __aiter__(self, **_): + for i in range(self.count): + yield i + + +class TestAsyncCollectionReference(aiounittest.AsyncTestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + return AsyncCollectionReference + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + @staticmethod + def _get_public_methods(klass): + return set().union( + *( + ( + name + for name, value in six.iteritems(class_.__dict__) + if ( + not name.startswith("_") + and isinstance(value, types.FunctionType) + ) + ) + for class_ in (klass,) + klass.__bases__ + ) + ) + + def test_query_method_matching(self): + from google.cloud.firestore_v1.async_query import AsyncQuery + + query_methods = self._get_public_methods(AsyncQuery) + klass = self._get_target_class() + collection_methods = self._get_public_methods(klass) + # Make sure every query method is present on + # ``AsyncCollectionReference``. + self.assertLessEqual(query_methods, collection_methods) + + def test_constructor(self): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = self._make_one( + collection_id1, document_id, collection_id2, client=client + ) + self.assertIs(collection._client, client) + expected_path = (collection_id1, document_id, collection_id2) + self.assertEqual(collection._path, expected_path) + + def test_constructor_invalid_path(self): + with self.assertRaises(ValueError): + self._make_one() + with self.assertRaises(ValueError): + self._make_one(99, "doc", "bad-collection-id") + with self.assertRaises(ValueError): + self._make_one("bad-document-ID", None, "sub-collection") + with self.assertRaises(ValueError): + self._make_one("Just", "A-Document") + + def test_constructor_invalid_kwarg(self): + with self.assertRaises(TypeError): + self._make_one("Coh-lek-shun", donut=True) + + @pytest.mark.asyncio + async def test_add_auto_assigned(self): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_create + + # Create a minimal fake GAPIC add attach it to a real client. + firestore_api = mock.Mock(spec=["create_document", "commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + firestore_api.commit.return_value = commit_response + create_doc_response = document.Document() + firestore_api.create_document.return_value = create_doc_response + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection. + collection = self._make_one("grand-parent", "parent", "child", client=client) + + # Actually call add() on our collection; include a transform to make + # sure transforms during adds work. + document_data = {"been": "here", "now": SERVER_TIMESTAMP} + + patch = mock.patch("google.cloud.firestore_v1.async_collection._auto_id") + random_doc_id = "DEADBEEF" + with patch as patched: + patched.return_value = random_doc_id + update_time, document_ref = await collection.add(document_data) + + # Verify the response and the mocks. + self.assertIs(update_time, mock.sentinel.update_time) + self.assertIsInstance(document_ref, AsyncDocumentReference) + self.assertIs(document_ref._client, client) + expected_path = collection._path + (random_doc_id,) + self.assertEqual(document_ref._path, expected_path) + + write_pbs = pbs_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + # Since we generate the ID locally, we don't call 'create_document'. + firestore_api.create_document.assert_not_called() + + @staticmethod + def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common.Precondition(exists=False), + ) + + @pytest.mark.asyncio + async def test_add_explicit_id(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection and call add(). + collection = self._make_one("parent", client=client) + document_data = {"zorp": 208.75, "i-did-not": b"know that"} + doc_id = "child" + update_time, document_ref = await collection.add( + document_data, document_id=doc_id + ) + + # Verify the response and the mocks. + self.assertIs(update_time, mock.sentinel.update_time) + self.assertIsInstance(document_ref, AsyncDocumentReference) + self.assertIs(document_ref._client, client) + self.assertEqual(document_ref._path, (collection.id, doc_id)) + + write_pb = self._write_pb_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def _list_documents_helper(self, page_size=None): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.services.firestore.client import FirestoreClient + from google.cloud.firestore_v1.types.document import Document + + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + client = _make_client() + template = client._database_string + "/documents/{}" + document_ids = ["doc-1", "doc-2"] + documents = [ + Document(name=template.format(document_id)) for document_id in document_ids + ] + iterator = _Iterator(pages=[documents]) + api_client = mock.create_autospec(FirestoreClient) + api_client.list_documents.return_value = iterator + client._firestore_api_internal = api_client + collection = self._make_one("collection", client=client) + + if page_size is not None: + documents = list(await collection.list_documents(page_size=page_size)) + else: + documents = list(await collection.list_documents()) + + # Verify the response and the mocks. + self.assertEqual(len(documents), len(document_ids)) + for document, document_id in zip(documents, document_ids): + self.assertIsInstance(document, AsyncDocumentReference) + self.assertEqual(document.parent, collection) + self.assertEqual(document.id, document_id) + + parent, _ = collection._parent_info() + api_client.list_documents.assert_called_once_with( + request={ + "parent": parent, + "collection_id": collection.id, + "page_size": page_size, + "show_missing": True, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_list_documents_wo_page_size(self): + await self._list_documents_helper() + + @pytest.mark.asyncio + async def test_list_documents_w_page_size(self): + await self._list_documents_helper(page_size=25) + + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) + @pytest.mark.asyncio + async def test_get(self, query_class): + import warnings + + query_class.return_value.stream.return_value = MockAsyncIter(3) + + collection = self._make_one("collection") + with warnings.catch_warnings(record=True) as warned: + get_response = collection.get() + + async for _ in get_response: + pass + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with(transaction=None) + + # Verify the deprecation + self.assertEqual(len(warned), 1) + self.assertIs(warned[0].category, DeprecationWarning) + + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) + @pytest.mark.asyncio + async def test_get_with_transaction(self, query_class): + import warnings + + query_class.return_value.stream.return_value = MockAsyncIter(3) + + collection = self._make_one("collection") + transaction = mock.sentinel.txn + with warnings.catch_warnings(record=True) as warned: + get_response = collection.get(transaction=transaction) + + async for _ in get_response: + pass + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with(transaction=transaction) + + # Verify the deprecation + self.assertEqual(len(warned), 1) + self.assertIs(warned[0].category, DeprecationWarning) + + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) + @pytest.mark.asyncio + async def test_stream(self, query_class): + query_class.return_value.stream.return_value = MockAsyncIter(3) + + collection = self._make_one("collection") + stream_response = collection.stream() + + async for _ in stream_response: + pass + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with(transaction=None) + + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) + @pytest.mark.asyncio + async def test_stream_with_transaction(self, query_class): + query_class.return_value.stream.return_value = MockAsyncIter(3) + + collection = self._make_one("collection") + transaction = mock.sentinel.txn + stream_response = collection.stream(transaction=transaction) + + async for _ in stream_response: + pass + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with(transaction=transaction) + + @mock.patch("google.cloud.firestore_v1.async_collection.Watch", autospec=True) + def test_on_snapshot(self, watch): + collection = self._make_one("collection") + collection.on_snapshot(None) + watch.for_query.assert_called_once() + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(): + from google.cloud.firestore_v1.async_client import AsyncClient + + credentials = _make_credentials() + return AsyncClient(project="project-project", credentials=credentials) diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py new file mode 100644 index 000000000..b59c7282b --- /dev/null +++ b/tests/unit/v1/async/test_async_document.py @@ -0,0 +1,511 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import collections +import aiounittest + +import mock + + +class TestAsyncDocumentReference(aiounittest.AsyncTestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + return AsyncDocumentReference + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + collection_id1 = "users" + document_id1 = "alovelace" + collection_id2 = "platform" + document_id2 = "*nix" + client = mock.MagicMock() + client.__hash__.return_value = 1234 + + document = self._make_one( + collection_id1, document_id1, collection_id2, document_id2, client=client + ) + self.assertIs(document._client, client) + expected_path = "/".join( + (collection_id1, document_id1, collection_id2, document_id2) + ) + self.assertEqual(document.path, expected_path) + + def test_constructor_invalid_path(self): + with self.assertRaises(ValueError): + self._make_one() + with self.assertRaises(ValueError): + self._make_one(None, "before", "bad-collection-id", "fifteen") + with self.assertRaises(ValueError): + self._make_one("bad-document-ID", None) + with self.assertRaises(ValueError): + self._make_one("Just", "A-Collection", "Sub") + + def test_constructor_invalid_kwarg(self): + with self.assertRaises(TypeError): + self._make_one("Coh-lek-shun", "Dahk-yu-mehnt", burger=18.75) + + @staticmethod + def _make_commit_repsonse(write_results=None): + from google.cloud.firestore_v1.types import firestore + + response = mock.create_autospec(firestore.CommitResponse) + response.write_results = write_results or [mock.sentinel.write_result] + response.commit_time = mock.sentinel.commit_time + return response + + @staticmethod + def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common.Precondition(exists=False), + ) + + @pytest.mark.asyncio + async def test_create(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock() + firestore_api.commit.mock_add_spec(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("foo", "twelve", client=client) + document_data = {"hello": "goodbye", "count": 99} + write_result = await document.create(document_data) + + # Verify the response and the mocks. + self.assertIs(write_result, mock.sentinel.write_result) + write_pb = self._write_pb_for_create(document._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_create_empty(self): + # Create a minimal fake GAPIC with a dummy response. + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + firestore_api = mock.Mock(spec=["commit"]) + document_reference = mock.create_autospec(AsyncDocumentReference) + snapshot = mock.create_autospec(DocumentSnapshot) + snapshot.exists = True + document_reference.get.return_value = snapshot + firestore_api.commit.return_value = self._make_commit_repsonse( + write_results=[document_reference] + ) + + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api + client.get_all = mock.MagicMock() + client.get_all.exists.return_value = True + + # Actually make a document and call create(). + document = self._make_one("foo", "twelve", client=client) + document_data = {} + write_result = await document.create(document_data) + self.assertTrue((await write_result.get()).exists) + + @staticmethod + def _write_pb_for_set(document_path, document_data, merge): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + write_pbs = write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ) + ) + if merge: + field_paths = [ + field_path + for field_path, value in _helpers.extract_fields( + document_data, _helpers.FieldPath() + ) + ] + field_paths = [ + field_path.to_api_repr() for field_path in sorted(field_paths) + ] + mask = common.DocumentMask(field_paths=sorted(field_paths)) + write_pbs._pb.update_mask.CopyFrom(mask._pb) + return write_pbs + + @pytest.mark.asyncio + async def _set_helper(self, merge=False, **option_kwargs): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("db-dee-bee") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("User", "Interface", client=client) + document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} + write_result = await document.set(document_data, merge) + + # Verify the response and the mocks. + self.assertIs(write_result, mock.sentinel.write_result) + write_pb = self._write_pb_for_set(document._document_path, document_data, merge) + + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_set(self): + await self._set_helper() + + @pytest.mark.asyncio + async def test_set_merge(self): + await self._set_helper(merge=True) + + @staticmethod + def _write_pb_for_update(document_path, update_values, field_paths): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(update_values) + ), + update_mask=common.DocumentMask(field_paths=field_paths), + current_document=common.Precondition(exists=True), + ) + + @pytest.mark.asyncio + async def _update_helper(self, **option_kwargs): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = collections.OrderedDict( + (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) + ) + if option_kwargs: + option = client.write_option(**option_kwargs) + write_result = await document.update(field_updates, option=option) + else: + option = None + write_result = await document.update(field_updates) + + # Verify the response and the mocks. + self.assertIs(write_result, mock.sentinel.write_result) + update_values = { + "hello": field_updates["hello"], + "then": {"do": field_updates["then.do"]}, + } + field_paths = list(field_updates.keys()) + write_pb = self._write_pb_for_update( + document._document_path, update_values, sorted(field_paths) + ) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_update_with_exists(self): + with self.assertRaises(ValueError): + await self._update_helper(exists=True) + + @pytest.mark.asyncio + async def test_update(self): + await self._update_helper() + + @pytest.mark.asyncio + async def test_update_with_precondition(self): + from google.protobuf import timestamp_pb2 + + timestamp = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + await self._update_helper(last_update_time=timestamp) + + @pytest.mark.asyncio + async def test_empty_update(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = {} + with self.assertRaises(ValueError): + await document.update(field_updates) + + @pytest.mark.asyncio + async def _delete_helper(self, **option_kwargs): + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + + # Actually make a document and call delete(). + document = self._make_one("where", "we-are", client=client) + if option_kwargs: + option = client.write_option(**option_kwargs) + delete_time = await document.delete(option=option) + else: + option = None + delete_time = await document.delete() + + # Verify the response and the mocks. + self.assertIs(delete_time, mock.sentinel.commit_time) + write_pb = write.Write(delete=document._document_path) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_delete(self): + await self._delete_helper() + + @pytest.mark.asyncio + async def test_delete_with_option(self): + from google.protobuf import timestamp_pb2 + + timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + await self._delete_helper(last_update_time=timestamp_pb) + + @pytest.mark.asyncio + async def _get_helper( + self, field_paths=None, use_transaction=False, not_found=False + ): + from google.api_core.exceptions import NotFound + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.transaction import Transaction + + # Create a minimal fake GAPIC with a dummy response. + create_time = 123 + update_time = 234 + firestore_api = mock.Mock(spec=["get_document"]) + response = mock.create_autospec(document.Document) + response.fields = {} + response.create_time = create_time + response.update_time = update_time + + if not_found: + firestore_api.get_document.side_effect = NotFound("testing") + else: + firestore_api.get_document.return_value = response + + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + + document = self._make_one("where", "we-are", client=client) + + if use_transaction: + transaction = Transaction(client) + transaction_id = transaction._id = b"asking-me-2" + else: + transaction = None + + snapshot = await document.get(field_paths=field_paths, transaction=transaction) + + self.assertIs(snapshot.reference, document) + if not_found: + self.assertIsNone(snapshot._data) + self.assertFalse(snapshot.exists) + self.assertIsNone(snapshot.read_time) + self.assertIsNone(snapshot.create_time) + self.assertIsNone(snapshot.update_time) + else: + self.assertEqual(snapshot.to_dict(), {}) + self.assertTrue(snapshot.exists) + self.assertIsNone(snapshot.read_time) + self.assertIs(snapshot.create_time, create_time) + self.assertIs(snapshot.update_time, update_time) + + # Verify the request made to the API + if field_paths is not None: + mask = common.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None + + if use_transaction: + expected_transaction_id = transaction_id + else: + expected_transaction_id = None + + firestore_api.get_document.assert_called_once_with( + request={ + "name": document._document_path, + "mask": mask, + "transaction": expected_transaction_id, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_get_not_found(self): + await self._get_helper(not_found=True) + + @pytest.mark.asyncio + async def test_get_default(self): + await self._get_helper() + + @pytest.mark.asyncio + async def test_get_w_string_field_path(self): + with self.assertRaises(ValueError): + await self._get_helper(field_paths="foo") + + @pytest.mark.asyncio + async def test_get_with_field_path(self): + await self._get_helper(field_paths=["foo"]) + + @pytest.mark.asyncio + async def test_get_with_multiple_field_paths(self): + await self._get_helper(field_paths=["foo", "bar.baz"]) + + @pytest.mark.asyncio + async def test_get_with_transaction(self): + await self._get_helper(use_transaction=True) + + @pytest.mark.asyncio + async def _collections_helper(self, page_size=None): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + from google.cloud.firestore_v1.services.firestore.client import FirestoreClient + + # TODO(microgen): https://github.com/googleapis/gapic-generator-python/issues/516 + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + collection_ids = ["coll-1", "coll-2"] + iterator = _Iterator(pages=[collection_ids]) + api_client = mock.create_autospec(FirestoreClient) + api_client.list_collection_ids.return_value = iterator + + client = _make_client() + client._firestore_api_internal = api_client + + # Actually make a document and call delete(). + document = self._make_one("where", "we-are", client=client) + if page_size is not None: + collections = [c async for c in document.collections(page_size=page_size)] + else: + collections = [c async for c in document.collections()] + + # Verify the response and the mocks. + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, AsyncCollectionReference) + self.assertEqual(collection.parent, document) + self.assertEqual(collection.id, collection_id) + + api_client.list_collection_ids.assert_called_once_with( + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_collections_wo_page_size(self): + await self._collections_helper() + + @pytest.mark.asyncio + async def test_collections_w_page_size(self): + await self._collections_helper(page_size=10) + + @mock.patch("google.cloud.firestore_v1.async_document.Watch", autospec=True) + def test_on_snapshot(self, watch): + client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) + document = self._make_one("yellow", "mellow", client=client) + document.on_snapshot(None) + watch.for_document.assert_called_once() + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="project-project"): + from google.cloud.firestore_v1.async_client import AsyncClient + + credentials = _make_credentials() + return AsyncClient(project=project, credentials=credentials) diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py new file mode 100644 index 000000000..87305bfbc --- /dev/null +++ b/tests/unit/v1/async/test_async_query.py @@ -0,0 +1,380 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import types +import aiounittest + +import mock + +from tests.unit.v1.test_base_query import _make_credentials, _make_query_response + + +class TestAsyncQuery(aiounittest.AsyncTestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_query import AsyncQuery + + return AsyncQuery + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor_defaults(self): + query = self._make_one(mock.sentinel.parent) + self.assertIs(query._parent, mock.sentinel.parent) + self.assertIsNone(query._projection) + self.assertEqual(query._field_filters, ()) + self.assertEqual(query._orders, ()) + self.assertIsNone(query._limit) + self.assertIsNone(query._offset) + self.assertIsNone(query._start_at) + self.assertIsNone(query._end_at) + self.assertFalse(query._all_descendants) + + @pytest.mark.asyncio + async def test_get_simple(self): + import warnings + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = self._make_one(parent) + + with warnings.catch_warnings(record=True) as warned: + get_response = query.get() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("dee", "sleep")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + # Verify the deprecation + self.assertEqual(len(warned), 1) + self.assertIs(warned[0].category, DeprecationWarning) + + @pytest.mark.asyncio + async def test_stream_simple(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("dee", "sleep")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_stream_with_transaction(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Create a real-ish transaction for this client. + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + # Make a **real** collection reference as parent. + parent = client.collection("declaration") + + # Add a dummy response to the minimal fake GAPIC. + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream(transaction=transaction) + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("declaration", "burger")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_stream_no_results(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["run_query"]) + empty_response = _make_query_response() + run_query_response = iter([empty_response]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = self._make_one(parent) + + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + self.assertEqual([x async for x in get_response], []) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_stream_second_response_in_empty_stream(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["run_query"]) + empty_response1 = _make_query_response() + empty_response2 = _make_query_response() + run_query_response = iter([empty_response1, empty_response2]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = self._make_one(parent) + + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + self.assertEqual([x async for x in get_response], []) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_stream_with_skipped_results(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("talk", "and", "chew-gum") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + response_pb1 = _make_query_response(skipped_results=1) + name = "{}/clock".format(expected_prefix) + data = {"noon": 12, "nested": {"bird": 10.5}} + response_pb2 = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("talk", "and", "chew-gum", "clock")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_stream_empty_after_first_response(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/bark".format(expected_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("charles", "bark")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_stream_w_collection_group(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + other = client.collection("dora") + + # Add two dummy responses to the minimal fake GAPIC. + _, other_prefix = other._parent_info() + name = "{}/bark".format(other_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + query._all_descendants = True + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + to_match = other.document("bark") + self.assertEqual(snapshot.reference._document_path, to_match._document_path) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + @mock.patch("google.cloud.firestore_v1.async_query.Watch", autospec=True) + def test_on_snapshot(self, watch): + query = self._make_one(mock.sentinel.parent) + query.on_snapshot(None) + watch.for_query.assert_called_once() + + +def _make_client(project="project-project"): + from google.cloud.firestore_v1.async_client import AsyncClient + + credentials = _make_credentials() + return AsyncClient(project=project, credentials=credentials) diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py new file mode 100644 index 000000000..b27f30e9c --- /dev/null +++ b/tests/unit/v1/async/test_async_transaction.py @@ -0,0 +1,1056 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import aiounittest +import mock + + +class TestAsyncTransaction(aiounittest.AsyncTestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + return AsyncTransaction + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor_defaults(self): + from google.cloud.firestore_v1.async_transaction import MAX_ATTEMPTS + + transaction = self._make_one(mock.sentinel.client) + self.assertIs(transaction._client, mock.sentinel.client) + self.assertEqual(transaction._write_pbs, []) + self.assertEqual(transaction._max_attempts, MAX_ATTEMPTS) + self.assertFalse(transaction._read_only) + self.assertIsNone(transaction._id) + + def test_constructor_explicit(self): + transaction = self._make_one( + mock.sentinel.client, max_attempts=10, read_only=True + ) + self.assertIs(transaction._client, mock.sentinel.client) + self.assertEqual(transaction._write_pbs, []) + self.assertEqual(transaction._max_attempts, 10) + self.assertTrue(transaction._read_only) + self.assertIsNone(transaction._id) + + def test__add_write_pbs_failure(self): + from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY + + batch = self._make_one(mock.sentinel.client, read_only=True) + self.assertEqual(batch._write_pbs, []) + with self.assertRaises(ValueError) as exc_info: + batch._add_write_pbs([mock.sentinel.write]) + + self.assertEqual(exc_info.exception.args, (_WRITE_READ_ONLY,)) + self.assertEqual(batch._write_pbs, []) + + def test__add_write_pbs(self): + batch = self._make_one(mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + batch._add_write_pbs([mock.sentinel.write]) + self.assertEqual(batch._write_pbs, [mock.sentinel.write]) + + def test__clean_up(self): + transaction = self._make_one(mock.sentinel.client) + transaction._write_pbs.extend( + [mock.sentinel.write_pb1, mock.sentinel.write_pb2] + ) + transaction._id = b"not-this-time-my-friend" + + ret_val = transaction._clean_up() + self.assertIsNone(ret_val) + + self.assertEqual(transaction._write_pbs, []) + self.assertIsNone(transaction._id) + + @pytest.mark.asyncio + async def test__begin(self): + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + from google.cloud.firestore_v1.types import firestore + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + txn_id = b"to-begin" + response = firestore.BeginTransactionResponse(transaction=txn_id) + firestore_api.begin_transaction.return_value = response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and ``begin()`` it. + transaction = self._make_one(client) + self.assertIsNone(transaction._id) + + ret_val = await transaction._begin() + self.assertIsNone(ret_val) + self.assertEqual(transaction._id, txn_id) + + # Verify the called mock. + firestore_api.begin_transaction.assert_called_once_with( + request={"database": client._database_string, "options": None}, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test__begin_failure(self): + from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN + + client = _make_client() + transaction = self._make_one(client) + transaction._id = b"not-none" + + with self.assertRaises(ValueError) as exc_info: + await transaction._begin() + + err_msg = _CANT_BEGIN.format(transaction._id) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + @pytest.mark.asyncio + async def test__rollback(self): + from google.protobuf import empty_pb2 + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + firestore_api.rollback.return_value = empty_pb2.Empty() + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = self._make_one(client) + txn_id = b"to-be-r\x00lled" + transaction._id = txn_id + ret_val = await transaction._rollback() + self.assertIsNone(ret_val) + self.assertIsNone(transaction._id) + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test__rollback_not_allowed(self): + from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK + + client = _make_client() + transaction = self._make_one(client) + self.assertIsNone(transaction._id) + + with self.assertRaises(ValueError) as exc_info: + await transaction._rollback() + + self.assertEqual(exc_info.exception.args, (_CANT_ROLLBACK,)) + + @pytest.mark.asyncio + async def test__rollback_failure(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + exc = exceptions.InternalServerError("Fire during rollback.") + firestore_api.rollback.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = self._make_one(client) + txn_id = b"roll-bad-server" + transaction._id = txn_id + + with self.assertRaises(exceptions.InternalServerError) as exc_info: + await transaction._rollback() + + self.assertIs(exc_info.exception, exc) + self.assertIsNone(transaction._id) + self.assertEqual(transaction._write_pbs, []) + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test__commit(self): + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client("phone-joe") + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = self._make_one(client) + txn_id = b"under-over-thru-woods" + transaction._id = txn_id + document = client.document("zap", "galaxy", "ship", "space") + transaction.set(document, {"apple": 4.5}) + write_pbs = transaction._write_pbs[::] + + write_results = await transaction._commit() + self.assertEqual(write_results, list(commit_response.write_results)) + # Make sure transaction has no more "changes". + self.assertIsNone(transaction._id) + self.assertEqual(transaction._write_pbs, []) + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test__commit_not_allowed(self): + from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT + + transaction = self._make_one(mock.sentinel.client) + self.assertIsNone(transaction._id) + with self.assertRaises(ValueError) as exc_info: + await transaction._commit() + + self.assertEqual(exc_info.exception.args, (_CANT_COMMIT,)) + + @pytest.mark.asyncio + async def test__commit_failure(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + exc = exceptions.InternalServerError("Fire during commit.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = self._make_one(client) + txn_id = b"beep-fail-commit" + transaction._id = txn_id + transaction.create(client.document("up", "down"), {"water": 1.0}) + transaction.delete(client.document("up", "left")) + write_pbs = transaction._write_pbs[::] + + with self.assertRaises(exceptions.InternalServerError) as exc_info: + await transaction._commit() + + self.assertIs(exc_info.exception, exc) + self.assertEqual(transaction._id, txn_id) + self.assertEqual(transaction._write_pbs, write_pbs) + + # Verify the called mock. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test_get_all(self): + client = mock.Mock(spec=["get_all"]) + transaction = self._make_one(client) + ref1, ref2 = mock.Mock(), mock.Mock() + result = await transaction.get_all([ref1, ref2]) + client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction) + self.assertIs(result, client.get_all.return_value) + + @pytest.mark.asyncio + async def test_get_document_ref(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + client = mock.Mock(spec=["get_all"]) + transaction = self._make_one(client) + ref = AsyncDocumentReference("documents", "doc-id") + result = await transaction.get(ref) + client.get_all.assert_called_once_with([ref], transaction=transaction) + self.assertIs(result, client.get_all.return_value) + + @pytest.mark.asyncio + async def test_get_w_query(self): + from google.cloud.firestore_v1.async_query import AsyncQuery + + client = mock.Mock(spec=[]) + transaction = self._make_one(client) + query = AsyncQuery(parent=mock.Mock(spec=[])) + query.stream = mock.MagicMock() + result = await transaction.get(query) + query.stream.assert_called_once_with(transaction=transaction) + self.assertIs(result, query.stream.return_value) + + @pytest.mark.asyncio + async def test_get_failure(self): + client = _make_client() + transaction = self._make_one(client) + ref_or_query = object() + with self.assertRaises(ValueError): + await transaction.get(ref_or_query) + + +class Test_Transactional(aiounittest.AsyncTestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_transaction import _AsyncTransactional + + return _AsyncTransactional + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + wrapped = self._make_one(mock.sentinel.callable_) + self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) + self.assertIsNone(wrapped.current_id) + self.assertIsNone(wrapped.retry_id) + + @pytest.mark.asyncio + async def test__pre_commit_success(self): + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"totes-began" + transaction = _make_transaction(txn_id) + result = await wrapped._pre_commit(transaction, "pos", key="word") + self.assertIs(result, mock.sentinel.result) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "pos", key="word") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": None, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + @pytest.mark.asyncio + async def test__pre_commit_retry_id_already_set_success(self): + from google.cloud.firestore_v1.types import common + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + txn_id1 = b"already-set" + wrapped.retry_id = txn_id1 + + txn_id2 = b"ok-here-too" + transaction = _make_transaction(txn_id2) + result = await wrapped._pre_commit(transaction) + self.assertIs(result, mock.sentinel.result) + + self.assertEqual(transaction._id, txn_id2) + self.assertEqual(wrapped.current_id, txn_id2) + self.assertEqual(wrapped.retry_id, txn_id1) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction) + firestore_api = transaction._client._firestore_api + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id1) + ) + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": options_, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + @pytest.mark.asyncio + async def test__pre_commit_failure(self): + exc = RuntimeError("Nope not today.") + to_wrap = mock.Mock(side_effect=exc, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"gotta-fail" + transaction = _make_transaction(txn_id) + with self.assertRaises(RuntimeError) as exc_info: + await wrapped._pre_commit(transaction, 10, 20) + self.assertIs(exc_info.exception, exc) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, 10, 20) + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": None, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + @pytest.mark.asyncio + async def test__pre_commit_failure_with_rollback_failure(self): + from google.api_core import exceptions + + exc1 = ValueError("I will not be only failure.") + to_wrap = mock.Mock(side_effect=exc1, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"both-will-fail" + transaction = _make_transaction(txn_id) + # Actually force the ``rollback`` to fail as well. + exc2 = exceptions.InternalServerError("Rollback blues.") + firestore_api = transaction._client._firestore_api + firestore_api.rollback.side_effect = exc2 + + # Try to ``_pre_commit`` + with self.assertRaises(exceptions.InternalServerError) as exc_info: + await wrapped._pre_commit(transaction, a="b", c="zebra") + self.assertIs(exc_info.exception, exc2) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, a="b", c="zebra") + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": None, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + @pytest.mark.asyncio + async def test__maybe_commit_success(self): + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"nyet" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + succeeded = await wrapped._maybe_commit(transaction) + self.assertTrue(succeeded) + + # On success, _id is reset. + self.assertIsNone(transaction._id) + + # Verify mocks. + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test__maybe_commit_failure_read_only(self): + from google.api_core import exceptions + + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"failed" + transaction = _make_transaction(txn_id, read_only=True) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail (use ABORTED, but cannot + # retry since read-only). + exc = exceptions.Aborted("Read-only did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with self.assertRaises(exceptions.Aborted) as exc_info: + await wrapped._maybe_commit(transaction) + self.assertIs(exc_info.exception, exc) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test__maybe_commit_failure_can_retry(self): + from google.api_core import exceptions + + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"failed-but-retry" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Read-write did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + succeeded = await wrapped._maybe_commit(transaction) + self.assertFalse(succeeded) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test__maybe_commit_failure_cannot_retry(self): + from google.api_core import exceptions + + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"failed-but-not-retryable" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.InternalServerError("Real bad thing") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with self.assertRaises(exceptions.InternalServerError) as exc_info: + await wrapped._maybe_commit(transaction) + self.assertIs(exc_info.exception, exc) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test___call__success_first_attempt(self): + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction(txn_id) + result = await wrapped(transaction, "a", b="c") + self.assertIs(result, mock.sentinel.result) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "a", b="c") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + @pytest.mark.asyncio + async def test___call__success_second_attempt(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction(txn_id) + + # Actually force the ``commit`` to fail on first / succeed on second. + exc = exceptions.Aborted("Contention junction.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = [ + exc, + firestore.CommitResponse(write_results=[write.WriteResult()]), + ] + + # Call the __call__-able ``wrapped``. + result = await wrapped(transaction, "a", b="c") + self.assertIs(result, mock.sentinel.result) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + wrapped_call = mock.call(transaction, "a", b="c") + self.assertEqual(to_wrap.mock_calls, [wrapped_call, wrapped_call]) + firestore_api = transaction._client._firestore_api + db_str = transaction._client._database_string + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + self.assertEqual( + firestore_api.begin_transaction.mock_calls, + [ + mock.call( + request={"database": db_str, "options": None}, + metadata=transaction._client._rpc_metadata, + ), + mock.call( + request={"database": db_str, "options": options_}, + metadata=transaction._client._rpc_metadata, + ), + ], + ) + firestore_api.rollback.assert_not_called() + commit_call = mock.call( + request={"database": db_str, "writes": [], "transaction": txn_id}, + metadata=transaction._client._rpc_metadata, + ) + self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) + + @pytest.mark.asyncio + async def test___call__failure(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import ( + _EXCEED_ATTEMPTS_TEMPLATE, + ) + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"only-one-shot" + transaction = _make_transaction(txn_id, max_attempts=1) + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Contention just once.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + # Call the __call__-able ``wrapped``. + with self.assertRaises(ValueError) as exc_info: + await wrapped(transaction, "here", there=1.5) + + err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": None, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +class Test_transactional(aiounittest.AsyncTestCase): + @staticmethod + def _call_fut(to_wrap): + from google.cloud.firestore_v1.async_transaction import transactional + + return transactional(to_wrap) + + def test_it(self): + from google.cloud.firestore_v1.async_transaction import _AsyncTransactional + + wrapped = self._call_fut(mock.sentinel.callable_) + self.assertIsInstance(wrapped, _AsyncTransactional) + self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) + + +class Test__commit_with_retry(aiounittest.AsyncTestCase): + @staticmethod + @pytest.mark.asyncio + async def _call_fut(client, write_pbs, transaction_id): + from google.cloud.firestore_v1.async_transaction import _commit_with_retry + + return await _commit_with_retry(client, write_pbs, transaction_id) + + @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") + @pytest.mark.asyncio + async def test_success_first_attempt(self, _sleep): + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + + # Attach the fake GAPIC to a real client. + client = _make_client("summer") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"cheeeeeez" + commit_response = await self._call_fut(client, mock.sentinel.write_pbs, txn_id) + self.assertIs(commit_response, firestore_api.commit.return_value) + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + @mock.patch( + "google.cloud.firestore_v1.async_transaction._sleep", side_effect=[2.0, 4.0] + ) + @pytest.mark.asyncio + async def test_success_third_attempt(self, _sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first two requests fail and the third succeeds. + firestore_api.commit.side_effect = [ + exceptions.ServiceUnavailable("Server sleepy."), + exceptions.ServiceUnavailable("Server groggy."), + mock.sentinel.commit_response, + ] + + # Attach the fake GAPIC to a real client. + client = _make_client("outside") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"the-world\x00" + commit_response = await self._call_fut(client, mock.sentinel.write_pbs, txn_id) + self.assertIs(commit_response, mock.sentinel.commit_response) + + # Verify mocks used. + # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds + self.assertEqual(_sleep.call_count, 2) + _sleep.assert_any_call(1.0) + _sleep.assert_any_call(2.0) + # commit() called same way 3 times. + commit_call = mock.call( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + self.assertEqual( + firestore_api.commit.mock_calls, [commit_call, commit_call, commit_call] + ) + + @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") + @pytest.mark.asyncio + async def test_failure_first_attempt(self, _sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first request fails with an un-retryable error. + exc = exceptions.ResourceExhausted("We ran out of fries.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" + with self.assertRaises(exceptions.ResourceExhausted) as exc_info: + await self._call_fut(client, mock.sentinel.write_pbs, txn_id) + + self.assertIs(exc_info.exception, exc) + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + @mock.patch("google.cloud.firestore_v1.async_transaction._sleep", return_value=2.0) + @pytest.mark.asyncio + async def test_failure_second_attempt(self, _sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first request fails retry-able and second + # fails non-retryable. + exc1 = exceptions.ServiceUnavailable("Come back next time.") + exc2 = exceptions.InternalServerError("Server on fritz.") + firestore_api.commit.side_effect = [exc1, exc2] + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"the-journey-when-and-where-well-go" + with self.assertRaises(exceptions.InternalServerError) as exc_info: + await self._call_fut(client, mock.sentinel.write_pbs, txn_id) + + self.assertIs(exc_info.exception, exc2) + + # Verify mocks used. + _sleep.assert_called_once_with(1.0) + # commit() called same way 2 times. + commit_call = mock.call( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) + + +class Test__sleep(aiounittest.AsyncTestCase): + @staticmethod + @pytest.mark.asyncio + async def _call_fut(current_sleep, **kwargs): + from google.cloud.firestore_v1.async_transaction import _sleep + + return await _sleep(current_sleep, **kwargs) + + @mock.patch("random.uniform", return_value=5.5) + @mock.patch("asyncio.sleep", return_value=None) + @pytest.mark.asyncio + async def test_defaults(self, sleep, uniform): + curr_sleep = 10.0 + self.assertLessEqual(uniform.return_value, curr_sleep) + + new_sleep = await self._call_fut(curr_sleep) + self.assertEqual(new_sleep, 2.0 * curr_sleep) + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + @mock.patch("random.uniform", return_value=10.5) + @mock.patch("asyncio.sleep", return_value=None) + @pytest.mark.asyncio + async def test_explicit(self, sleep, uniform): + curr_sleep = 12.25 + self.assertLessEqual(uniform.return_value, curr_sleep) + + multiplier = 1.5 + new_sleep = await self._call_fut( + curr_sleep, max_sleep=100.0, multiplier=multiplier + ) + self.assertEqual(new_sleep, multiplier * curr_sleep) + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + @mock.patch("random.uniform", return_value=6.75) + @mock.patch("asyncio.sleep", return_value=None) + @pytest.mark.asyncio + async def test_exceeds_max(self, sleep, uniform): + curr_sleep = 20.0 + self.assertLessEqual(uniform.return_value, curr_sleep) + + max_sleep = 38.5 + new_sleep = await self._call_fut( + curr_sleep, max_sleep=max_sleep, multiplier=2.0 + ) + self.assertEqual(new_sleep, max_sleep) + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="feral-tom-cat"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) + + +def _make_transaction(txn_id, **txn_kwargs): + from google.protobuf import empty_pb2 + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + # Create a fake GAPIC ... + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # ... with a dummy ``BeginTransactionResponse`` result ... + begin_response = firestore.BeginTransactionResponse(transaction=txn_id) + firestore_api.begin_transaction.return_value = begin_response + # ... and a dummy ``Rollback`` result ... + firestore_api.rollback.return_value = empty_pb2.Empty() + # ... and a dummy ``Commit`` result. + commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + return AsyncTransaction(client, **txn_kwargs) diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index e8ab7a267..5396540c6 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -133,9 +133,10 @@ def test_as_context_mgr_w_error(self): ctx_mgr.delete(document2) raise RuntimeError("testing") + # batch still has its changes, as _exit_ (and commit) is not invoked + # changes are preserved so commit can be retried self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) - # batch still has its changes self.assertEqual(len(batch._write_pbs), 2) firestore_api.commit.assert_not_called() diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 8aa5f41d4..433fcadfa 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -129,11 +129,13 @@ def test_collection_group(self): client = self._make_default_one() query = client.collection_group("collectionId").where("foo", "==", u"bar") - assert query._all_descendants - assert query._field_filters[0].field.field_path == "foo" - assert query._field_filters[0].value.string_value == u"bar" - assert query._field_filters[0].op == query._field_filters[0].Operator.EQUAL - assert query._parent.id == "collectionId" + self.assertTrue(query._all_descendants) + self.assertEqual(query._field_filters[0].field.field_path, "foo") + self.assertEqual(query._field_filters[0].value.string_value, u"bar") + self.assertEqual( + query._field_filters[0].op, query._field_filters[0].Operator.EQUAL + ) + self.assertEqual(query._parent.id, "collectionId") def test_collection_group_no_slashes(self): client = self._make_default_one() diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index e4c838992..a32e58c10 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -831,6 +831,7 @@ def test_success_third_attempt(self, _sleep): self.assertIs(commit_response, mock.sentinel.commit_response) # Verify mocks used. + # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds self.assertEqual(_sleep.call_count, 2) _sleep.assert_any_call(1.0) _sleep.assert_any_call(2.0)