From afff842a3356cbe5b0342be57341c12b2d601fda Mon Sep 17 00:00:00 2001 From: Christopher Wilcox Date: Wed, 5 Aug 2020 20:18:32 -0700 Subject: [PATCH] feat: add inline type hints and pytype ci (#134) --- google/cloud/firestore.py | 4 + .../services/firestore_admin/async_client.py | 4 +- .../services/firestore_admin/client.py | 6 +- .../firestore_admin/transports/base.py | 2 +- google/cloud/firestore_v1/__init__.py | 4 + google/cloud/firestore_v1/_helpers.py | 105 ++++++++++-------- google/cloud/firestore_v1/async_batch.py | 4 +- google/cloud/firestore_v1/async_client.py | 30 +++-- google/cloud/firestore_v1/async_collection.py | 32 ++++-- google/cloud/firestore_v1/async_document.py | 17 +-- google/cloud/firestore_v1/async_query.py | 12 +- .../cloud/firestore_v1/async_transaction.py | 41 ++++--- google/cloud/firestore_v1/base_batch.py | 12 +- google/cloud/firestore_v1/base_client.py | 63 ++++++----- google/cloud/firestore_v1/base_collection.py | 45 ++++---- google/cloud/firestore_v1/base_document.py | 35 +++--- google/cloud/firestore_v1/base_query.py | 77 ++++++++----- google/cloud/firestore_v1/base_transaction.py | 38 ++++--- google/cloud/firestore_v1/batch.py | 4 +- google/cloud/firestore_v1/client.py | 25 +++-- google/cloud/firestore_v1/collection.py | 17 +-- google/cloud/firestore_v1/document.py | 19 ++-- google/cloud/firestore_v1/order.py | 23 ++-- google/cloud/firestore_v1/query.py | 11 +- .../services/firestore/transports/base.py | 2 +- google/cloud/firestore_v1/transaction.py | 38 ++++--- google/cloud/firestore_v1/transforms.py | 6 +- google/cloud/firestore_v1/types/__init__.py | 48 ++++++++ google/cloud/firestore_v1/types/common.py | 3 + google/cloud/firestore_v1/types/document.py | 3 + google/cloud/firestore_v1/types/firestore.py | 3 + google/cloud/firestore_v1/types/query.py | 3 + google/cloud/firestore_v1/types/write.py | 3 + google/cloud/firestore_v1/watch.py | 8 +- noxfile.py | 10 +- setup.cfg | 11 ++ 36 files changed, 497 insertions(+), 271 deletions(-) diff --git a/google/cloud/firestore.py b/google/cloud/firestore.py index 4c5cb3fe2..8484b110a 100644 --- a/google/cloud/firestore.py +++ b/google/cloud/firestore.py @@ -46,6 +46,10 @@ from google.cloud.firestore_v1 import Watch from google.cloud.firestore_v1 import WriteBatch from google.cloud.firestore_v1 import WriteOption +from typing import List + +__all__: List[str] +__version__: str __all__ = [ diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py index 4957e3cc8..7e7dcc3f6 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py @@ -28,8 +28,8 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation -from google.api_core import operation_async +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.firestore_admin_v1.services.firestore_admin import pagers from google.cloud.firestore_admin_v1.types import field from google.cloud.firestore_admin_v1.types import field as gfa_field diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py index 4b3373fc9..b88b18dfb 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py @@ -30,9 +30,9 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.firestore_admin_v1.services.firestore_admin import pagers from google.cloud.firestore_admin_v1.types import field from google.cloud.firestore_admin_v1.types import field as gfa_field diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py index 56d98021f..ee9ce819e 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py @@ -18,7 +18,7 @@ import abc import typing -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index 74652de3e..684bdcd3a 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -97,6 +97,10 @@ # from .types.write import DocumentDelete # from .types.write import DocumentRemove from .types.write import DocumentTransform +from typing import List + +__all__: List[str] +__version__: str # from .types.write import ExistenceFilter # from .types.write import Write diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index e6aeb734b..77ae74d1f 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -17,12 +17,12 @@ import datetime from google.protobuf import struct_pb2 -from google.type import latlng_pb2 -import grpc +from google.type import latlng_pb2 # type: ignore +import grpc # type: ignore -from google.cloud import exceptions -from google.cloud._helpers import _datetime_to_pb_timestamp -from google.api_core.datetime_helpers import DatetimeWithNanoseconds +from google.cloud import exceptions # type: ignore +from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore +from google.api_core.datetime_helpers import DatetimeWithNanoseconds # type: ignore from google.cloud.firestore_v1.types.write import DocumentTransform from google.cloud.firestore_v1 import transforms from google.cloud.firestore_v1 import types @@ -31,6 +31,11 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import write +from typing import Any, Generator, List, NoReturn, Optional, Tuple + +_EmptyDict: transforms.Sentinel +_GRPC_ERROR_MAPPING: dict +_datetime_to_pb_timestamp: Any BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}." @@ -60,11 +65,11 @@ class GeoPoint(object): longitude (float): Longitude of a point. """ - def __init__(self, latitude, longitude): + def __init__(self, latitude, longitude) -> None: self.latitude = latitude self.longitude = longitude - def to_protobuf(self): + def to_protobuf(self) -> Any: """Convert the current object to protobuf. Returns: @@ -100,7 +105,7 @@ def __ne__(self, other): return not equality_val -def verify_path(path, is_collection): +def verify_path(path, is_collection) -> None: """Verifies that a ``path`` has the correct form. Checks that all of the elements in ``path`` are strings. @@ -136,7 +141,7 @@ def verify_path(path, is_collection): raise ValueError(msg) -def encode_value(value): +def encode_value(value) -> types.document.Value: """Converts a native Python value into a Firestore protobuf ``Value``. Args: @@ -200,7 +205,7 @@ def encode_value(value): ) -def encode_dict(values_dict): +def encode_dict(values_dict) -> dict: """Encode a dictionary into protobuf ``Value``-s. Args: @@ -214,7 +219,7 @@ def encode_dict(values_dict): return {key: encode_value(value) for key, value in values_dict.items()} -def reference_value_to_document(reference_value, client): +def reference_value_to_document(reference_value, client) -> Any: """Convert a reference value string to a document. Args: @@ -248,7 +253,7 @@ def reference_value_to_document(reference_value, client): return document -def decode_value(value, client): +def decode_value(value, client) -> Any: """Converts a Firestore protobuf ``Value`` to a native Python value. Args: @@ -294,7 +299,7 @@ def decode_value(value, client): raise ValueError("Unknown ``value_type``", value_type) -def decode_dict(value_fields, client): +def decode_dict(value_fields, client) -> dict: """Converts a protobuf map of Firestore ``Value``-s. Args: @@ -311,7 +316,7 @@ def decode_dict(value_fields, client): return {key: decode_value(value, client) for key, value in value_fields.items()} -def get_doc_id(document_pb, expected_prefix): +def get_doc_id(document_pb, expected_prefix) -> Any: """Parse a document ID from a document protobuf. Args: @@ -342,7 +347,9 @@ def get_doc_id(document_pb, expected_prefix): _EmptyDict = transforms.Sentinel("Marker for an empty dict value") -def extract_fields(document_data, prefix_path, expand_dots=False): +def extract_fields( + document_data, prefix_path: FieldPath, expand_dots=False +) -> Generator[Tuple[Any, Any], Any, None]: """Do depth-first walk of tree, yielding field_path, value""" if not document_data: yield prefix_path, _EmptyDict @@ -363,7 +370,7 @@ def extract_fields(document_data, prefix_path, expand_dots=False): yield field_path, value -def set_field_value(document_data, field_path, value): +def set_field_value(document_data, field_path, value) -> None: """Set a value into a document for a field_path""" current = document_data for element in field_path.parts[:-1]: @@ -373,7 +380,7 @@ def set_field_value(document_data, field_path, value): current[field_path.parts[-1]] = value -def get_field_value(document_data, field_path): +def get_field_value(document_data, field_path) -> Any: if not field_path.parts: raise ValueError("Empty path") @@ -394,7 +401,7 @@ class DocumentExtractor(object): a document. """ - def __init__(self, document_data): + def __init__(self, document_data) -> None: self.document_data = document_data self.field_paths = [] self.deleted_fields = [] @@ -440,7 +447,9 @@ def __init__(self, document_data): self.field_paths.append(field_path) set_field_value(self.set_fields, field_path, value) - def _get_document_iterator(self, prefix_path): + def _get_document_iterator( + self, prefix_path: FieldPath + ) -> Generator[Tuple[Any, Any], Any, None]: return extract_fields(self.document_data, prefix_path) @property @@ -465,10 +474,12 @@ def transform_paths(self): + list(self.minimums) ) - def _get_update_mask(self, allow_empty_mask=False): + def _get_update_mask(self, allow_empty_mask=False) -> None: return None - def get_update_pb(self, document_path, exists=None, allow_empty_mask=False): + def get_update_pb( + self, document_path, exists=None, allow_empty_mask=False + ) -> types.write.Write: if exists is not None: current_document = common.Precondition(exists=exists) @@ -485,7 +496,7 @@ def get_update_pb(self, document_path, exists=None, allow_empty_mask=False): return update_pb - def get_transform_pb(self, document_path, exists=None): + def get_transform_pb(self, document_path, exists=None) -> types.write.Write: def make_array_value(values): value_list = [encode_value(element) for element in values] return document.ArrayValue(values=value_list) @@ -565,7 +576,7 @@ def make_array_value(values): return transform_pb -def pbs_for_create(document_path, document_data): +def pbs_for_create(document_path, document_data) -> List[types.write.Write]: """Make ``Write`` protobufs for ``create()`` methods. Args: @@ -597,7 +608,7 @@ def pbs_for_create(document_path, document_data): return write_pbs -def pbs_for_set_no_merge(document_path, document_data): +def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]: """Make ``Write`` protobufs for ``set()`` methods. Args: @@ -632,7 +643,7 @@ class DocumentExtractorForMerge(DocumentExtractor): """ Break document data up into actual data and transforms. """ - def __init__(self, document_data): + def __init__(self, document_data) -> None: super(DocumentExtractorForMerge, self).__init__(document_data) self.data_merge = [] self.transform_merge = [] @@ -652,20 +663,20 @@ def has_updates(self): return bool(update_paths) - def _apply_merge_all(self): + def _apply_merge_all(self) -> None: self.data_merge = sorted(self.field_paths + self.deleted_fields) # TODO: other transforms self.transform_merge = self.transform_paths self.merge = sorted(self.data_merge + self.transform_paths) - def _construct_merge_paths(self, merge): + def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]: for merge_field in merge: if isinstance(merge_field, FieldPath): yield merge_field else: yield FieldPath(*parse_field_path(merge_field)) - def _normalize_merge_paths(self, merge): + def _normalize_merge_paths(self, merge) -> list: merge_paths = sorted(self._construct_merge_paths(merge)) # Raise if any merge path is a parent of another. Leverage sorting @@ -685,7 +696,7 @@ def _normalize_merge_paths(self, merge): return merge_paths - def _apply_merge_paths(self, merge): + def _apply_merge_paths(self, merge) -> None: if self.empty_document: raise ValueError("Cannot merge specific fields with empty document.") @@ -749,13 +760,15 @@ def _apply_merge_paths(self, merge): if path in merged_transform_paths } - def apply_merge(self, merge): + def apply_merge(self, merge) -> None: if merge is True: # merge all fields self._apply_merge_all() else: self._apply_merge_paths(merge) - def _get_update_mask(self, allow_empty_mask=False): + def _get_update_mask( + self, allow_empty_mask=False + ) -> Optional[types.common.DocumentMask]: # Mask uses dotted / quoted paths. mask_paths = [ field_path.to_api_repr() @@ -767,7 +780,9 @@ def _get_update_mask(self, allow_empty_mask=False): return common.DocumentMask(field_paths=mask_paths) -def pbs_for_set_with_merge(document_path, document_data, merge): +def pbs_for_set_with_merge( + document_path, document_data, merge +) -> List[types.write.Write]: """Make ``Write`` protobufs for ``set()`` methods. Args: @@ -804,7 +819,7 @@ class DocumentExtractorForUpdate(DocumentExtractor): """ Break document data up into actual data and transforms. """ - def __init__(self, document_data): + def __init__(self, document_data) -> None: super(DocumentExtractorForUpdate, self).__init__(document_data) self.top_level_paths = sorted( [FieldPath.from_string(key) for key in document_data] @@ -825,10 +840,12 @@ def __init__(self, document_data): "Cannot update with nest delete: {}".format(field_path) ) - def _get_document_iterator(self, prefix_path): + def _get_document_iterator( + self, prefix_path: FieldPath + ) -> Generator[Tuple[Any, Any], Any, None]: return extract_fields(self.document_data, prefix_path, expand_dots=True) - def _get_update_mask(self, allow_empty_mask=False): + def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask: mask_paths = [] for field_path in self.top_level_paths: if field_path not in self.transform_paths: @@ -837,7 +854,7 @@ def _get_update_mask(self, allow_empty_mask=False): return common.DocumentMask(field_paths=mask_paths) -def pbs_for_update(document_path, field_updates, option): +def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]: """Make ``Write`` protobufs for ``update()`` methods. Args: @@ -878,7 +895,7 @@ def pbs_for_update(document_path, field_updates, option): return write_pbs -def pb_for_delete(document_path, option): +def pb_for_delete(document_path, option) -> types.write.Write: """Make a ``Write`` protobuf for ``delete()`` methods. Args: @@ -905,7 +922,7 @@ class ReadAfterWriteError(Exception): """ -def get_transaction_id(transaction, read_operation=True): +def get_transaction_id(transaction, read_operation=True) -> Any: """Get the transaction ID from a ``Transaction`` object. Args: @@ -935,7 +952,7 @@ def get_transaction_id(transaction, read_operation=True): return transaction.id -def metadata_with_prefix(prefix, **kw): +def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: """Create RPC metadata containing a prefix. Args: @@ -950,7 +967,7 @@ def metadata_with_prefix(prefix, **kw): class WriteOption(object): """Option used to assert a condition on a write operation.""" - def modify_write(self, write, no_create_msg=None): + def modify_write(self, write, no_create_msg=None) -> NoReturn: """Modify a ``Write`` protobuf based on the state of this write option. This is a virtual method intended to be implemented by subclasses. @@ -982,7 +999,7 @@ class LastUpdateOption(WriteOption): as part of a "write result" protobuf or directly. """ - def __init__(self, last_update_time): + def __init__(self, last_update_time) -> None: self._last_update_time = last_update_time def __eq__(self, other): @@ -990,7 +1007,7 @@ def __eq__(self, other): return NotImplemented return self._last_update_time == other._last_update_time - def modify_write(self, write, **unused_kwargs): + def modify_write(self, write, **unused_kwargs) -> None: """Modify a ``Write`` protobuf based on the state of this write option. The ``last_update_time`` is added to ``write_pb`` as an "update time" @@ -1019,7 +1036,7 @@ class ExistsOption(WriteOption): should already exist. """ - def __init__(self, exists): + def __init__(self, exists) -> None: self._exists = exists def __eq__(self, other): @@ -1027,7 +1044,7 @@ def __eq__(self, other): return NotImplemented return self._exists == other._exists - def modify_write(self, write, **unused_kwargs): + def modify_write(self, write, **unused_kwargs) -> None: """Modify a ``Write`` protobuf based on the state of this write option. If: diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 983a3bd98..cc359d6b5 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -30,10 +30,10 @@ class AsyncWriteBatch(BaseWriteBatch): The client that created this batch. """ - def __init__(self, client): + def __init__(self, client) -> None: super(AsyncWriteBatch, self).__init__(client=client) - async def commit(self): + async def commit(self) -> list: """Commit the changes accumulated in this batch. Returns: diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index e6e9656ae..44e07f272 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -28,8 +28,8 @@ BaseClient, DEFAULT_DATABASE, _CLIENT_INFO, - _reference_info, - _parse_batch_get, + _reference_info, # type: ignore + _parse_batch_get, # type: ignore _get_doc_mask, _path_helper, ) @@ -38,7 +38,10 @@ from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference -from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_document import ( + AsyncDocumentReference, + DocumentSnapshot, +) from google.cloud.firestore_v1.async_transaction import AsyncTransaction from google.cloud.firestore_v1.services.firestore import ( async_client as firestore_client, @@ -46,6 +49,9 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) +from typing import Any, AsyncGenerator, NoReturn + +_CLIENT_INFO: Any class AsyncClient(BaseClient): @@ -83,7 +89,7 @@ def __init__( database=DEFAULT_DATABASE, client_info=_CLIENT_INFO, client_options=None, - ): + ) -> None: super(AsyncClient, self).__init__( project=project, credentials=credentials, @@ -115,7 +121,7 @@ def _target(self): """ return self._target_helper(firestore_client.FirestoreAsyncClient) - def collection(self, *collection_path): + def collection(self, *collection_path) -> AsyncCollectionReference: """Get a reference to a collection. For a top-level collection: @@ -146,7 +152,7 @@ def collection(self, *collection_path): """ return AsyncCollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id): + def collection_group(self, collection_id) -> NoReturn: """ Creates and returns a new AsyncQuery that includes all documents in the database that are contained in a collection or subcollection with the @@ -170,7 +176,7 @@ def collection_group(self, collection_id): self._get_collection_reference(collection_id), all_descendants=True ) - def document(self, *document_path): + def document(self, *document_path) -> AsyncDocumentReference: """Get a reference to a document in a collection. For a top-level document: @@ -205,7 +211,9 @@ def document(self, *document_path): *self._document_path_helper(*document_path), client=self ) - async def get_all(self, references, field_paths=None, transaction=None): + async def get_all( + self, references, field_paths=None, transaction=None + ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieve a batch of documents. .. note:: @@ -255,7 +263,7 @@ async def get_all(self, references, field_paths=None, transaction=None): async for get_doc_response in response_iterator: yield _parse_batch_get(get_doc_response, reference_map, self) - async def collections(self): + async def collections(self) -> AsyncGenerator[AsyncCollectionReference, Any]: """List top-level collections of the client's database. Returns: @@ -288,7 +296,7 @@ async def collections(self): # iterator.item_to_value = _item_to_collection_ref # return iterator - def batch(self): + def batch(self) -> AsyncWriteBatch: """Get a batch instance from this client. Returns: @@ -298,7 +306,7 @@ def batch(self): """ return AsyncWriteBatch(self) - def transaction(self, **kwargs): + def transaction(self, **kwargs) -> AsyncTransaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 95967b294..bd9aef5e5 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -21,7 +21,15 @@ _auto_id, _item_to_document_ref, ) -from google.cloud.firestore_v1 import async_query +from google.cloud.firestore_v1 import ( + async_query, + async_document, +) + +from google.cloud.firestore_v1.document import DocumentReference + +from typing import AsyncIterator +from typing import Any, AsyncGenerator, Tuple class AsyncCollectionReference(BaseCollectionReference): @@ -50,10 +58,10 @@ class AsyncCollectionReference(BaseCollectionReference): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs): + def __init__(self, *path, **kwargs) -> None: super(AsyncCollectionReference, self).__init__(*path, **kwargs) - def _query(self): + def _query(self) -> async_query.AsyncQuery: """Query factory. Returns: @@ -61,7 +69,7 @@ def _query(self): """ return async_query.AsyncQuery(self) - async def add(self, document_data, document_id=None): + async def add(self, document_data, document_id=None) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. Args: @@ -92,7 +100,9 @@ async def add(self, document_data, document_id=None): write_result = await document_ref.create(document_data) return write_result.update_time, document_ref - async def list_documents(self, page_size=None): + async def list_documents( + self, page_size=None + ) -> AsyncGenerator[DocumentReference, None]: """List all subdocuments of the current collection. Args: @@ -120,7 +130,9 @@ async def list_documents(self, page_size=None): async for i in iterator: yield _item_to_document_ref(self, i) - async def get(self, transaction=None): + async def get( + self, transaction=None + ) -> AsyncGenerator[async_document.DocumentSnapshot, Any]: """Deprecated alias for :meth:`stream`.""" warnings.warn( "'Collection.get' is deprecated: please use 'Collection.stream' instead.", @@ -128,9 +140,11 @@ async def get(self, transaction=None): stacklevel=2, ) async for d in self.stream(transaction=transaction): - yield d + yield d # pytype: disable=name-error - async def stream(self, transaction=None): + async def stream( + self, transaction=None + ) -> AsyncIterator[async_document.DocumentSnapshot]: """Read the documents in this collection. This sends a ``RunQuery`` RPC and then returns an iterator which @@ -159,4 +173,4 @@ async def stream(self, transaction=None): """ query = async_query.AsyncQuery(self) async for d in query.stream(transaction=transaction): - yield d + yield d # pytype: disable=name-error diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index a36d8894a..f387707c9 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -20,9 +20,10 @@ _first_write_result, ) -from google.api_core import exceptions +from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common +from typing import AsyncGenerator, Coroutine class AsyncDocumentReference(BaseDocumentReference): @@ -50,10 +51,10 @@ class AsyncDocumentReference(BaseDocumentReference): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs): + def __init__(self, *path, **kwargs) -> None: super(AsyncDocumentReference, self).__init__(*path, **kwargs) - async def create(self, document_data): + async def create(self, document_data) -> Coroutine: """Create the current document in the Firestore database. Args: @@ -74,7 +75,7 @@ async def create(self, document_data): write_results = await batch.commit() return _first_write_result(write_results) - async def set(self, document_data, merge=False): + async def set(self, document_data, merge=False) -> Coroutine: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -105,7 +106,7 @@ async def set(self, document_data, merge=False): write_results = await batch.commit() return _first_write_result(write_results) - async def update(self, field_updates, option=None): + async def update(self, field_updates, option=None) -> Coroutine: """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -253,7 +254,7 @@ async def update(self, field_updates, option=None): write_results = await batch.commit() return _first_write_result(write_results) - async def delete(self, option=None): + async def delete(self, option=None) -> Coroutine: """Delete the current document in the Firestore database. Args: @@ -280,7 +281,7 @@ async def delete(self, option=None): return commit_response.commit_time - async def get(self, field_paths=None, transaction=None): + async def get(self, field_paths=None, transaction=None) -> DocumentSnapshot: """Retrieve a snapshot of the current document. See :meth:`~google.cloud.firestore_v1.base_client.BaseClient.field_path` for @@ -345,7 +346,7 @@ async def get(self, field_paths=None, transaction=None): update_time=update_time, ) - async def collections(self, page_size=None): + async def collections(self, page_size=None) -> AsyncGenerator: """List subcollections of the current document. Args: diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 14e17e71a..f556c1206 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -27,6 +27,8 @@ ) from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import async_document +from typing import AsyncGenerator class AsyncQuery(BaseQuery): @@ -96,7 +98,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, - ): + ) -> None: super(AsyncQuery, self).__init__( parent=parent, projection=projection, @@ -109,7 +111,9 @@ def __init__( all_descendants=all_descendants, ) - async def get(self, transaction=None): + async def get( + self, transaction=None + ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: """Deprecated alias for :meth:`stream`.""" warnings.warn( "'AsyncQuery.get' is deprecated: please use 'AsyncQuery.stream' instead.", @@ -119,7 +123,9 @@ async def get(self, transaction=None): async for d in self.stream(transaction=transaction): yield d - async def stream(self, transaction=None): + async def stream( + self, transaction=None + ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and then returns an iterator which diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 33a81a292..19a436b0b 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -32,10 +32,22 @@ _EXCEED_ATTEMPTS_TEMPLATE, ) -from google.api_core import exceptions +from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import async_batch +from google.cloud.firestore_v1 import types + from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_query import AsyncQuery +from typing import Coroutine + +_CANT_BEGIN: str +_CANT_COMMIT: str +_CANT_ROLLBACK: str +_EXCEED_ATTEMPTS_TEMPLATE: str +_INITIAL_SLEEP: float +_MAX_SLEEP: float +_MULTIPLIER: float +_WRITE_READ_ONLY: str class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): @@ -52,11 +64,11 @@ class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): :data:`False`. """ - def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: super(AsyncTransaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs): + def _add_write_pbs(self, write_pbs) -> None: """Add `Write`` protobufs to this transaction. Args: @@ -71,7 +83,7 @@ def _add_write_pbs(self, write_pbs): super(AsyncTransaction, self)._add_write_pbs(write_pbs) - async def _begin(self, retry_id=None): + async def _begin(self, retry_id=None) -> None: """Begin the transaction. Args: @@ -94,7 +106,7 @@ async def _begin(self, retry_id=None): ) self._id = transaction_response.transaction - async def _rollback(self): + async def _rollback(self) -> None: """Roll back the transaction. Raises: @@ -115,7 +127,7 @@ async def _rollback(self): finally: self._clean_up() - async def _commit(self): + async def _commit(self) -> list: """Transactionally commit the changes accumulated. Returns: @@ -137,7 +149,7 @@ async def _commit(self): self._clean_up() return list(commit_response.write_results) - async def get_all(self, references): + async def get_all(self, references) -> Coroutine: """Retrieves multiple documents from Firestore. Args: @@ -150,7 +162,7 @@ async def get_all(self, references): """ return await self._client.get_all(references, transaction=self) - async def get(self, ref_or_query): + async def get(self, ref_or_query) -> Coroutine: """ Retrieve a document or a query result from the database. Args: @@ -180,10 +192,10 @@ class _AsyncTransactional(_BaseTransactional): A callable that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap): + def __init__(self, to_wrap) -> None: super(_AsyncTransactional, self).__init__(to_wrap) - async def _pre_commit(self, transaction, *args, **kwargs): + async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine: """Begin transaction and call the wrapped callable. If the callable raises an exception, the transaction will be rolled @@ -221,7 +233,7 @@ async def _pre_commit(self, transaction, *args, **kwargs): await transaction._rollback() raise - async def _maybe_commit(self, transaction): + async def _maybe_commit(self, transaction) -> bool: """Try to commit the transaction. If the transaction is read-write and the ``Commit`` fails with the @@ -287,7 +299,7 @@ async def __call__(self, transaction, *args, **kwargs): raise ValueError(msg) -def async_transactional(to_wrap): +def async_transactional(to_wrap) -> _AsyncTransactional: """Decorate a callable so that it runs in a transaction. Args: @@ -302,7 +314,8 @@ def async_transactional(to_wrap): return _AsyncTransactional(to_wrap) -async def _commit_with_retry(client, write_pbs, transaction_id): +# TODO(crwilcox): this was 'coroutine' from pytype merge-pyi... +async def _commit_with_retry(client, write_pbs, transaction_id) -> types.CommitResponse: """Call ``Commit`` on the GAPIC client with retry / sleep. Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level @@ -345,7 +358,7 @@ async def _commit_with_retry(client, write_pbs, transaction_id): current_sleep = await _sleep(current_sleep) -async def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER): +async def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER) -> float: """Sleep and produce a new sleep time. .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index 45f8c49d9..dadcb0ec0 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -30,13 +30,13 @@ class BaseWriteBatch(object): The client that created this batch. """ - def __init__(self, client): + def __init__(self, client) -> None: self._client = client self._write_pbs = [] self.write_results = None self.commit_time = None - def _add_write_pbs(self, write_pbs): + def _add_write_pbs(self, write_pbs) -> None: """Add `Write`` protobufs to this transaction. This method intended to be over-ridden by subclasses. @@ -47,7 +47,7 @@ def _add_write_pbs(self, write_pbs): """ self._write_pbs.extend(write_pbs) - def create(self, reference, document_data): + def create(self, reference, document_data) -> None: """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -62,7 +62,7 @@ def create(self, reference, document_data): write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) self._add_write_pbs(write_pbs) - def set(self, reference, document_data, merge=False): + def set(self, reference, document_data, merge=False) -> None: """Add a "change" to replace a document. See @@ -90,7 +90,7 @@ def set(self, reference, document_data, merge=False): self._add_write_pbs(write_pbs) - def update(self, reference, field_updates, option=None): + def update(self, reference, field_updates, option=None) -> None: """Add a "change" to update a document. See @@ -113,7 +113,7 @@ def update(self, reference, field_updates, option=None): ) self._add_write_pbs(write_pbs) - def delete(self, reference, option=None): + def delete(self, reference, option=None) -> None: """Add a "change" to delete a document. See diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 538cafefa..e88a141a8 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -27,14 +27,23 @@ import google.api_core.client_options import google.api_core.path_template -from google.api_core.gapic_v1 import client_info -from google.cloud.client import ClientWithProject +from google.api_core.gapic_v1 import client_info # type: ignore +from google.cloud.client import ClientWithProject # type: ignore from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.field_path import render_field_path +from typing import Any, List, NoReturn, Optional, Tuple, Union + +_ACTIVE_TXN: str +_BAD_DOC_TEMPLATE: str +_BAD_OPTION_ERR: str +_CLIENT_INFO: Any +_FIRESTORE_EMULATOR_HOST: str +_INACTIVE_TXN: str +__version__: str DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" @@ -95,7 +104,7 @@ def __init__( database=DEFAULT_DATABASE, client_info=_CLIENT_INFO, client_options=None, - ): + ) -> None: # NOTE: This API has no use for the _http argument, but sending it # will have no impact since the _http() @property only lazily # creates a working HTTP object. @@ -105,7 +114,7 @@ def __init__( self._client_info = client_info if client_options: if type(client_options) == dict: - client_options = google.api_core.client_options.from_dict( + client_options = google.api_core.client_options.from_dict( # type: ignore client_options ) self._client_options = client_options @@ -113,7 +122,7 @@ def __init__( self._database = database self._emulator_host = os.getenv(_FIRESTORE_EMULATOR_HOST) - def _firestore_api_helper(self, transport, client_class, client_module): + def _firestore_api_helper(self, transport, client_class, client_module) -> Any: """Lazy-loading getter GAPIC Firestore API. Returns: The GAPIC client with the credentials of the current client. @@ -142,7 +151,7 @@ def _firestore_api_helper(self, transport, client_class, client_module): return self._firestore_api_internal - def _target_helper(self, client_class): + def _target_helper(self, client_class) -> Any: """Return the target (where the API is). Eg. "firestore.googleapis.com" @@ -173,7 +182,7 @@ def _database_string(self): project. (The default database is also in this string.) """ if self._database_string_internal is None: - db_str = google.api_core.path_template.expand( + db_str = google.api_core.path_template.expand( # type: ignore "projects/{project}/databases/{database}", project=self.project, database=self._database, @@ -202,13 +211,13 @@ def _rpc_metadata(self): return self._rpc_metadata_internal - def collection(self, *collection_path): + def collection(self, *collection_path) -> NoReturn: raise NotImplementedError - def collection_group(self, collection_id): + def collection_group(self, collection_id) -> NoReturn: raise NotImplementedError - def _get_collection_reference(self, collection_id): + def _get_collection_reference(self, collection_id) -> NoReturn: """Checks validity of collection_id and then uses subclasses collection implementation. Args: @@ -229,10 +238,10 @@ def _get_collection_reference(self, collection_id): return self.collection(collection_id) - def document(self, *document_path): + def document(self, *document_path) -> NoReturn: raise NotImplementedError - def _document_path_helper(self, *document_path): + def _document_path_helper(self, *document_path) -> List[str]: """Standardize the format of path to tuple of path segments and strip the database string from path if present. Args: @@ -249,7 +258,7 @@ def _document_path_helper(self, *document_path): return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) @staticmethod - def field_path(*field_names): + def field_path(*field_names) -> Any: """Create a **field path** from a list of nested field names. A **field path** is a ``.``-delimited concatenation of the field @@ -278,7 +287,11 @@ def field_path(*field_names): return render_field_path(field_names) @staticmethod - def write_option(**kwargs): + def write_option( + **kwargs, + ) -> Union[ + _helpers.ExistsOption, _helpers.LastUpdateOption, + ]: """Create a write option for write operations. Write operations include :meth:`~google.cloud.DocumentReference.set`, @@ -326,20 +339,20 @@ def write_option(**kwargs): extra = "{!r} was provided".format(name) raise TypeError(_BAD_OPTION_ERR, extra) - def get_all(self, references, field_paths=None, transaction=None): + def get_all(self, references, field_paths=None, transaction=None) -> NoReturn: raise NotImplementedError - def collections(self): + def collections(self) -> NoReturn: raise NotImplementedError - def batch(self): + def batch(self) -> NoReturn: raise NotImplementedError - def transaction(self, **kwargs): + def transaction(self, **kwargs) -> NoReturn: raise NotImplementedError -def _reference_info(references): +def _reference_info(references) -> Tuple[list, dict]: """Get information about document references. Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`. @@ -366,7 +379,7 @@ def _reference_info(references): return document_paths, reference_map -def _get_reference(document_path, reference_map): +def _get_reference(document_path, reference_map) -> Any: """Get a document reference from a dictionary. This just wraps a simple dictionary look-up with a helpful error that is @@ -392,7 +405,7 @@ def _get_reference(document_path, reference_map): raise ValueError(msg) -def _parse_batch_get(get_doc_response, reference_map, client): +def _parse_batch_get(get_doc_response, reference_map, client) -> DocumentSnapshot: """Parse a `BatchGetDocumentsResponse` protobuf. Args: @@ -442,7 +455,7 @@ def _parse_batch_get(get_doc_response, reference_map, client): return snapshot -def _get_doc_mask(field_paths): +def _get_doc_mask(field_paths,) -> Optional[types.common.DocumentMask]: """Get a document mask if field paths are provided. Args: @@ -451,7 +464,7 @@ def _get_doc_mask(field_paths): projection of document fields in the returned results. Returns: - Optional[google.cloud.firestore_v1.types.DocumentMask]: A mask + Optional[google.cloud.firestore_v1.types.common.DocumentMask]: A mask to project documents to a restricted set of field paths. """ if field_paths is None: @@ -460,7 +473,7 @@ def _get_doc_mask(field_paths): return types.DocumentMask(field_paths=field_paths) -def _item_to_collection_ref(iterator, item): +def _item_to_collection_ref(iterator, item) -> Any: """Convert collection ID to collection ref. Args: @@ -471,7 +484,7 @@ def _item_to_collection_ref(iterator, item): return iterator.client.collection(item) -def _path_helper(path): +def _path_helper(path) -> Any: """Standardize path into a tuple of path segments. Args: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index f7fc0e552..8ce40bd1b 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -16,6 +16,9 @@ import random from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.document import DocumentReference +from typing import Any, NoReturn, Tuple + _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -46,7 +49,7 @@ class BaseCollectionReference(object): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs): + def __init__(self, *path, **kwargs) -> None: _helpers.verify_path(path, is_collection=True) self._path = path self._client = kwargs.pop("client", None) @@ -84,10 +87,10 @@ def parent(self): parent_path = self._path[:-1] return self._client.document(*parent_path) - def _query(self): + def _query(self) -> NoReturn: raise NotImplementedError - def document(self, document_id=None): + def document(self, document_id=None) -> Any: """Create a sub-document underneath the current collection. Args: @@ -106,7 +109,7 @@ def document(self, document_id=None): child_path = self._path + (document_id,) return self._client.document(*child_path) - def _parent_info(self): + def _parent_info(self) -> Tuple[Any, str]: """Get fully-qualified parent path and prefix for this collection. Returns: @@ -128,13 +131,13 @@ def _parent_info(self): expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) return parent_path, expected_prefix - def add(self, document_data, document_id=None): + def add(self, document_data, document_id=None) -> NoReturn: raise NotImplementedError - def list_documents(self, page_size=None): + def list_documents(self, page_size=None) -> NoReturn: raise NotImplementedError - def select(self, field_paths): + def select(self, field_paths) -> NoReturn: """Create a "select" query with this collection as parent. See @@ -153,7 +156,7 @@ def select(self, field_paths): query = self._query() return query.select(field_paths) - def where(self, field_path, op_string, value): + def where(self, field_path, op_string, value) -> NoReturn: """Create a "where" query with this collection as parent. See @@ -177,7 +180,7 @@ def where(self, field_path, op_string, value): query = self._query() return query.where(field_path, op_string, value) - def order_by(self, field_path, **kwargs): + def order_by(self, field_path, **kwargs) -> NoReturn: """Create an "order by" query with this collection as parent. See @@ -199,7 +202,7 @@ def order_by(self, field_path, **kwargs): query = self._query() return query.order_by(field_path, **kwargs) - def limit(self, count): + def limit(self, count) -> NoReturn: """Create a limited query with this collection as parent. See @@ -217,7 +220,7 @@ def limit(self, count): query = self._query() return query.limit(count) - def offset(self, num_to_skip): + def offset(self, num_to_skip) -> NoReturn: """Skip to an offset in a query with this collection as parent. See @@ -235,7 +238,7 @@ def offset(self, num_to_skip): query = self._query() return query.offset(num_to_skip) - def start_at(self, document_fields): + def start_at(self, document_fields) -> NoReturn: """Start query at a cursor with this collection as parent. See @@ -256,7 +259,7 @@ def start_at(self, document_fields): query = self._query() return query.start_at(document_fields) - def start_after(self, document_fields): + def start_after(self, document_fields) -> NoReturn: """Start query after a cursor with this collection as parent. See @@ -277,7 +280,7 @@ def start_after(self, document_fields): query = self._query() return query.start_after(document_fields) - def end_before(self, document_fields): + def end_before(self, document_fields) -> NoReturn: """End query before a cursor with this collection as parent. See @@ -298,7 +301,7 @@ def end_before(self, document_fields): query = self._query() return query.end_before(document_fields) - def end_at(self, document_fields): + def end_at(self, document_fields) -> NoReturn: """End query at a cursor with this collection as parent. See @@ -319,17 +322,17 @@ def end_at(self, document_fields): query = self._query() return query.end_at(document_fields) - def get(self, transaction=None): + def get(self, transaction=None) -> NoReturn: raise NotImplementedError - def stream(self, transaction=None): + def stream(self, transaction=None) -> NoReturn: raise NotImplementedError - def on_snapshot(self, callback): + def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError -def _auto_id(): +def _auto_id() -> str: """Generate a "random" automatically generated ID. Returns: @@ -339,11 +342,11 @@ def _auto_id(): return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20)) -def _item_to_document_ref(collection_reference, item): +def _item_to_document_ref(collection_reference, item) -> DocumentReference: """Convert Document resource to document ref. Args: - iterator (google.api_core.page_iterator.GRPCIterator): + collection_reference (google.api_core.page_iterator.GRPCIterator): iterator response item (dict): document resource """ diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 196e3cb5e..c0a81d739 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -18,6 +18,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import field_path as field_path_module +from typing import Any, NoReturn class BaseDocumentReference(object): @@ -47,7 +48,7 @@ class BaseDocumentReference(object): _document_path_internal = None - def __init__(self, *path, **kwargs): + def __init__(self, *path, **kwargs) -> None: _helpers.verify_path(path, is_collection=False) self._path = path self._client = kwargs.pop("client", None) @@ -163,7 +164,7 @@ def parent(self): parent_path = self._path[:-1] return self._client.collection(*parent_path) - def collection(self, collection_id): + def collection(self, collection_id) -> Any: """Create a sub-collection underneath the current document. Args: @@ -177,25 +178,25 @@ def collection(self, collection_id): child_path = self._path + (collection_id,) return self._client.collection(*child_path) - def create(self, document_data): + def create(self, document_data) -> NoReturn: raise NotImplementedError - def set(self, document_data, merge=False): + def set(self, document_data, merge=False) -> NoReturn: raise NotImplementedError - def update(self, field_updates, option=None): + def update(self, field_updates, option=None) -> NoReturn: raise NotImplementedError - def delete(self, option=None): + def delete(self, option=None) -> NoReturn: raise NotImplementedError - def get(self, field_paths=None, transaction=None): + def get(self, field_paths=None, transaction=None) -> NoReturn: raise NotImplementedError - def collections(self, page_size=None): + def collections(self, page_size=None) -> NoReturn: raise NotImplementedError - def on_snapshot(self, callback): + def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError @@ -227,7 +228,9 @@ class DocumentSnapshot(object): The time that this document was last updated. """ - def __init__(self, reference, data, exists, read_time, create_time, update_time): + def __init__( + self, reference, data, exists, read_time, create_time, update_time + ) -> None: self._reference = reference # We want immutable data, so callers can't modify this value # out from under us. @@ -288,7 +291,7 @@ def reference(self): """ return self._reference - def get(self, field_path): + def get(self, field_path) -> Any: """Get a value from the snapshot data. If the data is nested, for example: @@ -352,7 +355,7 @@ def get(self, field_path): nested_data = field_path_module.get_nested_value(field_path, self._data) return copy.deepcopy(nested_data) - def to_dict(self): + def to_dict(self) -> Any: """Retrieve the data contained in this snapshot. A copy is returned since the data may contain mutable values, @@ -368,7 +371,7 @@ def to_dict(self): return copy.deepcopy(self._data) -def _get_document_path(client, path): +def _get_document_path(client, path) -> str: """Convert a path tuple into a full path string. Of the form: @@ -389,7 +392,7 @@ def _get_document_path(client, path): return _helpers.DOCUMENT_PATH_DELIMITER.join(parts) -def _consume_single_get(response_iterator): +def _consume_single_get(response_iterator) -> Any: """Consume a gRPC stream that should contain a single response. The stream will correspond to a ``BatchGetDocuments`` request made @@ -420,7 +423,7 @@ def _consume_single_get(response_iterator): return all_responses[0] -def _first_write_result(write_results): +def _first_write_result(write_results) -> Any: """Get first write result from list. For cases where ``len(write_results) > 1``, this assumes the writes @@ -446,7 +449,7 @@ def _first_write_result(write_results): return write_results[0] -def _item_to_collection_ref(iterator, item): +def _item_to_collection_ref(iterator, item) -> Any: """Convert collection ID to collection ref. Args: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 16925f7ea..0522ac89a 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -29,7 +29,22 @@ from google.cloud.firestore_v1 import transforms from google.cloud.firestore_v1.types import StructuredQuery from google.cloud.firestore_v1.types import query +from google.cloud.firestore_v1.types import Cursor from google.cloud.firestore_v1.order import Order +from typing import Any, Dict, NoReturn, Optional, Tuple + +_BAD_DIR_STRING: str +_BAD_OP_NAN_NULL: str +_BAD_OP_STRING: str +_COMPARISON_OPERATORS: Dict[str, Any] +_EQ_OP: str +_INVALID_CURSOR_TRANSFORM: str +_INVALID_WHERE_TRANSFORM: str +_MISMATCH_CURSOR_W_ORDER_BY: str +_MISSING_ORDER_BY: str +_NO_ORDERS_FOR_CURSOR: str +_operator_enum: Any + _EQ_OP = "==" _operator_enum = StructuredQuery.FieldFilter.Operator @@ -135,7 +150,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, - ): + ) -> None: self._parent = parent self._projection = projection self._field_filters = field_filters @@ -171,7 +186,7 @@ def _client(self): """ return self._parent._client - def select(self, field_paths): + def select(self, field_paths) -> "BaseQuery": """Project documents matching query to a limited set of fields. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -215,7 +230,7 @@ def select(self, field_paths): all_descendants=self._all_descendants, ) - def where(self, field_path, op_string, value): + def where(self, field_path, op_string, value) -> "BaseQuery": """Filter the query on a field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -285,14 +300,14 @@ def where(self, field_path, op_string, value): ) @staticmethod - def _make_order(field_path, direction): + def _make_order(field_path, direction) -> Any: """Helper for :meth:`order_by`.""" return query.StructuredQuery.Order( field=query.StructuredQuery.FieldReference(field_path=field_path), direction=_enum_from_direction(direction), ) - def order_by(self, field_path, direction=ASCENDING): + def order_by(self, field_path, direction=ASCENDING) -> "BaseQuery": """Modify the query to add an order clause on a specific field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -336,7 +351,7 @@ def order_by(self, field_path, direction=ASCENDING): all_descendants=self._all_descendants, ) - def limit(self, count): + def limit(self, count) -> "BaseQuery": """Limit a query to return a fixed number of results. If the current query already has a limit set, this will overwrite it. @@ -362,7 +377,7 @@ def limit(self, count): all_descendants=self._all_descendants, ) - def offset(self, num_to_skip): + def offset(self, num_to_skip) -> "BaseQuery": """Skip to an offset in a query. If the current query already has specified an offset, this will @@ -389,7 +404,7 @@ def offset(self, num_to_skip): all_descendants=self._all_descendants, ) - def _check_snapshot(self, document_fields): + def _check_snapshot(self, document_fields) -> None: """Validate local snapshots for non-collection-group queries. Raises: @@ -402,7 +417,7 @@ def _check_snapshot(self, document_fields): if document_fields.reference._path[:-1] != self._parent._path: raise ValueError("Cannot use snapshot from another collection as a cursor.") - def _cursor_helper(self, document_fields, before, start): + def _cursor_helper(self, document_fields, before, start) -> "BaseQuery": """Set values to be used for a ``start_at`` or ``end_at`` cursor. The values will later be used in a query protobuf. @@ -454,7 +469,7 @@ def _cursor_helper(self, document_fields, before, start): return self.__class__(self._parent, **query_kwargs) - def start_at(self, document_fields): + def start_at(self, document_fields) -> "BaseQuery": """Start query results at a particular document value. The result set will **include** the document specified by @@ -484,7 +499,7 @@ def start_at(self, document_fields): """ return self._cursor_helper(document_fields, before=True, start=True) - def start_after(self, document_fields): + def start_after(self, document_fields) -> "BaseQuery": """Start query results after a particular document value. The result set will **exclude** the document specified by @@ -513,7 +528,7 @@ def start_after(self, document_fields): """ return self._cursor_helper(document_fields, before=False, start=True) - def end_before(self, document_fields): + def end_before(self, document_fields) -> "BaseQuery": """End query results before a particular document value. The result set will **exclude** the document specified by @@ -542,7 +557,7 @@ def end_before(self, document_fields): """ return self._cursor_helper(document_fields, before=True, start=False) - def end_at(self, document_fields): + def end_at(self, document_fields) -> "BaseQuery": """End query results at a particular document value. The result set will **include** the document specified by @@ -571,7 +586,7 @@ def end_at(self, document_fields): """ return self._cursor_helper(document_fields, before=False, start=False) - def _filters_pb(self): + def _filters_pb(self) -> Any: """Convert all the filters into a single generic Filter protobuf. This may be a lone field filter or unary filter, may be a composite @@ -594,7 +609,7 @@ def _filters_pb(self): return query.StructuredQuery.Filter(composite_filter=composite_filter) @staticmethod - def _normalize_projection(projection): + def _normalize_projection(projection) -> Any: """Helper: convert field paths to message.""" if projection is not None: @@ -606,7 +621,7 @@ def _normalize_projection(projection): return projection - def _normalize_orders(self): + def _normalize_orders(self) -> list: """Helper: adjust orders based on cursors, where clauses.""" orders = list(self._orders) _has_snapshot_cursor = False @@ -640,7 +655,7 @@ def _normalize_orders(self): return orders - def _normalize_cursor(self, cursor, orders): + def _normalize_cursor(self, cursor, orders) -> Optional[Tuple[Any, Any]]: """Helper: convert cursor to a list of values based on orders.""" if cursor is None: return @@ -692,7 +707,7 @@ def _normalize_cursor(self, cursor, orders): return document_fields, before - def _to_protobuf(self): + def _to_protobuf(self) -> StructuredQuery: """Convert the current query into the equivalent protobuf. Returns: @@ -723,16 +738,16 @@ def _to_protobuf(self): return query.StructuredQuery(**query_kwargs) - def get(self, transaction=None): + def get(self, transaction=None) -> NoReturn: raise NotImplementedError - def stream(self, transaction=None): + def stream(self, transaction=None) -> NoReturn: raise NotImplementedError - def on_snapshot(self, callback): + def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError - def _comparator(self, doc1, doc2): + def _comparator(self, doc1, doc2) -> Any: _orders = self._orders # Add implicit sorting by name, using the last specified direction. @@ -779,7 +794,7 @@ def _comparator(self, doc1, doc2): return 0 -def _enum_from_op_string(op_string): +def _enum_from_op_string(op_string) -> Any: """Convert a string representation of a binary operator to an enum. These enums come from the protobuf message definition @@ -804,7 +819,7 @@ def _enum_from_op_string(op_string): raise ValueError(msg) -def _isnan(value): +def _isnan(value) -> bool: """Check if a value is NaN. This differs from ``math.isnan`` in that **any** input type is @@ -822,7 +837,7 @@ def _isnan(value): return False -def _enum_from_direction(direction): +def _enum_from_direction(direction) -> Any: """Convert a string representation of a direction to an enum. Args: @@ -850,7 +865,7 @@ def _enum_from_direction(direction): raise ValueError(msg) -def _filter_pb(field_or_unary): +def _filter_pb(field_or_unary) -> Any: """Convert a specific protobuf filter to the generic filter type. Args: @@ -874,7 +889,7 @@ def _filter_pb(field_or_unary): raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) -def _cursor_pb(cursor_pair): +def _cursor_pb(cursor_pair) -> Optional[Cursor]: """Convert a cursor pair to a protobuf. If ``cursor_pair`` is :data:`None`, just returns :data:`None`. @@ -895,7 +910,9 @@ def _cursor_pb(cursor_pair): return query.Cursor(values=value_pbs, before=before) -def _query_response_to_snapshot(response_pb, collection, expected_prefix): +def _query_response_to_snapshot( + response_pb, collection, expected_prefix +) -> Optional[document.DocumentSnapshot]: """Parse a query response protobuf to a document snapshot. Args: @@ -929,7 +946,9 @@ def _query_response_to_snapshot(response_pb, collection, expected_prefix): return snapshot -def _collection_group_query_response_to_snapshot(response_pb, collection): +def _collection_group_query_response_to_snapshot( + response_pb, collection +) -> Optional[document.DocumentSnapshot]: """Parse a query response protobuf to a document snapshot. Args: diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index f477fb0fe..b26eb3f5e 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -16,6 +16,18 @@ from google.cloud.firestore_v1 import types +from typing import NoReturn, Optional + +_CANT_BEGIN: str +_CANT_COMMIT: str +_CANT_RETRY_READ_ONLY: str +_CANT_ROLLBACK: str +_EXCEED_ATTEMPTS_TEMPLATE: str +_INITIAL_SLEEP: float +_MAX_SLEEP: float +_MISSING_ID_TEMPLATE: str +_MULTIPLIER: float +_WRITE_READ_ONLY: str MAX_ATTEMPTS = 5 """int: Default number of transaction attempts (with retries).""" @@ -46,15 +58,15 @@ class BaseTransaction(object): :data:`False`. """ - def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False): + def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: self._max_attempts = max_attempts self._read_only = read_only self._id = None - def _add_write_pbs(self, write_pbs): + def _add_write_pbs(self, write_pbs) -> NoReturn: raise NotImplementedError - def _options_protobuf(self, retry_id): + def _options_protobuf(self, retry_id) -> Optional[types.common.TransactionOptions]: """Convert the current object to protobuf. The ``retry_id`` value is used when retrying a transaction that @@ -109,7 +121,7 @@ def id(self): """ return self._id - def _clean_up(self): + def _clean_up(self) -> None: """Clean up the instance after :meth:`_rollback`` or :meth:`_commit``. This intended to occur on success or failure of the associated RPCs. @@ -117,19 +129,19 @@ def _clean_up(self): self._write_pbs = [] self._id = None - def _begin(self, retry_id=None): + def _begin(self, retry_id=None) -> NoReturn: raise NotImplementedError - def _rollback(self): + def _rollback(self) -> NoReturn: raise NotImplementedError - def _commit(self): + def _commit(self) -> NoReturn: raise NotImplementedError - def get_all(self, references): + def get_all(self, references) -> NoReturn: raise NotImplementedError - def get(self, ref_or_query): + def get(self, ref_or_query) -> NoReturn: raise NotImplementedError @@ -144,22 +156,22 @@ class _BaseTransactional(object): A callable that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap): + def __init__(self, to_wrap) -> None: self.to_wrap = to_wrap self.current_id = None """Optional[bytes]: The current transaction ID.""" self.retry_id = None """Optional[bytes]: The ID of the first attempted transaction.""" - def _reset(self): + def _reset(self) -> None: """Unset the transaction IDs.""" self.current_id = None self.retry_id = None - def _pre_commit(self, transaction, *args, **kwargs): + def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn: raise NotImplementedError - def _maybe_commit(self, transaction): + def _maybe_commit(self, transaction) -> NoReturn: raise NotImplementedError def __call__(self, transaction, *args, **kwargs): diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index 1c47ffb48..c4e5c7a6f 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -30,10 +30,10 @@ class WriteBatch(BaseWriteBatch): The client that created this batch. """ - def __init__(self, client): + def __init__(self, client) -> None: super(WriteBatch, self).__init__(client=client) - def commit(self): + def commit(self) -> list: """Commit the changes accumulated in this batch. Returns: diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 829c4285e..a2e2eb14e 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -44,6 +44,13 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc as firestore_grpc_transport, ) +from typing import Any, Generator + +_CLIENT_INFO: Any +_get_doc_mask: Any +_parse_batch_get: Any +_path_helper: Any +_reference_info: Any class Client(BaseClient): @@ -81,7 +88,7 @@ def __init__( database=DEFAULT_DATABASE, client_info=_CLIENT_INFO, client_options=None, - ): + ) -> None: super(Client, self).__init__( project=project, credentials=credentials, @@ -113,7 +120,7 @@ def _target(self): """ return self._target_helper(firestore_client.FirestoreClient) - def collection(self, *collection_path): + def collection(self, *collection_path) -> CollectionReference: """Get a reference to a collection. For a top-level collection: @@ -144,7 +151,7 @@ def collection(self, *collection_path): """ return CollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id): + def collection_group(self, collection_id) -> Query: """ Creates and returns a new Query that includes all documents in the database that are contained in a collection or subcollection with the @@ -168,7 +175,7 @@ def collection_group(self, collection_id): self._get_collection_reference(collection_id), all_descendants=True ) - def document(self, *document_path): + def document(self, *document_path) -> DocumentReference: """Get a reference to a document in a collection. For a top-level document: @@ -203,7 +210,9 @@ def document(self, *document_path): *self._document_path_helper(*document_path), client=self ) - def get_all(self, references, field_paths=None, transaction=None): + def get_all( + self, references, field_paths=None, transaction=None + ) -> Generator[Any, Any, None]: """Retrieve a batch of documents. .. note:: @@ -253,7 +262,7 @@ def get_all(self, references, field_paths=None, transaction=None): for get_doc_response in response_iterator: yield _parse_batch_get(get_doc_response, reference_map, self) - def collections(self): + def collections(self) -> Generator[Any, Any, None]: """List top-level collections of the client's database. Returns: @@ -286,7 +295,7 @@ def collections(self): # iterator.item_to_value = _item_to_collection_ref # return iterator - def batch(self): + def batch(self) -> WriteBatch: """Get a batch instance from this client. Returns: @@ -296,7 +305,7 @@ def batch(self): """ return WriteBatch(self) - def transaction(self, **kwargs): + def transaction(self, **kwargs) -> Transaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.transaction.Transaction` for diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 50b2ae453..67144b0f7 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -23,6 +23,7 @@ from google.cloud.firestore_v1 import query as query_mod from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import document +from typing import Any, Generator, Tuple class CollectionReference(BaseCollectionReference): @@ -51,10 +52,10 @@ class CollectionReference(BaseCollectionReference): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs): + def __init__(self, *path, **kwargs) -> None: super(CollectionReference, self).__init__(*path, **kwargs) - def _query(self): + def _query(self) -> query_mod.Query: """Query factory. Returns: @@ -62,7 +63,7 @@ def _query(self): """ return query_mod.Query(self) - def add(self, document_data, document_id=None): + def add(self, document_data, document_id=None) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. Args: @@ -93,7 +94,7 @@ def add(self, document_data, document_id=None): write_result = document_ref.create(document_data) return write_result.update_time, document_ref - def list_documents(self, page_size=None): + def list_documents(self, page_size=None) -> Generator[Any, Any, None]: """List all subdocuments of the current collection. Args: @@ -120,7 +121,7 @@ def list_documents(self, page_size=None): ) return (_item_to_document_ref(self, i) for i in iterator) - def get(self, transaction=None): + def get(self, transaction=None) -> Generator[document.DocumentSnapshot, Any, None]: """Deprecated alias for :meth:`stream`.""" warnings.warn( "'Collection.get' is deprecated: please use 'Collection.stream' instead.", @@ -129,7 +130,9 @@ def get(self, transaction=None): ) return self.stream(transaction=transaction) - def stream(self, transaction=None): + def stream( + self, transaction=None + ) -> Generator[document.DocumentSnapshot, Any, None]: """Read the documents in this collection. This sends a ``RunQuery`` RPC and then returns an iterator which @@ -159,7 +162,7 @@ def stream(self, transaction=None): query = query_mod.Query(self) return query.stream(transaction=transaction) - def on_snapshot(self, callback): + def on_snapshot(self, callback) -> Watch: """Monitor the documents in this collection. This starts a watch on this collection using a background thread. The diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 4d5d42aa4..f4f08ee71 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -20,10 +20,11 @@ _first_write_result, ) -from google.api_core import exceptions +from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.watch import Watch +from typing import Any, Generator class DocumentReference(BaseDocumentReference): @@ -51,10 +52,10 @@ class DocumentReference(BaseDocumentReference): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs): + def __init__(self, *path, **kwargs) -> None: super(DocumentReference, self).__init__(*path, **kwargs) - def create(self, document_data): + def create(self, document_data) -> Any: """Create the current document in the Firestore database. Args: @@ -75,7 +76,7 @@ def create(self, document_data): write_results = batch.commit() return _first_write_result(write_results) - def set(self, document_data, merge=False): + def set(self, document_data, merge=False) -> Any: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -106,7 +107,7 @@ def set(self, document_data, merge=False): write_results = batch.commit() return _first_write_result(write_results) - def update(self, field_updates, option=None): + def update(self, field_updates, option=None) -> Any: """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -254,7 +255,7 @@ def update(self, field_updates, option=None): write_results = batch.commit() return _first_write_result(write_results) - def delete(self, option=None): + def delete(self, option=None) -> Any: """Delete the current document in the Firestore database. Args: @@ -281,7 +282,7 @@ def delete(self, option=None): return commit_response.commit_time - def get(self, field_paths=None, transaction=None): + def get(self, field_paths=None, transaction=None) -> DocumentSnapshot: """Retrieve a snapshot of the current document. See :meth:`~google.cloud.firestore_v1.base_client.BaseClient.field_path` for @@ -346,7 +347,7 @@ def get(self, field_paths=None, transaction=None): update_time=update_time, ) - def collections(self, page_size=None): + def collections(self, page_size=None) -> Generator[Any, Any, None]: """List subcollections of the current document. Args: @@ -386,7 +387,7 @@ def collections(self, page_size=None): # iterator.item_to_value = _item_to_collection_ref # return iterator - def on_snapshot(self, callback): + def on_snapshot(self, callback) -> Watch: """Watch this document. This starts a watch on this document using a background thread. The diff --git a/google/cloud/firestore_v1/order.py b/google/cloud/firestore_v1/order.py index 427e797e8..5d1e3345d 100644 --- a/google/cloud/firestore_v1/order.py +++ b/google/cloud/firestore_v1/order.py @@ -15,6 +15,7 @@ from enum import Enum from google.cloud.firestore_v1._helpers import decode_value import math +from typing import Any class TypeOrder(Enum): @@ -31,7 +32,7 @@ class TypeOrder(Enum): OBJECT = 9 @staticmethod - def from_value(value): + def from_value(value) -> Any: v = value._pb.WhichOneof("value_type") lut = { @@ -59,7 +60,7 @@ class Order(object): """ @classmethod - def compare(cls, left, right): + def compare(cls, left, right) -> Any: """ Main comparison function for all Firestore types. @return -1 is left < right, 0 if left == right, otherwise 1 @@ -101,14 +102,14 @@ def compare(cls, left, right): raise ValueError(f"Unknown ``value_type`` {value_type}") @staticmethod - def compare_blobs(left, right): + def compare_blobs(left, right) -> Any: left_bytes = left.bytes_value right_bytes = right.bytes_value return Order._compare_to(left_bytes, right_bytes) @staticmethod - def compare_timestamps(left, right): + def compare_timestamps(left, right) -> Any: left = left._pb.timestamp_value right = right._pb.timestamp_value @@ -119,7 +120,7 @@ def compare_timestamps(left, right): return Order._compare_to(left.nanos or 0, right.nanos or 0) @staticmethod - def compare_geo_points(left, right): + def compare_geo_points(left, right) -> Any: left_value = decode_value(left, None) right_value = decode_value(right, None) cmp = (left_value.latitude > right_value.latitude) - ( @@ -133,7 +134,7 @@ def compare_geo_points(left, right): ) @staticmethod - def compare_resource_paths(left, right): + def compare_resource_paths(left, right) -> int: left = left.reference_value right = right.reference_value @@ -152,7 +153,7 @@ def compare_resource_paths(left, right): return (left_length > right_length) - (left_length < right_length) @staticmethod - def compare_arrays(left, right): + def compare_arrays(left, right) -> Any: l_values = left.array_value.values r_values = right.array_value.values @@ -165,7 +166,7 @@ def compare_arrays(left, right): return Order._compare_to(len(l_values), len(r_values)) @staticmethod - def compare_objects(left, right): + def compare_objects(left, right) -> Any: left_fields = left.map_value.fields right_fields = right.map_value.fields @@ -183,13 +184,13 @@ def compare_objects(left, right): return Order._compare_to(len(left_fields), len(right_fields)) @staticmethod - def compare_numbers(left, right): + def compare_numbers(left, right) -> Any: left_value = decode_value(left, None) right_value = decode_value(right, None) return Order.compare_doubles(left_value, right_value) @staticmethod - def compare_doubles(left, right): + def compare_doubles(left, right) -> Any: if math.isnan(left): if math.isnan(right): return 0 @@ -200,7 +201,7 @@ def compare_doubles(left, right): return Order._compare_to(left, right) @staticmethod - def _compare_to(left, right): + def _compare_to(left, right) -> Any: # We can't just use cmp(left, right) because cmp doesn't exist # in Python 3, so this is an equivalent suggested by # https://docs.python.org/3.0/whatsnew/3.0.html#ordering-comparisons diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 90996b8a4..4523cc71b 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -29,6 +29,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch +from typing import Any, Generator class Query(BaseQuery): @@ -98,7 +99,7 @@ def __init__( start_at=None, end_at=None, all_descendants=False, - ): + ) -> None: super(Query, self).__init__( parent=parent, projection=projection, @@ -111,7 +112,7 @@ def __init__( all_descendants=all_descendants, ) - def get(self, transaction=None): + def get(self, transaction=None) -> Generator[document.DocumentSnapshot, Any, None]: """Deprecated alias for :meth:`stream`.""" warnings.warn( "'Query.get' is deprecated: please use 'Query.stream' instead.", @@ -120,7 +121,9 @@ def get(self, transaction=None): ) return self.stream(transaction=transaction) - def stream(self, transaction=None): + def stream( + self, transaction=None + ) -> Generator[document.DocumentSnapshot, Any, None]: """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and then returns an iterator which @@ -169,7 +172,7 @@ def stream(self, transaction=None): if snapshot is not None: yield snapshot - def on_snapshot(self, callback): + def on_snapshot(self, callback) -> Watch: """Monitor the documents in this collection that match this query. This starts a watch on this query using a background thread. The diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index 87edcbcda..857997f44 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -18,7 +18,7 @@ import abc import typing -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.auth import credentials # type: ignore diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index cfe396c74..93a91099c 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -32,10 +32,20 @@ _EXCEED_ATTEMPTS_TEMPLATE, ) -from google.api_core import exceptions +from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import batch from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.query import Query +from typing import Any, Optional + +_CANT_BEGIN: str +_CANT_COMMIT: str +_CANT_ROLLBACK: str +_EXCEED_ATTEMPTS_TEMPLATE: str +_INITIAL_SLEEP: float +_MAX_SLEEP: float +_MULTIPLIER: float +_WRITE_READ_ONLY: str class Transaction(batch.WriteBatch, BaseTransaction): @@ -52,11 +62,11 @@ class Transaction(batch.WriteBatch, BaseTransaction): :data:`False`. """ - def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: super(Transaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs): + def _add_write_pbs(self, write_pbs) -> None: """Add `Write`` protobufs to this transaction. Args: @@ -71,7 +81,7 @@ def _add_write_pbs(self, write_pbs): super(Transaction, self)._add_write_pbs(write_pbs) - def _begin(self, retry_id=None): + def _begin(self, retry_id=None) -> None: """Begin the transaction. Args: @@ -94,7 +104,7 @@ def _begin(self, retry_id=None): ) self._id = transaction_response.transaction - def _rollback(self): + def _rollback(self) -> None: """Roll back the transaction. Raises: @@ -115,7 +125,7 @@ def _rollback(self): finally: self._clean_up() - def _commit(self): + def _commit(self) -> list: """Transactionally commit the changes accumulated. Returns: @@ -135,7 +145,7 @@ def _commit(self): self._clean_up() return list(commit_response.write_results) - def get_all(self, references): + def get_all(self, references) -> Any: """Retrieves multiple documents from Firestore. Args: @@ -148,7 +158,7 @@ def get_all(self, references): """ return self._client.get_all(references, transaction=self) - def get(self, ref_or_query): + def get(self, ref_or_query) -> Any: """ Retrieve a document or a query result from the database. Args: @@ -178,10 +188,10 @@ class _Transactional(_BaseTransactional): A callable that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap): + def __init__(self, to_wrap) -> None: super(_Transactional, self).__init__(to_wrap) - def _pre_commit(self, transaction, *args, **kwargs): + def _pre_commit(self, transaction, *args, **kwargs) -> Any: """Begin transaction and call the wrapped callable. If the callable raises an exception, the transaction will be rolled @@ -219,7 +229,7 @@ def _pre_commit(self, transaction, *args, **kwargs): transaction._rollback() raise - def _maybe_commit(self, transaction): + def _maybe_commit(self, transaction) -> Optional[bool]: """Try to commit the transaction. If the transaction is read-write and the ``Commit`` fails with the @@ -285,7 +295,7 @@ def __call__(self, transaction, *args, **kwargs): raise ValueError(msg) -def transactional(to_wrap): +def transactional(to_wrap) -> _Transactional: """Decorate a callable so that it runs in a transaction. Args: @@ -300,7 +310,7 @@ def transactional(to_wrap): return _Transactional(to_wrap) -def _commit_with_retry(client, write_pbs, transaction_id): +def _commit_with_retry(client, write_pbs, transaction_id) -> Any: """Call ``Commit`` on the GAPIC client with retry / sleep. Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level @@ -343,7 +353,7 @@ def _commit_with_retry(client, write_pbs, transaction_id): current_sleep = _sleep(current_sleep) -def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER): +def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER) -> Any: """Sleep and produce a new sleep time. .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ diff --git a/google/cloud/firestore_v1/transforms.py b/google/cloud/firestore_v1/transforms.py index ea2eeec9a..e9aa87606 100644 --- a/google/cloud/firestore_v1/transforms.py +++ b/google/cloud/firestore_v1/transforms.py @@ -20,7 +20,7 @@ class Sentinel(object): __slots__ = ("description",) - def __init__(self, description): + def __init__(self, description) -> None: self.description = description def __repr__(self): @@ -44,7 +44,7 @@ class _ValueList(object): slots = ("_values",) - def __init__(self, values): + def __init__(self, values) -> None: if not isinstance(values, (list, tuple)): raise ValueError("'values' must be a list or tuple.") @@ -97,7 +97,7 @@ class _NumericValue(object): value (int | float): value held in the helper. """ - def __init__(self, value): + def __init__(self, value) -> None: if not isinstance(value, (int, float)): raise ValueError("Pass an integer / float value.") diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index 137c3130a..465a2d92e 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -68,6 +68,54 @@ BatchWriteRequest, BatchWriteResponse, ) +from typing import Tuple + + +__all__: Tuple[ + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, + str, +] __all__ = ( diff --git a/google/cloud/firestore_v1/types/common.py b/google/cloud/firestore_v1/types/common.py index b03242a4a..f7bd22a3d 100644 --- a/google/cloud/firestore_v1/types/common.py +++ b/google/cloud/firestore_v1/types/common.py @@ -19,6 +19,9 @@ from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from typing import Any + +__protobuf__: Any __protobuf__ = proto.module( diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 7104bfc61..b2111b34f 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -21,6 +21,9 @@ from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.type import latlng_pb2 as latlng # type: ignore +from typing import Any + +__protobuf__: Any __protobuf__ = proto.module( diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index cb0fa75dc..909a782c8 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -24,6 +24,9 @@ from google.cloud.firestore_v1.types import write from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as gr_status # type: ignore +from typing import Any + +__protobuf__: Any __protobuf__ = proto.module( diff --git a/google/cloud/firestore_v1/types/query.py b/google/cloud/firestore_v1/types/query.py index a65b0191b..bea9a10a5 100644 --- a/google/cloud/firestore_v1/types/query.py +++ b/google/cloud/firestore_v1/types/query.py @@ -20,6 +20,9 @@ from google.cloud.firestore_v1.types import document from google.protobuf import wrappers_pb2 as wrappers # type: ignore +from typing import Any + +__protobuf__: Any __protobuf__ = proto.module( diff --git a/google/cloud/firestore_v1/types/write.py b/google/cloud/firestore_v1/types/write.py index 6b3f49b53..12cdf99b6 100644 --- a/google/cloud/firestore_v1/types/write.py +++ b/google/cloud/firestore_v1/types/write.py @@ -21,6 +21,9 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document as gf_document from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from typing import Any + +__protobuf__: Any __protobuf__ = proto.module( diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index d3499e649..466821bb5 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -18,14 +18,14 @@ from enum import Enum import functools -from google.api_core.bidi import ResumableBidiRpc -from google.api_core.bidi import BackgroundConsumer +from google.api_core.bidi import ResumableBidiRpc # type: ignore +from google.api_core.bidi import BackgroundConsumer # type: ignore from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1 import _helpers -from google.api_core import exceptions +from google.api_core import exceptions # type: ignore -import grpc +import grpc # type: ignore """Python client for Google Cloud Firestore Watch.""" diff --git a/noxfile.py b/noxfile.py index 55f2da88e..82daad6af 100644 --- a/noxfile.py +++ b/noxfile.py @@ -22,7 +22,7 @@ import nox - +PYTYPE_VERSION = "pytype==2020.7.24" BLACK_VERSION = "black==19.10b0" BLACK_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] @@ -61,6 +61,14 @@ def blacken(session): ) +@nox.session(python="3.7") +def pytype(session): + """Run pytype + """ + session.install(PYTYPE_VERSION) + session.run("pytype",) + + @nox.session(python=DEFAULT_PYTHON_VERSION) def lint_setup_py(session): """Verify that setup.py is valid (including RST check).""" diff --git a/setup.cfg b/setup.cfg index c3a2b39f6..f0c722b1e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,3 +17,14 @@ # Generated by synthtool. DO NOT EDIT! [bdist_wheel] universal = 1 + +[pytype] +python_version = 3.8 +inputs = + google/cloud/ +exclude = + tests/ +output = .pytype/ +# Workaround for https://github.com/google/pytype/issues/150 +disable = pyi-error +