From d30fff8e42621d42d169e354948c26ee3e0d16f0 Mon Sep 17 00:00:00 2001 From: Christopher Wilcox Date: Tue, 18 Aug 2020 21:45:09 -0700 Subject: [PATCH] fix: type hint improvements (#144) --- google/cloud/firestore.py | 5 +- google/cloud/firestore_v1/__init__.py | 6 +- google/cloud/firestore_v1/_helpers.py | 1 - google/cloud/firestore_v1/async_client.py | 6 +- google/cloud/firestore_v1/async_document.py | 6 +- .../cloud/firestore_v1/async_transaction.py | 14 +---- google/cloud/firestore_v1/base_client.py | 61 ++++++++++++------- google/cloud/firestore_v1/base_collection.py | 57 ++++++++++++----- google/cloud/firestore_v1/base_document.py | 2 +- google/cloud/firestore_v1/base_transaction.py | 25 ++++---- google/cloud/firestore_v1/client.py | 6 -- google/cloud/firestore_v1/transaction.py | 9 --- 12 files changed, 106 insertions(+), 92 deletions(-) diff --git a/google/cloud/firestore.py b/google/cloud/firestore.py index 8484b110a..904aedc00 100644 --- a/google/cloud/firestore.py +++ b/google/cloud/firestore.py @@ -48,11 +48,8 @@ from google.cloud.firestore_v1 import WriteOption from typing import List -__all__: List[str] -__version__: str - -__all__ = [ +__all__: List[str] = [ "__version__", "ArrayRemove", "ArrayUnion", diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index 684bdcd3a..23588e4a8 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -22,7 +22,6 @@ __version__ = get_distribution("google-cloud-firestore").version - from google.cloud.firestore_v1 import types from google.cloud.firestore_v1._helpers import GeoPoint from google.cloud.firestore_v1._helpers import ExistsOption @@ -99,15 +98,12 @@ from .types.write import DocumentTransform from typing import List -__all__: List[str] -__version__: str # from .types.write import ExistenceFilter # from .types.write import Write # from .types.write import WriteResult - -__all__ = [ +__all__: List[str] = [ "__version__", "ArrayRemove", "ArrayUnion", diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 77ae74d1f..f9f01e7b9 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -35,7 +35,6 @@ _EmptyDict: transforms.Sentinel _GRPC_ERROR_MAPPING: dict -_datetime_to_pb_timestamp: Any BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}." diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 44e07f272..9cdab62b4 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -49,9 +49,7 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) -from typing import Any, AsyncGenerator, NoReturn - -_CLIENT_INFO: Any +from typing import Any, AsyncGenerator class AsyncClient(BaseClient): @@ -152,7 +150,7 @@ def collection(self, *collection_path) -> AsyncCollectionReference: """ return AsyncCollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id) -> NoReturn: + def collection_group(self, collection_id) -> AsyncQuery: """ Creates and returns a new AsyncQuery that includes all documents in the database that are contained in a collection or subcollection with the diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index f387707c9..d33b76a46 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -23,7 +23,7 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common -from typing import AsyncGenerator, Coroutine +from typing import Any, AsyncGenerator, Coroutine, Union class AsyncDocumentReference(BaseDocumentReference): @@ -281,7 +281,9 @@ async def delete(self, option=None) -> Coroutine: return commit_response.commit_time - async def get(self, field_paths=None, transaction=None) -> DocumentSnapshot: + async def get( + self, field_paths=None, transaction=None + ) -> Union[DocumentSnapshot, Coroutine[Any, Any, DocumentSnapshot]]: """Retrieve a snapshot of the current document. See :meth:`~google.cloud.firestore_v1.base_client.BaseClient.field_path` for diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 4793e216c..0a1f6a936 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -37,17 +37,9 @@ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_document import DocumentSnapshot from google.cloud.firestore_v1.async_query import AsyncQuery -from typing import Coroutine - -_CANT_BEGIN: str -_CANT_COMMIT: str -_CANT_ROLLBACK: str -_EXCEED_ATTEMPTS_TEMPLATE: str -_INITIAL_SLEEP: float -_MAX_SLEEP: float -_MULTIPLIER: float -_WRITE_READ_ONLY: str +from typing import Any, AsyncGenerator, Coroutine class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): @@ -162,7 +154,7 @@ async def get_all(self, references) -> Coroutine: """ return await self._client.get_all(references, transaction=self) - async def get(self, ref_or_query) -> Coroutine: + async def get(self, ref_or_query) -> AsyncGenerator[DocumentSnapshot, Any]: """ Retrieve a document or a query result from the database. Args: diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 06ec6b8e2..8ad6d1441 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -23,6 +23,7 @@ * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.document.DocumentReference` """ + import os import google.api_core.client_options # type: ignore @@ -34,29 +35,38 @@ from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.field_path import render_field_path -from typing import Any, List, NoReturn, Optional, Tuple, Union +from typing import ( + Any, + AsyncGenerator, + Generator, + List, + Optional, + Tuple, + Union, +) + +# Types needed only for Type Hints +from google.cloud.firestore_v1.base_collection import BaseCollectionReference +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.base_transaction import BaseTransaction +from google.cloud.firestore_v1.base_batch import BaseWriteBatch +from google.cloud.firestore_v1.base_query import BaseQuery -_ACTIVE_TXN: str -_BAD_DOC_TEMPLATE: str -_BAD_OPTION_ERR: str -_CLIENT_INFO: Any -_FIRESTORE_EMULATOR_HOST: str -_INACTIVE_TXN: str -__version__: str DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" _BAD_OPTION_ERR = ( "Exactly one of ``last_update_time`` or ``exists`` " "must be provided." ) -_BAD_DOC_TEMPLATE = ( +_BAD_DOC_TEMPLATE: str = ( "Document {!r} appeared in response but was not present among references" ) -_ACTIVE_TXN = "There is already an active transaction." -_INACTIVE_TXN = "There is no active transaction." -_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) -_FIRESTORE_EMULATOR_HOST = "FIRESTORE_EMULATOR_HOST" +_ACTIVE_TXN: str = "There is already an active transaction." +_INACTIVE_TXN: str = "There is no active transaction." +_CLIENT_INFO: Any = client_info.ClientInfo(client_library_version=__version__) +_FIRESTORE_EMULATOR_HOST: str = "FIRESTORE_EMULATOR_HOST" class BaseClient(ClientWithProject): @@ -214,13 +224,13 @@ def _rpc_metadata(self): return self._rpc_metadata_internal - def collection(self, *collection_path) -> NoReturn: + def collection(self, *collection_path) -> BaseCollectionReference: raise NotImplementedError - def collection_group(self, collection_id) -> NoReturn: + def collection_group(self, collection_id) -> BaseQuery: raise NotImplementedError - def _get_collection_reference(self, collection_id) -> NoReturn: + def _get_collection_reference(self, collection_id) -> BaseCollectionReference: """Checks validity of collection_id and then uses subclasses collection implementation. Args: @@ -241,7 +251,7 @@ def _get_collection_reference(self, collection_id) -> NoReturn: return self.collection(collection_id) - def document(self, *document_path) -> NoReturn: + def document(self, *document_path) -> BaseDocumentReference: raise NotImplementedError def _document_path_helper(self, *document_path) -> List[str]: @@ -342,16 +352,25 @@ def write_option( extra = "{!r} was provided".format(name) raise TypeError(_BAD_OPTION_ERR, extra) - def get_all(self, references, field_paths=None, transaction=None) -> NoReturn: + def get_all( + self, references, field_paths=None, transaction=None + ) -> Union[ + AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any] + ]: raise NotImplementedError - def collections(self) -> NoReturn: + def collections( + self, + ) -> Union[ + AsyncGenerator[BaseCollectionReference, Any], + Generator[BaseCollectionReference, Any, Any], + ]: raise NotImplementedError - def batch(self) -> NoReturn: + def batch(self) -> BaseWriteBatch: raise NotImplementedError - def transaction(self, **kwargs) -> NoReturn: + def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 0c2fe0e94..67dfc36d5 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -17,8 +17,21 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.document import DocumentReference -from typing import Any, NoReturn, Tuple - +from typing import ( + Any, + AsyncGenerator, + Coroutine, + Generator, + AsyncIterator, + Iterator, + NoReturn, + Tuple, + Union, +) + +# Types needed only for Type Hints +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_query import BaseQuery _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -87,7 +100,7 @@ def parent(self): parent_path = self._path[:-1] return self._client.document(*parent_path) - def _query(self) -> NoReturn: + def _query(self) -> BaseQuery: raise NotImplementedError def document(self, document_id=None) -> Any: @@ -131,13 +144,19 @@ def _parent_info(self) -> Tuple[Any, str]: expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) return parent_path, expected_prefix - def add(self, document_data, document_id=None) -> NoReturn: + def add( + self, document_data, document_id=None + ) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]: raise NotImplementedError - def list_documents(self, page_size=None) -> NoReturn: + def list_documents( + self, page_size=None + ) -> Union[ + Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] + ]: raise NotImplementedError - def select(self, field_paths) -> NoReturn: + def select(self, field_paths) -> BaseQuery: """Create a "select" query with this collection as parent. See @@ -156,7 +175,7 @@ def select(self, field_paths) -> NoReturn: query = self._query() return query.select(field_paths) - def where(self, field_path, op_string, value) -> NoReturn: + def where(self, field_path, op_string, value) -> BaseQuery: """Create a "where" query with this collection as parent. See @@ -180,7 +199,7 @@ def where(self, field_path, op_string, value) -> NoReturn: query = self._query() return query.where(field_path, op_string, value) - def order_by(self, field_path, **kwargs) -> NoReturn: + def order_by(self, field_path, **kwargs) -> BaseQuery: """Create an "order by" query with this collection as parent. See @@ -202,7 +221,7 @@ def order_by(self, field_path, **kwargs) -> NoReturn: query = self._query() return query.order_by(field_path, **kwargs) - def limit(self, count) -> NoReturn: + def limit(self, count) -> BaseQuery: """Create a limited query with this collection as parent. .. note:: @@ -242,7 +261,7 @@ def limit_to_last(self, count): query = self._query() return query.limit_to_last(count) - def offset(self, num_to_skip) -> NoReturn: + def offset(self, num_to_skip) -> BaseQuery: """Skip to an offset in a query with this collection as parent. See @@ -260,7 +279,7 @@ def offset(self, num_to_skip) -> NoReturn: query = self._query() return query.offset(num_to_skip) - def start_at(self, document_fields) -> NoReturn: + def start_at(self, document_fields) -> BaseQuery: """Start query at a cursor with this collection as parent. See @@ -281,7 +300,7 @@ def start_at(self, document_fields) -> NoReturn: query = self._query() return query.start_at(document_fields) - def start_after(self, document_fields) -> NoReturn: + def start_after(self, document_fields) -> BaseQuery: """Start query after a cursor with this collection as parent. See @@ -302,7 +321,7 @@ def start_after(self, document_fields) -> NoReturn: query = self._query() return query.start_after(document_fields) - def end_before(self, document_fields) -> NoReturn: + def end_before(self, document_fields) -> BaseQuery: """End query before a cursor with this collection as parent. See @@ -323,7 +342,7 @@ def end_before(self, document_fields) -> NoReturn: query = self._query() return query.end_before(document_fields) - def end_at(self, document_fields) -> NoReturn: + def end_at(self, document_fields) -> BaseQuery: """End query at a cursor with this collection as parent. See @@ -344,10 +363,16 @@ def end_at(self, document_fields) -> NoReturn: query = self._query() return query.end_at(document_fields) - def get(self, transaction=None) -> NoReturn: + def get( + self, transaction=None + ) -> Union[ + Generator[DocumentSnapshot, Any, Any], AsyncGenerator[DocumentSnapshot, Any] + ]: raise NotImplementedError - def stream(self, transaction=None) -> NoReturn: + def stream( + self, transaction=None + ) -> Union[Iterator[DocumentSnapshot], AsyncIterator[DocumentSnapshot]]: raise NotImplementedError def on_snapshot(self, callback) -> NoReturn: diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index c0a81d739..f11546cac 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -190,7 +190,7 @@ def update(self, field_updates, option=None) -> NoReturn: def delete(self, option=None) -> NoReturn: raise NotImplementedError - def get(self, field_paths=None, transaction=None) -> NoReturn: + def get(self, field_paths=None, transaction=None) -> "DocumentSnapshot": raise NotImplementedError def collections(self, page_size=None) -> NoReturn: diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index b26eb3f5e..9f2eff0ec 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -16,7 +16,7 @@ from google.cloud.firestore_v1 import types -from typing import NoReturn, Optional +from typing import Any, Coroutine, NoReturn, Optional, Union _CANT_BEGIN: str _CANT_COMMIT: str @@ -29,21 +29,22 @@ _MULTIPLIER: float _WRITE_READ_ONLY: str + MAX_ATTEMPTS = 5 """int: Default number of transaction attempts (with retries).""" -_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." -_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." -_CANT_ROLLBACK = _MISSING_ID_TEMPLATE.format("rolled back") -_CANT_COMMIT = _MISSING_ID_TEMPLATE.format("committed") -_WRITE_READ_ONLY = "Cannot perform write operation in read-only transaction." -_INITIAL_SLEEP = 1.0 +_CANT_BEGIN: str = "The transaction has already begun. Current transaction ID: {!r}." +_MISSING_ID_TEMPLATE: str = "The transaction has no transaction ID, so it cannot be {}." +_CANT_ROLLBACK: str = _MISSING_ID_TEMPLATE.format("rolled back") +_CANT_COMMIT: str = _MISSING_ID_TEMPLATE.format("committed") +_WRITE_READ_ONLY: str = "Cannot perform write operation in read-only transaction." +_INITIAL_SLEEP: float = 1.0 """float: Initial "max" for sleep interval. To be used in :func:`_sleep`.""" -_MAX_SLEEP = 30.0 +_MAX_SLEEP: float = 30.0 """float: Eventual "max" sleep time. To be used in :func:`_sleep`.""" -_MULTIPLIER = 2.0 +_MULTIPLIER: float = 2.0 """float: Multiplier for exponential backoff. To be used in :func:`_sleep`.""" -_EXCEED_ATTEMPTS_TEMPLATE = "Failed to commit transaction in {:d} attempts." -_CANT_RETRY_READ_ONLY = "Only read-write transactions can be retried." +_EXCEED_ATTEMPTS_TEMPLATE: str = "Failed to commit transaction in {:d} attempts." +_CANT_RETRY_READ_ONLY: str = "Only read-write transactions can be retried." class BaseTransaction(object): @@ -135,7 +136,7 @@ def _begin(self, retry_id=None) -> NoReturn: def _rollback(self) -> NoReturn: raise NotImplementedError - def _commit(self) -> NoReturn: + def _commit(self) -> Union[list, Coroutine[Any, Any, list]]: raise NotImplementedError def get_all(self, references) -> NoReturn: diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index a2e2eb14e..30d6bd1cd 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -46,12 +46,6 @@ ) from typing import Any, Generator -_CLIENT_INFO: Any -_get_doc_mask: Any -_parse_batch_get: Any -_path_helper: Any -_reference_info: Any - class Client(BaseClient): """Client for interacting with Google Cloud Firestore API. diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 93a91099c..a93f3c62e 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -38,15 +38,6 @@ from google.cloud.firestore_v1.query import Query from typing import Any, Optional -_CANT_BEGIN: str -_CANT_COMMIT: str -_CANT_ROLLBACK: str -_EXCEED_ATTEMPTS_TEMPLATE: str -_INITIAL_SLEEP: float -_MAX_SLEEP: float -_MULTIPLIER: float -_WRITE_READ_ONLY: str - class Transaction(batch.WriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction.