diff --git a/google/cloud/firestore.py b/google/cloud/firestore.py index 545b31b18..4c5cb3fe2 100644 --- a/google/cloud/firestore.py +++ b/google/cloud/firestore.py @@ -18,6 +18,13 @@ from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import ArrayRemove from google.cloud.firestore_v1 import ArrayUnion +from google.cloud.firestore_v1 import AsyncClient +from google.cloud.firestore_v1 import AsyncCollectionReference +from google.cloud.firestore_v1 import AsyncDocumentReference +from google.cloud.firestore_v1 import AsyncQuery +from google.cloud.firestore_v1 import async_transactional +from google.cloud.firestore_v1 import AsyncTransaction +from google.cloud.firestore_v1 import AsyncWriteBatch from google.cloud.firestore_v1 import Client from google.cloud.firestore_v1 import CollectionReference from google.cloud.firestore_v1 import DELETE_FIELD @@ -45,6 +52,13 @@ "__version__", "ArrayRemove", "ArrayUnion", + "AsyncClient", + "AsyncCollectionReference", + "AsyncDocumentReference", + "AsyncQuery", + "async_transactional", + "AsyncTransaction", + "AsyncWriteBatch", "Client", "CollectionReference", "DELETE_FIELD", diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index 5b96029a1..74652de3e 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -29,9 +29,21 @@ from google.cloud.firestore_v1._helpers import LastUpdateOption from google.cloud.firestore_v1._helpers import ReadAfterWriteError from google.cloud.firestore_v1._helpers import WriteOption +from google.cloud.firestore_v1.async_batch import AsyncWriteBatch +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1.async_collection import AsyncCollectionReference +from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_query import AsyncQuery +from google.cloud.firestore_v1.async_transaction import async_transactional +from google.cloud.firestore_v1.async_transaction import AsyncTransaction +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.collection import CollectionReference +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.query import Query +from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.transaction import transactional from google.cloud.firestore_v1.transforms import ArrayRemove from google.cloud.firestore_v1.transforms import ArrayUnion from google.cloud.firestore_v1.transforms import DELETE_FIELD @@ -39,11 +51,6 @@ from google.cloud.firestore_v1.transforms import Maximum from google.cloud.firestore_v1.transforms import Minimum from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP -from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1.document import DocumentSnapshot -from google.cloud.firestore_v1.query import Query -from google.cloud.firestore_v1.transaction import Transaction -from google.cloud.firestore_v1.transaction import transactional from google.cloud.firestore_v1.watch import Watch @@ -100,6 +107,13 @@ "__version__", "ArrayRemove", "ArrayUnion", + "AsyncClient", + "AsyncCollectionReference", + "AsyncDocumentReference", + "AsyncQuery", + "async_transactional", + "AsyncTransaction", + "AsyncWriteBatch", "Client", "CollectionReference", "DELETE_FIELD", diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index f37b28ddc..e6e9656ae 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -242,7 +242,7 @@ async def get_all(self, references, field_paths=None, transaction=None): """ document_paths, reference_map = _reference_info(references) mask = _get_doc_mask(field_paths) - response_iterator = self._firestore_api.batch_get_documents( + response_iterator = await self._firestore_api.batch_get_documents( request={ "database": self._database_string, "documents": document_paths, diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 70676360e..95967b294 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -22,8 +22,6 @@ _item_to_document_ref, ) from google.cloud.firestore_v1 import async_query -from google.cloud.firestore_v1.watch import Watch -from google.cloud.firestore_v1 import async_document class AsyncCollectionReference(BaseCollectionReference): @@ -119,7 +117,8 @@ async def list_documents(self, page_size=None): }, metadata=self._client._rpc_metadata, ) - return (_item_to_document_ref(self, i) for i in iterator) + async for i in iterator: + yield _item_to_document_ref(self, i) async def get(self, transaction=None): """Deprecated alias for :meth:`stream`.""" @@ -161,36 +160,3 @@ async def stream(self, transaction=None): query = async_query.AsyncQuery(self) async for d in query.stream(transaction=transaction): yield d - - def on_snapshot(self, callback): - """Monitor the documents in this collection. - - This starts a watch on this collection using a background thread. The - provided callback is run on the snapshot of the documents. - - Args: - callback (Callable[[:class:`~google.cloud.firestore.collection.CollectionSnapshot`], NoneType]): - a callback to run when a change occurs. - - Example: - from google.cloud import firestore_v1 - - db = firestore_v1.Client() - collection_ref = db.collection(u'users') - - def on_snapshot(collection_snapshot, changes, read_time): - for doc in collection_snapshot.documents: - print(u'{} => {}'.format(doc.id, doc.to_dict())) - - # Watch this collection - collection_watch = collection_ref.on_snapshot(on_snapshot) - - # Terminate this watch - collection_watch.unsubscribe() - """ - return Watch.for_query( - self._query(), - callback, - async_document.DocumentSnapshot, - async_document.AsyncDocumentReference, - ) diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 0b7c3bfd3..a36d8894a 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -23,7 +23,6 @@ from google.api_core import exceptions from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common -from google.cloud.firestore_v1.watch import Watch class AsyncDocumentReference(BaseDocumentReference): @@ -385,39 +384,3 @@ async def collections(self, page_size=None): # iterator.document = self # iterator.item_to_value = _item_to_collection_ref # return iterator - - def on_snapshot(self, callback): - """Watch this document. - - This starts a watch on this document using a background thread. The - provided callback is run on the snapshot. - - Args: - callback(Callable[[:class:`~google.cloud.firestore.document.DocumentSnapshot`], NoneType]): - a callback to run when a change occurs - - Example: - - .. code-block:: python - - from google.cloud import firestore_v1 - - db = firestore_v1.Client() - collection_ref = db.collection(u'users') - - def on_snapshot(document_snapshot, changes, read_time): - doc = document_snapshot - print(u'{} => {}'.format(doc.id, doc.to_dict())) - - doc_ref = db.collection(u'users').document( - u'alovelace' + unique_resource_id()) - - # Watch this document - doc_watch = doc_ref.on_snapshot(on_snapshot) - - # Terminate this watch - doc_watch.unsubscribe() - """ - return Watch.for_document( - self, callback, DocumentSnapshot, AsyncDocumentReference - ) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index a4a46d6ec..14e17e71a 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -27,8 +27,6 @@ ) from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import async_document -from google.cloud.firestore_v1.watch import Watch class AsyncQuery(BaseQuery): @@ -149,7 +147,7 @@ async def stream(self, transaction=None): The next document that fulfills the query. """ parent_path, expected_prefix = self._parent._parent_info() - response_iterator = self._client._firestore_api.run_query( + response_iterator = await self._client._firestore_api.run_query( request={ "parent": parent_path, "structured_query": self._to_protobuf(), @@ -169,39 +167,3 @@ async def stream(self, transaction=None): ) if snapshot is not None: yield snapshot - - def on_snapshot(self, callback): - """Monitor the documents in this collection that match this query. - - This starts a watch on this query using a background thread. The - provided callback is run on the snapshot of the documents. - - Args: - callback(Callable[[:class:`~google.cloud.firestore.query.QuerySnapshot`], NoneType]): - a callback to run when a change occurs. - - Example: - - .. code-block:: python - - from google.cloud import firestore_v1 - - db = firestore_v1.Client() - query_ref = db.collection(u'users').where("user", "==", u'Ada') - - def on_snapshot(docs, changes, read_time): - for doc in docs: - print(u'{} => {}'.format(doc.id, doc.to_dict())) - - # Watch this query - query_watch = query_ref.on_snapshot(on_snapshot) - - # Terminate this watch - query_watch.unsubscribe() - """ - return Watch.for_query( - self, - callback, - async_document.DocumentSnapshot, - async_document.AsyncDocumentReference, - ) diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 0b1e83788..33a81a292 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -287,7 +287,7 @@ async def __call__(self, transaction, *args, **kwargs): raise ValueError(msg) -def transactional(to_wrap): +def async_transactional(to_wrap): """Decorate a callable so that it runs in a transaction. Args: diff --git a/noxfile.py b/noxfile.py index fff963ae9..55f2da88e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -124,7 +124,7 @@ def system(session): # Install all test dependencies, then install this package into the # virtualenv's dist-packages. session.install( - "mock", "pytest", "google-cloud-testutils", + "mock", "pytest", "pytest-asyncio", "google-cloud-testutils", ) session.install("-e", ".") diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py new file mode 100644 index 000000000..c114efaf3 --- /dev/null +++ b/tests/system/test__helpers.py @@ -0,0 +1,10 @@ +import os +import re +from test_utils.system import unique_resource_id + +FIRESTORE_CREDS = os.environ.get("FIRESTORE_APPLICATION_CREDENTIALS") +FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") +RANDOM_ID_REGEX = re.compile("^[a-zA-Z0-9]{20}$") +MISSING_DOCUMENT = "No document to update: " +DOCUMENT_EXISTS = "Document already exists: " +UNIQUE_RESOURCE_ID = unique_resource_id("-") diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 4800014da..15efa81e6 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -15,8 +15,6 @@ import datetime import math import operator -import os -import re from google.oauth2 import service_account import pytest @@ -28,16 +26,16 @@ from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud._helpers import UTC from google.cloud import firestore_v1 as firestore -from test_utils.system import unique_resource_id from time import sleep -FIRESTORE_CREDS = os.environ.get("FIRESTORE_APPLICATION_CREDENTIALS") -FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") -RANDOM_ID_REGEX = re.compile("^[a-zA-Z0-9]{20}$") -MISSING_DOCUMENT = "No document to update: " -DOCUMENT_EXISTS = "Document already exists: " -UNIQUE_RESOURCE_ID = unique_resource_id("-") +from tests.system.test__helpers import ( + FIRESTORE_CREDS, + FIRESTORE_PROJECT, + RANDOM_ID_REGEX, + MISSING_DOCUMENT, + UNIQUE_RESOURCE_ID, +) @pytest.fixture(scope=u"module") @@ -683,7 +681,7 @@ def test_query_stream_w_offset(query_docs): def test_query_with_order_dot_key(client, cleanup): db = client - collection_id = "collek" + unique_resource_id("-") + collection_id = "collek" + UNIQUE_RESOURCE_ID collection = db.collection(collection_id) for index in range(100, -1, -1): doc = collection.document("test_{:09d}".format(index)) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py new file mode 100644 index 000000000..4dfe36a87 --- /dev/null +++ b/tests/system/test_system_async.py @@ -0,0 +1,998 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import math +import pytest +import operator + +from google.oauth2 import service_account + +from google.api_core.exceptions import AlreadyExists +from google.api_core.exceptions import FailedPrecondition +from google.api_core.exceptions import InvalidArgument +from google.api_core.exceptions import NotFound +from google.cloud._helpers import _datetime_to_pb_timestamp +from google.cloud._helpers import UTC +from google.cloud import firestore_v1 as firestore + +from tests.system.test__helpers import ( + FIRESTORE_CREDS, + FIRESTORE_PROJECT, + RANDOM_ID_REGEX, + MISSING_DOCUMENT, + UNIQUE_RESOURCE_ID, +) + +_test_event_loop = asyncio.new_event_loop() +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(scope=u"module") +def client(): + credentials = service_account.Credentials.from_service_account_file(FIRESTORE_CREDS) + project = FIRESTORE_PROJECT or credentials.project_id + yield firestore.AsyncClient(project=project, credentials=credentials) + + +@pytest.fixture +async def cleanup(): + operations = [] + yield operations.append + + for operation in operations: + await operation() + + +@pytest.fixture +def event_loop(): + asyncio.set_event_loop(_test_event_loop) + return asyncio.get_event_loop() + + +async def test_collections(client): + collections = [x async for x in client.collections()] + assert isinstance(collections, list) + + +async def test_collections_w_import(): + from google.cloud import firestore + + client = firestore.AsyncClient() + collections = [x async for x in client.collections()] + + assert isinstance(collections, list) + + +async def test_create_document(client, cleanup): + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + document_id = "doc" + UNIQUE_RESOURCE_ID + document = client.document(collection_id, document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + data = { + "now": firestore.SERVER_TIMESTAMP, + "eenta-ger": 11, + "bites": b"\xe2\x98\x83 \xe2\x9b\xb5", + "also": {"nestednow": firestore.SERVER_TIMESTAMP, "quarter": 0.25}, + } + write_result = await document.create(data) + + updated = write_result.update_time + delta = updated - now + # Allow a bit of clock skew, but make sure timestamps are close. + assert -300.0 < delta.total_seconds() < 300.0 + + with pytest.raises(AlreadyExists): + await document.create(data) + + # Verify the server times. + snapshot = await document.get() + stored_data = snapshot.to_dict() + server_now = stored_data["now"] + + delta = updated - server_now + # NOTE: We could check the ``transform_results`` from the write result + # for the document transform, but this value gets dropped. Instead + # we make sure the timestamps are close. + # TODO(microgen): this was 0.0 - 5.0 before. After microgen, This started + # getting very small negative times. + assert -0.2 <= delta.total_seconds() < 5.0 + expected_data = { + "now": server_now, + "eenta-ger": data["eenta-ger"], + "bites": data["bites"], + "also": {"nestednow": server_now, "quarter": data["also"]["quarter"]}, + } + assert stored_data == expected_data + + +async def test_create_document_w_subcollection(client, cleanup): + collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID + document_id = "doc" + UNIQUE_RESOURCE_ID + document = client.document(collection_id, document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + data = {"now": firestore.SERVER_TIMESTAMP} + await document.create(data) + + child_ids = ["child1", "child2"] + + for child_id in child_ids: + subcollection = document.collection(child_id) + _, subdoc = await subcollection.add({"foo": "bar"}) + cleanup(subdoc.delete) + + children = document.collections() + assert sorted([child.id async for child in children]) == sorted(child_ids) + + +async def test_cannot_use_foreign_key(client, cleanup): + document_id = "cannot" + UNIQUE_RESOURCE_ID + document = client.document("foreign-key", document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + other_client = firestore.Client( + project="other-prahj", credentials=client._credentials, database="dee-bee" + ) + assert other_client._database_string != client._database_string + fake_doc = other_client.document("foo", "bar") + with pytest.raises(InvalidArgument): + await document.create({"ref": fake_doc}) + + +def assert_timestamp_less(timestamp_pb1, timestamp_pb2): + assert timestamp_pb1 < timestamp_pb2 + + +async def test_no_document(client): + document_id = "no_document" + UNIQUE_RESOURCE_ID + document = client.document("abcde", document_id) + snapshot = await document.get() + assert snapshot.to_dict() is None + + +async def test_document_set(client, cleanup): + document_id = "for-set" + UNIQUE_RESOURCE_ID + document = client.document("i-did-it", document_id) + # Add to clean-up before API request (in case ``set()`` fails). + cleanup(document.delete) + + # 0. Make sure the document doesn't exist yet + snapshot = await document.get() + assert snapshot.to_dict() is None + + # 1. Use ``create()`` to create the document. + data1 = {"foo": 88} + write_result1 = await document.create(data1) + snapshot1 = await document.get() + assert snapshot1.to_dict() == data1 + # Make sure the update is what created the document. + assert snapshot1.create_time == snapshot1.update_time + assert snapshot1.update_time == write_result1.update_time + + # 2. Call ``set()`` again to overwrite. + data2 = {"bar": None} + write_result2 = await document.set(data2) + snapshot2 = await document.get() + assert snapshot2.to_dict() == data2 + # Make sure the create time hasn't changed. + assert snapshot2.create_time == snapshot1.create_time + assert snapshot2.update_time == write_result2.update_time + + +async def test_document_integer_field(client, cleanup): + document_id = "for-set" + UNIQUE_RESOURCE_ID + document = client.document("i-did-it", document_id) + # Add to clean-up before API request (in case ``set()`` fails). + cleanup(document.delete) + + data1 = {"1a": {"2b": "3c", "ab": "5e"}, "6f": {"7g": "8h", "cd": "0j"}} + await document.create(data1) + + data2 = {"1a.ab": "4d", "6f.7g": "9h"} + await document.update(data2) + snapshot = await document.get() + expected = {"1a": {"2b": "3c", "ab": "4d"}, "6f": {"7g": "9h", "cd": "0j"}} + assert snapshot.to_dict() == expected + + +async def test_document_set_merge(client, cleanup): + document_id = "for-set" + UNIQUE_RESOURCE_ID + document = client.document("i-did-it", document_id) + # Add to clean-up before API request (in case ``set()`` fails). + cleanup(document.delete) + + # 0. Make sure the document doesn't exist yet + snapshot = await document.get() + assert not snapshot.exists + + # 1. Use ``create()`` to create the document. + data1 = {"name": "Sam", "address": {"city": "SF", "state": "CA"}} + write_result1 = await document.create(data1) + snapshot1 = await document.get() + assert snapshot1.to_dict() == data1 + # Make sure the update is what created the document. + assert snapshot1.create_time == snapshot1.update_time + assert snapshot1.update_time == write_result1.update_time + + # 2. Call ``set()`` to merge + data2 = {"address": {"city": "LA"}} + write_result2 = await document.set(data2, merge=True) + snapshot2 = await document.get() + assert snapshot2.to_dict() == { + "name": "Sam", + "address": {"city": "LA", "state": "CA"}, + } + # Make sure the create time hasn't changed. + assert snapshot2.create_time == snapshot1.create_time + assert snapshot2.update_time == write_result2.update_time + + +async def test_document_set_w_int_field(client, cleanup): + document_id = "set-int-key" + UNIQUE_RESOURCE_ID + document = client.document("i-did-it", document_id) + # Add to clean-up before API request (in case ``set()`` fails). + cleanup(document.delete) + + # 0. Make sure the document doesn't exist yet + snapshot = await document.get() + assert not snapshot.exists + + # 1. Use ``create()`` to create the document. + before = {"testing": "1"} + await document.create(before) + + # 2. Replace using ``set()``. + data = {"14": {"status": "active"}} + await document.set(data) + + # 3. Verify replaced data. + snapshot1 = await document.get() + assert snapshot1.to_dict() == data + + +async def test_document_update_w_int_field(client, cleanup): + # Attempt to reproduce #5489. + document_id = "update-int-key" + UNIQUE_RESOURCE_ID + document = client.document("i-did-it", document_id) + # Add to clean-up before API request (in case ``set()`` fails). + cleanup(document.delete) + + # 0. Make sure the document doesn't exist yet + snapshot = await document.get() + assert not snapshot.exists + + # 1. Use ``create()`` to create the document. + before = {"testing": "1"} + await document.create(before) + + # 2. Add values using ``update()``. + data = {"14": {"status": "active"}} + await document.update(data) + + # 3. Verify updated data. + expected = before.copy() + expected.update(data) + snapshot1 = await document.get() + assert snapshot1.to_dict() == expected + + +async def test_update_document(client, cleanup): + document_id = "for-update" + UNIQUE_RESOURCE_ID + document = client.document("made", document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + # 0. Try to update before the document exists. + with pytest.raises(NotFound) as exc_info: + await document.update({"not": "there"}) + assert exc_info.value.message.startswith(MISSING_DOCUMENT) + assert document_id in exc_info.value.message + + # 1. Try to update before the document exists (now with an option). + with pytest.raises(NotFound) as exc_info: + await document.update({"still": "not-there"}) + assert exc_info.value.message.startswith(MISSING_DOCUMENT) + assert document_id in exc_info.value.message + + # 2. Update and create the document (with an option). + data = {"foo": {"bar": "baz"}, "scoop": {"barn": 981}, "other": True} + write_result2 = await document.create(data) + + # 3. Send an update without a field path (no option). + field_updates3 = {"foo": {"quux": 800}} + write_result3 = await document.update(field_updates3) + assert_timestamp_less(write_result2.update_time, write_result3.update_time) + snapshot3 = await document.get() + expected3 = { + "foo": field_updates3["foo"], + "scoop": data["scoop"], + "other": data["other"], + } + assert snapshot3.to_dict() == expected3 + + # 4. Send an update **with** a field path and a delete and a valid + # "last timestamp" option. + field_updates4 = {"scoop.silo": None, "other": firestore.DELETE_FIELD} + option4 = client.write_option(last_update_time=snapshot3.update_time) + write_result4 = await document.update(field_updates4, option=option4) + assert_timestamp_less(write_result3.update_time, write_result4.update_time) + snapshot4 = await document.get() + expected4 = { + "foo": field_updates3["foo"], + "scoop": {"barn": data["scoop"]["barn"], "silo": field_updates4["scoop.silo"]}, + } + assert snapshot4.to_dict() == expected4 + + # 5. Call ``update()`` with invalid (in the past) "last timestamp" option. + assert_timestamp_less(option4._last_update_time, snapshot4.update_time) + with pytest.raises(FailedPrecondition) as exc_info: + await document.update({"bad": "time-past"}, option=option4) + + # 6. Call ``update()`` with invalid (in future) "last timestamp" option. + # TODO(microgen): start using custom datetime with nanos in protoplus? + timestamp_pb = _datetime_to_pb_timestamp(snapshot4.update_time) + timestamp_pb.seconds += 3600 + + option6 = client.write_option(last_update_time=timestamp_pb) + # TODO(microgen):invalid argument thrown after microgen. + # with pytest.raises(FailedPrecondition) as exc_info: + with pytest.raises(InvalidArgument) as exc_info: + await document.update({"bad": "time-future"}, option=option6) + + +def check_snapshot(snapshot, document, data, write_result): + assert snapshot.reference is document + assert snapshot.to_dict() == data + assert snapshot.exists + assert snapshot.create_time == write_result.update_time + assert snapshot.update_time == write_result.update_time + + +async def test_document_get(client, cleanup): + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + document_id = "for-get" + UNIQUE_RESOURCE_ID + document = client.document("created", document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + # First make sure it doesn't exist. + assert not (await document.get()).exists + + ref_doc = client.document("top", "middle1", "middle2", "bottom") + data = { + "turtle": "power", + "cheese": 19.5, + "fire": 199099299, + "referee": ref_doc, + "gio": firestore.GeoPoint(45.5, 90.0), + "deep": [u"some", b"\xde\xad\xbe\xef"], + "map": {"ice": True, "water": None, "vapor": {"deeper": now}}, + } + write_result = await document.create(data) + snapshot = await document.get() + check_snapshot(snapshot, document, data, write_result) + + +async def test_document_delete(client, cleanup): + document_id = "deleted" + UNIQUE_RESOURCE_ID + document = client.document("here-to-be", document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + await document.create({"not": "much"}) + + # 1. Call ``delete()`` with invalid (in the past) "last timestamp" option. + snapshot1 = await document.get() + timestamp_pb = _datetime_to_pb_timestamp(snapshot1.update_time) + timestamp_pb.seconds += 3600 + + option1 = client.write_option(last_update_time=timestamp_pb) + # TODO(microgen):invalid argument thrown after microgen. + # with pytest.raises(FailedPrecondition): + with pytest.raises(InvalidArgument): + await document.delete(option=option1) + + # 2. Call ``delete()`` with invalid (in future) "last timestamp" option. + timestamp_pb = _datetime_to_pb_timestamp(snapshot1.update_time) + timestamp_pb.seconds += 3600 + + option2 = client.write_option(last_update_time=timestamp_pb) + # TODO(microgen):invalid argument thrown after microgen. + # with pytest.raises(FailedPrecondition): + with pytest.raises(InvalidArgument): + await document.delete(option=option2) + + # 3. Actually ``delete()`` the document. + delete_time3 = await document.delete() + + # 4. ``delete()`` again, even though we know the document is gone. + delete_time4 = await document.delete() + assert_timestamp_less(delete_time3, delete_time4) + + +async def test_collection_add(client, cleanup): + # TODO(microgen): list_documents is returning a generator, not a list. + # Consider if this is desired. Also, Document isn't hashable. + collection_id = "coll-add" + UNIQUE_RESOURCE_ID + collection1 = client.collection(collection_id) + collection2 = client.collection(collection_id, "doc", "child") + collection3 = client.collection(collection_id, "table", "child") + explicit_doc_id = "hula" + UNIQUE_RESOURCE_ID + + assert set([i async for i in collection1.list_documents()]) == set() + assert set([i async for i in collection2.list_documents()]) == set() + assert set([i async for i in collection3.list_documents()]) == set() + + # Auto-ID at top-level. + data1 = {"foo": "bar"} + update_time1, document_ref1 = await collection1.add(data1) + cleanup(document_ref1.delete) + assert set([i async for i in collection1.list_documents()]) == {document_ref1} + assert set([i async for i in collection2.list_documents()]) == set() + assert set([i async for i in collection3.list_documents()]) == set() + snapshot1 = await document_ref1.get() + assert snapshot1.to_dict() == data1 + assert snapshot1.update_time == update_time1 + assert RANDOM_ID_REGEX.match(document_ref1.id) + + # Explicit ID at top-level. + data2 = {"baz": 999} + update_time2, document_ref2 = await collection1.add( + data2, document_id=explicit_doc_id + ) + cleanup(document_ref2.delete) + assert set([i async for i in collection1.list_documents()]) == { + document_ref1, + document_ref2, + } + assert set([i async for i in collection2.list_documents()]) == set() + assert set([i async for i in collection3.list_documents()]) == set() + snapshot2 = await document_ref2.get() + assert snapshot2.to_dict() == data2 + assert snapshot2.create_time == update_time2 + assert snapshot2.update_time == update_time2 + assert document_ref2.id == explicit_doc_id + + nested_ref = collection1.document("doc") + + # Auto-ID for nested collection. + data3 = {"quux": b"\x00\x01\x02\x03"} + update_time3, document_ref3 = await collection2.add(data3) + cleanup(document_ref3.delete) + assert set([i async for i in collection1.list_documents()]) == { + document_ref1, + document_ref2, + nested_ref, + } + assert set([i async for i in collection2.list_documents()]) == {document_ref3} + assert set([i async for i in collection3.list_documents()]) == set() + snapshot3 = await document_ref3.get() + assert snapshot3.to_dict() == data3 + assert snapshot3.update_time == update_time3 + assert RANDOM_ID_REGEX.match(document_ref3.id) + + # Explicit for nested collection. + data4 = {"kazaam": None, "bad": False} + update_time4, document_ref4 = await collection2.add( + data4, document_id=explicit_doc_id + ) + cleanup(document_ref4.delete) + assert set([i async for i in collection1.list_documents()]) == { + document_ref1, + document_ref2, + nested_ref, + } + assert set([i async for i in collection2.list_documents()]) == { + document_ref3, + document_ref4, + } + assert set([i async for i in collection3.list_documents()]) == set() + snapshot4 = await document_ref4.get() + assert snapshot4.to_dict() == data4 + assert snapshot4.create_time == update_time4 + assert snapshot4.update_time == update_time4 + assert document_ref4.id == explicit_doc_id + + # Exercise "missing" document (no doc, but subcollection). + data5 = {"bam": 123, "folyk": False} + update_time5, document_ref5 = await collection3.add(data5) + cleanup(document_ref5.delete) + missing_ref = collection1.document("table") + assert set([i async for i in collection1.list_documents()]) == { + document_ref1, + document_ref2, + nested_ref, + missing_ref, + } + assert set([i async for i in collection2.list_documents()]) == { + document_ref3, + document_ref4, + } + assert set([i async for i in collection3.list_documents()]) == {document_ref5} + + +@pytest.fixture +async def query_docs(client): + collection_id = "qs" + UNIQUE_RESOURCE_ID + sub_collection = "child" + UNIQUE_RESOURCE_ID + collection = client.collection(collection_id, "doc", sub_collection) + + cleanup = [] + stored = {} + num_vals = 5 + allowed_vals = range(num_vals) + for a_val in allowed_vals: + for b_val in allowed_vals: + document_data = { + "a": a_val, + "b": b_val, + "c": [a_val, num_vals * 100], + "stats": {"sum": a_val + b_val, "product": a_val * b_val}, + } + _, doc_ref = await collection.add(document_data) + # Add to clean-up. + cleanup.append(doc_ref.delete) + stored[doc_ref.id] = document_data + + yield collection, stored, allowed_vals + + for operation in cleanup: + await operation() + + +async def test_query_stream_w_simple_field_eq_op(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("a", "==", 1) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in values.items(): + assert stored[key] == value + assert value["a"] == 1 + + +async def test_query_stream_w_simple_field_array_contains_op(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("c", "array_contains", 1) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in values.items(): + assert stored[key] == value + assert value["a"] == 1 + + +async def test_query_stream_w_simple_field_in_op(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("a", "in", [1, num_vals + 100]) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in values.items(): + assert stored[key] == value + assert value["a"] == 1 + + +async def test_query_stream_w_simple_field_array_contains_any_op(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("c", "array_contains_any", [1, num_vals * 200]) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in values.items(): + assert stored[key] == value + assert value["a"] == 1 + + +async def test_query_stream_w_order_by(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.order_by("b", direction=firestore.Query.DESCENDING) + values = [(snapshot.id, snapshot.to_dict()) async for snapshot in query.stream()] + assert len(values) == len(stored) + b_vals = [] + for key, value in values: + assert stored[key] == value + b_vals.append(value["b"]) + # Make sure the ``b``-values are in DESCENDING order. + assert sorted(b_vals, reverse=True) == b_vals + + +async def test_query_stream_w_field_path(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("stats.sum", ">", 4) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(values) == 10 + ab_pairs2 = set() + for key, value in values.items(): + assert stored[key] == value + ab_pairs2.add((value["a"], value["b"])) + + expected_ab_pairs = set( + [ + (a_val, b_val) + for a_val in allowed_vals + for b_val in allowed_vals + if a_val + b_val > 4 + ] + ) + assert expected_ab_pairs == ab_pairs2 + + +async def test_query_stream_w_start_end_cursor(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = ( + collection.order_by("a") + .start_at({"a": num_vals - 2}) + .end_before({"a": num_vals - 1}) + ) + values = [(snapshot.id, snapshot.to_dict()) async for snapshot in query.stream()] + assert len(values) == num_vals + for key, value in values: + assert stored[key] == value + assert value["a"] == num_vals - 2 + + +async def test_query_stream_wo_results(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("b", "==", num_vals + 100) + values = [i async for i in query.stream()] + assert len(values) == 0 + + +async def test_query_stream_w_projection(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("b", "<=", 1).select(["a", "stats.product"]) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(values) == num_vals * 2 # a ANY, b in (0, 1) + for key, value in values.items(): + expected = { + "a": stored[key]["a"], + "stats": {"product": stored[key]["stats"]["product"]}, + } + assert expected == value + + +async def test_query_stream_w_multiple_filters(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("stats.product", ">", 5).where("stats.product", "<", 10) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + matching_pairs = [ + (a_val, b_val) + for a_val in allowed_vals + for b_val in allowed_vals + if 5 < a_val * b_val < 10 + ] + assert len(values) == len(matching_pairs) + for key, value in values.items(): + assert stored[key] == value + pair = (value["a"], value["b"]) + assert pair in matching_pairs + + +async def test_query_stream_w_offset(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + offset = 3 + query = collection.where("b", "==", 2).offset(offset) + values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + # NOTE: We don't check the ``a``-values, since that would require + # an ``order_by('a')``, which combined with the ``b == 2`` + # filter would necessitate an index. + assert len(values) == num_vals - offset + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 2 + + +async def test_query_with_order_dot_key(client, cleanup): + db = client + collection_id = "collek" + UNIQUE_RESOURCE_ID + collection = db.collection(collection_id) + for index in range(100, -1, -1): + doc = collection.document("test_{:09d}".format(index)) + data = {"count": 10 * index, "wordcount": {"page1": index * 10 + 100}} + await doc.set(data) + cleanup(doc.delete) + query = collection.order_by("wordcount.page1").limit(3) + data = [doc.to_dict()["wordcount"]["page1"] async for doc in query.stream()] + assert [100, 110, 120] == data + async for snapshot in collection.order_by("wordcount.page1").limit(3).stream(): + last_value = snapshot.get("wordcount.page1") + cursor_with_nested_keys = {"wordcount": {"page1": last_value}} + found = [ + i + async for i in collection.order_by("wordcount.page1") + .start_after(cursor_with_nested_keys) + .limit(3) + .stream() + ] + found_data = [ + {u"count": 30, u"wordcount": {u"page1": 130}}, + {u"count": 40, u"wordcount": {u"page1": 140}}, + {u"count": 50, u"wordcount": {u"page1": 150}}, + ] + assert found_data == [snap.to_dict() for snap in found] + cursor_with_dotted_paths = {"wordcount.page1": last_value} + cursor_with_key_data = [ + i + async for i in collection.order_by("wordcount.page1") + .start_after(cursor_with_dotted_paths) + .limit(3) + .stream() + ] + assert found_data == [snap.to_dict() for snap in cursor_with_key_data] + + +async def test_query_unary(client, cleanup): + collection_name = "unary" + UNIQUE_RESOURCE_ID + collection = client.collection(collection_name) + field_name = "foo" + + _, document0 = await collection.add({field_name: None}) + # Add to clean-up. + cleanup(document0.delete) + + nan_val = float("nan") + _, document1 = await collection.add({field_name: nan_val}) + # Add to clean-up. + cleanup(document1.delete) + + # 0. Query for null. + query0 = collection.where(field_name, "==", None) + values0 = [i async for i in query0.stream()] + assert len(values0) == 1 + snapshot0 = values0[0] + assert snapshot0.reference._path == document0._path + assert snapshot0.to_dict() == {field_name: None} + + # 1. Query for a NAN. + query1 = collection.where(field_name, "==", nan_val) + values1 = [i async for i in query1.stream()] + assert len(values1) == 1 + snapshot1 = values1[0] + assert snapshot1.reference._path == document1._path + data1 = snapshot1.to_dict() + assert len(data1) == 1 + assert math.isnan(data1[field_name]) + + +async def test_collection_group_queries(client, cleanup): + collection_group = "b" + UNIQUE_RESOURCE_ID + + doc_paths = [ + "abc/123/" + collection_group + "/cg-doc1", + "abc/123/" + collection_group + "/cg-doc2", + collection_group + "/cg-doc3", + collection_group + "/cg-doc4", + "def/456/" + collection_group + "/cg-doc5", + collection_group + "/virtual-doc/nested-coll/not-cg-doc", + "x" + collection_group + "/not-cg-doc", + collection_group + "x/not-cg-doc", + "abc/123/" + collection_group + "x/not-cg-doc", + "abc/123/x" + collection_group + "/not-cg-doc", + "abc/" + collection_group, + ] + + batch = client.batch() + for doc_path in doc_paths: + doc_ref = client.document(doc_path) + batch.set(doc_ref, {"x": 1}) + cleanup(doc_ref.delete) + + await batch.commit() + + query = client.collection_group(collection_group) + snapshots = [i async for i in query.stream()] + found = [snapshot.id for snapshot in snapshots] + expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] + assert found == expected + + +async def test_collection_group_queries_startat_endat(client, cleanup): + collection_group = "b" + UNIQUE_RESOURCE_ID + + doc_paths = [ + "a/a/" + collection_group + "/cg-doc1", + "a/b/a/b/" + collection_group + "/cg-doc2", + "a/b/" + collection_group + "/cg-doc3", + "a/b/c/d/" + collection_group + "/cg-doc4", + "a/c/" + collection_group + "/cg-doc5", + collection_group + "/cg-doc6", + "a/b/nope/nope", + ] + + batch = client.batch() + for doc_path in doc_paths: + doc_ref = client.document(doc_path) + batch.set(doc_ref, {"x": doc_path}) + cleanup(doc_ref.delete) + + await batch.commit() + + query = ( + client.collection_group(collection_group) + .order_by("__name__") + .start_at([client.document("a/b")]) + .end_at([client.document("a/b0")]) + ) + snapshots = [i async for i in query.stream()] + found = set(snapshot.id for snapshot in snapshots) + assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + + query = ( + client.collection_group(collection_group) + .order_by("__name__") + .start_after([client.document("a/b")]) + .end_before([client.document("a/b/" + collection_group + "/cg-doc3")]) + ) + snapshots = [i async for i in query.stream()] + found = set(snapshot.id for snapshot in snapshots) + assert found == set(["cg-doc2"]) + + +async def test_collection_group_queries_filters(client, cleanup): + collection_group = "b" + UNIQUE_RESOURCE_ID + + doc_paths = [ + "a/a/" + collection_group + "/cg-doc1", + "a/b/a/b/" + collection_group + "/cg-doc2", + "a/b/" + collection_group + "/cg-doc3", + "a/b/c/d/" + collection_group + "/cg-doc4", + "a/c/" + collection_group + "/cg-doc5", + collection_group + "/cg-doc6", + "a/b/nope/nope", + ] + + batch = client.batch() + + for index, doc_path in enumerate(doc_paths): + doc_ref = client.document(doc_path) + batch.set(doc_ref, {"x": index}) + cleanup(doc_ref.delete) + + await batch.commit() + + query = ( + client.collection_group(collection_group) + .where( + firestore.field_path.FieldPath.document_id(), ">=", client.document("a/b") + ) + .where( + firestore.field_path.FieldPath.document_id(), "<=", client.document("a/b0") + ) + ) + snapshots = [i async for i in query.stream()] + found = set(snapshot.id for snapshot in snapshots) + assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + + query = ( + client.collection_group(collection_group) + .where( + firestore.field_path.FieldPath.document_id(), ">", client.document("a/b") + ) + .where( + firestore.field_path.FieldPath.document_id(), + "<", + client.document("a/b/{}/cg-doc3".format(collection_group)), + ) + ) + snapshots = [i async for i in query.stream()] + found = set(snapshot.id for snapshot in snapshots) + assert found == set(["cg-doc2"]) + + +async def test_get_all(client, cleanup): + collection_name = "get-all" + UNIQUE_RESOURCE_ID + + document1 = client.document(collection_name, "a") + document2 = client.document(collection_name, "b") + document3 = client.document(collection_name, "c") + # Add to clean-up before API requests (in case ``create()`` fails). + cleanup(document1.delete) + cleanup(document3.delete) + + data1 = {"a": {"b": 2, "c": 3}, "d": 4, "e": 0} + write_result1 = await document1.create(data1) + data3 = {"a": {"b": 5, "c": 6}, "d": 7, "e": 100} + write_result3 = await document3.create(data3) + + # 0. Get 3 unique documents, one of which is missing. + snapshots = [i async for i in client.get_all([document1, document2, document3])] + + assert snapshots[0].exists + assert snapshots[1].exists + assert not snapshots[2].exists + + snapshots = [snapshot for snapshot in snapshots if snapshot.exists] + id_attr = operator.attrgetter("id") + snapshots.sort(key=id_attr) + + snapshot1, snapshot3 = snapshots + check_snapshot(snapshot1, document1, data1, write_result1) + check_snapshot(snapshot3, document3, data3, write_result3) + + # 1. Get 2 colliding documents. + document1_also = client.document(collection_name, "a") + snapshots = [i async for i in client.get_all([document1, document1_also])] + + assert len(snapshots) == 1 + assert document1 is not document1_also + check_snapshot(snapshots[0], document1_also, data1, write_result1) + + # 2. Use ``field_paths`` / projection in ``get_all()``. + snapshots = [ + i + async for i in client.get_all([document1, document3], field_paths=["a.b", "d"]) + ] + + assert len(snapshots) == 2 + snapshots.sort(key=id_attr) + + snapshot1, snapshot3 = snapshots + restricted1 = {"a": {"b": data1["a"]["b"]}, "d": data1["d"]} + check_snapshot(snapshot1, document1, restricted1, write_result1) + restricted3 = {"a": {"b": data3["a"]["b"]}, "d": data3["d"]} + check_snapshot(snapshot3, document3, restricted3, write_result3) + + +async def test_batch(client, cleanup): + collection_name = "batch" + UNIQUE_RESOURCE_ID + + document1 = client.document(collection_name, "abc") + document2 = client.document(collection_name, "mno") + document3 = client.document(collection_name, "xyz") + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document1.delete) + cleanup(document2.delete) + cleanup(document3.delete) + + data2 = {"some": {"deep": "stuff", "and": "here"}, "water": 100.0} + await document2.create(data2) + await document3.create({"other": 19}) + + batch = client.batch() + data1 = {"all": True} + batch.create(document1, data1) + new_value = "there" + batch.update(document2, {"some.and": new_value}) + batch.delete(document3) + write_results = await batch.commit() + + assert len(write_results) == 3 + + write_result1 = write_results[0] + write_result2 = write_results[1] + write_result3 = write_results[2] + assert not write_result3._pb.HasField("update_time") + + snapshot1 = await document1.get() + assert snapshot1.to_dict() == data1 + assert snapshot1.create_time == write_result1.update_time + assert snapshot1.update_time == write_result1.update_time + + snapshot2 = await document2.get() + assert snapshot2.to_dict() != data2 + data2["some"]["and"] = new_value + assert snapshot2.to_dict() == data2 + assert_timestamp_less(snapshot2.create_time, write_result2.update_time) + assert snapshot2.update_time == write_result2.update_time + + assert not (await document3.get()).exists diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 0beb0157c..8a6527175 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -236,7 +236,7 @@ def _next_page(self): async def _get_all_helper(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["batch_get_documents"]) + firestore_api = AsyncMock(spec=["batch_get_documents"]) response_iterator = AsyncIter(document_pbs) firestore_api.batch_get_documents.return_value = response_iterator diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 742a381db..5649561e0 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -185,17 +185,17 @@ async def test_add_explicit_id(self): @pytest.mark.asyncio async def _list_documents_helper(self, page_size=None): - from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator_async import AsyncIterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.types.document import Document - class _Iterator(Iterator): + class _AsyncIterator(AsyncIterator): def __init__(self, pages): - super(_Iterator, self).__init__(client=None) + super(_AsyncIterator, self).__init__(client=None) self._pages = pages - def _next_page(self): + async def _next_page(self): if self._pages: page, self._pages = self._pages[0], self._pages[1:] return Page(self, page, self.item_to_value) @@ -206,7 +206,7 @@ def _next_page(self): documents = [ Document(name=template.format(document_id)) for document_id in document_ids ] - iterator = _Iterator(pages=[documents]) + iterator = _AsyncIterator(pages=[documents]) firestore_api = AsyncMock() firestore_api.mock_add_spec(spec=["list_documents"]) firestore_api.list_documents.return_value = iterator @@ -214,9 +214,11 @@ def _next_page(self): collection = self._make_one("collection", client=client) if page_size is not None: - documents = list(await collection.list_documents(page_size=page_size)) + documents = [ + i async for i in collection.list_documents(page_size=page_size) + ] else: - documents = list(await collection.list_documents()) + documents = [i async for i in collection.list_documents()] # Verify the response and the mocks. self.assertEqual(len(documents), len(document_ids)) @@ -320,12 +322,6 @@ async def test_stream_with_transaction(self, query_class): query_instance = query_class.return_value query_instance.stream.assert_called_once_with(transaction=transaction) - @mock.patch("google.cloud.firestore_v1.async_collection.Watch", autospec=True) - def test_on_snapshot(self, watch): - collection = self._make_one("collection") - collection.on_snapshot(None) - watch.for_query.assert_called_once() - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 816f3b6b7..79a89d4ab 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -477,13 +477,6 @@ async def test_collections_wo_page_size(self): async def test_collections_w_page_size(self): await self._collections_helper(page_size=10) - @mock.patch("google.cloud.firestore_v1.async_document.Watch", autospec=True) - def test_on_snapshot(self, watch): - client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) - document = self._make_one("yellow", "mellow", client=client) - document.on_snapshot(None) - watch.for_document.assert_called_once() - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 1bbbf9ff7..be9c34358 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -17,7 +17,7 @@ import aiounittest import mock -from tests.unit.v1.test__helpers import AsyncIter +from tests.unit.v1.test__helpers import AsyncMock, AsyncIter from tests.unit.v1.test_base_query import _make_credentials, _make_query_response @@ -62,7 +62,7 @@ async def test_get(self): stream_mock.return_value = AsyncIter(range(3)) # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. client = _make_client() @@ -90,7 +90,7 @@ async def test_get(self): @pytest.mark.asyncio async def test_stream_simple(self): # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. client = _make_client() @@ -130,7 +130,7 @@ async def test_stream_simple(self): @pytest.mark.asyncio async def test_stream_with_transaction(self): # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. client = _make_client() @@ -174,7 +174,7 @@ async def test_stream_with_transaction(self): @pytest.mark.asyncio async def test_stream_no_results(self): # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) empty_response = _make_query_response() run_query_response = AsyncIter([empty_response]) firestore_api.run_query.return_value = run_query_response @@ -205,7 +205,7 @@ async def test_stream_no_results(self): @pytest.mark.asyncio async def test_stream_second_response_in_empty_stream(self): # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) empty_response1 = _make_query_response() empty_response2 = _make_query_response() run_query_response = AsyncIter([empty_response1, empty_response2]) @@ -237,7 +237,7 @@ async def test_stream_second_response_in_empty_stream(self): @pytest.mark.asyncio async def test_stream_with_skipped_results(self): # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. client = _make_client() @@ -278,7 +278,7 @@ async def test_stream_with_skipped_results(self): @pytest.mark.asyncio async def test_stream_empty_after_first_response(self): # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. client = _make_client() @@ -319,7 +319,7 @@ async def test_stream_empty_after_first_response(self): @pytest.mark.asyncio async def test_stream_w_collection_group(self): # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) + firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. client = _make_client() @@ -360,12 +360,6 @@ async def test_stream_w_collection_group(self): metadata=client._rpc_metadata, ) - @mock.patch("google.cloud.firestore_v1.async_query.Watch", autospec=True) - def test_on_snapshot(self, watch): - query = self._make_one(mock.sentinel.parent) - query.on_snapshot(None) - watch.for_query.assert_called_once() - def _make_client(project="project-project"): from google.cloud.firestore_v1.async_client import AsyncClient diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index 6f12c3394..a7774a28c 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -755,12 +755,12 @@ async def test___call__failure(self): ) -class Test_transactional(aiounittest.AsyncTestCase): +class Test_async_transactional(aiounittest.AsyncTestCase): @staticmethod def _call_fut(to_wrap): - from google.cloud.firestore_v1.async_transaction import transactional + from google.cloud.firestore_v1.async_transaction import async_transactional - return transactional(to_wrap) + return async_transactional(to_wrap) def test_it(self): from google.cloud.firestore_v1.async_transaction import _AsyncTransactional