From 9b6c2f33351c65901ea648e4407b2817e5e70957 Mon Sep 17 00:00:00 2001 From: HemangChothani <50404902+HemangChothani@users.noreply.github.com> Date: Fri, 9 Oct 2020 13:03:11 -0400 Subject: [PATCH] feat: add type hints for method params (#182) Co-authored-by: Christopher Wilcox --- google/cloud/firestore_v1/async_client.py | 10 ++-- google/cloud/firestore_v1/async_collection.py | 13 +++-- google/cloud/firestore_v1/async_document.py | 16 +++--- google/cloud/firestore_v1/async_query.py | 7 ++- .../cloud/firestore_v1/async_transaction.py | 29 +++++++---- google/cloud/firestore_v1/base_batch.py | 26 ++++++++-- google/cloud/firestore_v1/base_client.py | 28 ++++++---- google/cloud/firestore_v1/base_collection.py | 40 +++++++++------ google/cloud/firestore_v1/base_document.py | 28 +++++----- google/cloud/firestore_v1/base_query.py | 51 ++++++++++++------- google/cloud/firestore_v1/base_transaction.py | 6 ++- google/cloud/firestore_v1/client.py | 13 +++-- google/cloud/firestore_v1/collection.py | 15 +++--- google/cloud/firestore_v1/document.py | 16 +++--- google/cloud/firestore_v1/field_path.py | 15 +++--- google/cloud/firestore_v1/query.py | 4 +- google/cloud/firestore_v1/transaction.py | 22 ++++---- 17 files changed, 214 insertions(+), 125 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index dafd1a28d..b1376170e 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -49,7 +49,7 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) -from typing import Any, AsyncGenerator +from typing import Any, AsyncGenerator, Iterable, Tuple class AsyncClient(BaseClient): @@ -119,7 +119,7 @@ def _target(self): """ return self._target_helper(firestore_client.FirestoreAsyncClient) - def collection(self, *collection_path) -> AsyncCollectionReference: + def collection(self, *collection_path: Tuple[str]) -> AsyncCollectionReference: """Get a reference to a collection. For a top-level collection: @@ -150,7 +150,7 @@ def collection(self, *collection_path) -> AsyncCollectionReference: """ return AsyncCollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id) -> AsyncCollectionGroup: + def collection_group(self, collection_id: str) -> AsyncCollectionGroup: """ Creates and returns a new AsyncQuery that includes all documents in the database that are contained in a collection or subcollection with the @@ -172,7 +172,7 @@ def collection_group(self, collection_id) -> AsyncCollectionGroup: """ return AsyncCollectionGroup(self._get_collection_reference(collection_id)) - def document(self, *document_path) -> AsyncDocumentReference: + def document(self, *document_path: Tuple[str]) -> AsyncDocumentReference: """Get a reference to a document in a collection. For a top-level document: @@ -208,7 +208,7 @@ def document(self, *document_path) -> AsyncDocumentReference: ) async def get_all( - self, references, field_paths=None, transaction=None + self, references: list, field_paths: Iterable[str] = None, transaction=None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieve a batch of documents. diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 2a37353fd..f0d41985b 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -28,6 +28,9 @@ from typing import AsyncIterator from typing import Any, AsyncGenerator, Tuple +# Types needed only for Type Hints +from google.cloud.firestore_v1.transaction import Transaction + class AsyncCollectionReference(BaseCollectionReference): """A reference to a collection in a Firestore database. @@ -66,7 +69,9 @@ def _query(self) -> async_query.AsyncQuery: """ return async_query.AsyncQuery(self) - async def add(self, document_data, document_id=None) -> Tuple[Any, Any]: + async def add( + self, document_data: dict, document_id: str = None + ) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. Args: @@ -98,7 +103,7 @@ async def add(self, document_data, document_id=None) -> Tuple[Any, Any]: return write_result.update_time, document_ref async def list_documents( - self, page_size=None + self, page_size: int = None ) -> AsyncGenerator[DocumentReference, None]: """List all subdocuments of the current collection. @@ -127,7 +132,7 @@ async def list_documents( async for i in iterator: yield _item_to_document_ref(self, i) - async def get(self, transaction=None) -> list: + async def get(self, transaction: Transaction = None) -> list: """Read the documents in this collection. This sends a ``RunQuery`` RPC and returns a list of documents @@ -149,7 +154,7 @@ async def get(self, transaction=None) -> list: return await query.get(transaction=transaction) async def stream( - self, transaction=None + self, transaction: Transaction = None ) -> AsyncIterator[async_document.DocumentSnapshot]: """Read the documents in this collection. diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index d33b76a46..064797f6d 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -23,7 +23,7 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common -from typing import Any, AsyncGenerator, Coroutine, Union +from typing import Any, AsyncGenerator, Coroutine, Iterable, Union class AsyncDocumentReference(BaseDocumentReference): @@ -54,7 +54,7 @@ class AsyncDocumentReference(BaseDocumentReference): def __init__(self, *path, **kwargs) -> None: super(AsyncDocumentReference, self).__init__(*path, **kwargs) - async def create(self, document_data) -> Coroutine: + async def create(self, document_data: dict) -> Coroutine: """Create the current document in the Firestore database. Args: @@ -75,7 +75,7 @@ async def create(self, document_data) -> Coroutine: write_results = await batch.commit() return _first_write_result(write_results) - async def set(self, document_data, merge=False) -> Coroutine: + async def set(self, document_data: dict, merge: bool = False) -> Coroutine: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -106,7 +106,9 @@ async def set(self, document_data, merge=False) -> Coroutine: write_results = await batch.commit() return _first_write_result(write_results) - async def update(self, field_updates, option=None) -> Coroutine: + async def update( + self, field_updates: dict, option: _helpers.WriteOption = None + ) -> Coroutine: """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -254,7 +256,7 @@ async def update(self, field_updates, option=None) -> Coroutine: write_results = await batch.commit() return _first_write_result(write_results) - async def delete(self, option=None) -> Coroutine: + async def delete(self, option: _helpers.WriteOption = None) -> Coroutine: """Delete the current document in the Firestore database. Args: @@ -282,7 +284,7 @@ async def delete(self, option=None) -> Coroutine: return commit_response.commit_time async def get( - self, field_paths=None, transaction=None + self, field_paths: Iterable[str] = None, transaction=None ) -> Union[DocumentSnapshot, Coroutine[Any, Any, DocumentSnapshot]]: """Retrieve a snapshot of the current document. @@ -348,7 +350,7 @@ async def get( update_time=update_time, ) - async def collections(self, page_size=None) -> AsyncGenerator: + async def collections(self, page_size: int = None) -> AsyncGenerator: """List subcollections of the current document. Args: diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 8c5302db7..2750f290f 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -31,6 +31,9 @@ from google.cloud.firestore_v1 import async_document from typing import AsyncGenerator +# Types needed only for Type Hints +from google.cloud.firestore_v1.transaction import Transaction + class AsyncQuery(BaseQuery): """Represents a query to the Firestore API. @@ -114,7 +117,7 @@ def __init__( all_descendants=all_descendants, ) - async def get(self, transaction=None) -> list: + async def get(self, transaction: Transaction = None) -> list: """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and returns a list of documents @@ -154,7 +157,7 @@ async def get(self, transaction=None) -> list: return result async def stream( - self, transaction=None + self, transaction: Transaction = None ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: """Read the documents in the collection that match this query. diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 0a1f6a936..81316b8e6 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -39,7 +39,10 @@ from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_document import DocumentSnapshot from google.cloud.firestore_v1.async_query import AsyncQuery -from typing import Any, AsyncGenerator, Coroutine +from typing import Any, AsyncGenerator, Callable, Coroutine + +# Types needed only for Type Hints +from google.cloud.firestore_v1.client import Client class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): @@ -60,7 +63,7 @@ def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: super(AsyncTransaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs) -> None: + def _add_write_pbs(self, write_pbs: list) -> None: """Add `Write`` protobufs to this transaction. Args: @@ -75,7 +78,7 @@ def _add_write_pbs(self, write_pbs) -> None: super(AsyncTransaction, self)._add_write_pbs(write_pbs) - async def _begin(self, retry_id=None) -> None: + async def _begin(self, retry_id: bytes = None) -> None: """Begin the transaction. Args: @@ -141,7 +144,7 @@ async def _commit(self) -> list: self._clean_up() return list(commit_response.write_results) - async def get_all(self, references) -> Coroutine: + async def get_all(self, references: list) -> Coroutine: """Retrieves multiple documents from Firestore. Args: @@ -187,7 +190,9 @@ class _AsyncTransactional(_BaseTransactional): def __init__(self, to_wrap) -> None: super(_AsyncTransactional, self).__init__(to_wrap) - async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine: + async def _pre_commit( + self, transaction: AsyncTransaction, *args, **kwargs + ) -> Coroutine: """Begin transaction and call the wrapped coroutine. If the coroutine raises an exception, the transaction will be rolled @@ -225,7 +230,7 @@ async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine: await transaction._rollback() raise - async def _maybe_commit(self, transaction) -> bool: + async def _maybe_commit(self, transaction: AsyncTransaction) -> bool: """Try to commit the transaction. If the transaction is read-write and the ``Commit`` fails with the @@ -291,7 +296,9 @@ async def __call__(self, transaction, *args, **kwargs): raise ValueError(msg) -def async_transactional(to_wrap) -> _AsyncTransactional: +def async_transactional( + to_wrap: Callable[[AsyncTransaction], Any] +) -> _AsyncTransactional: """Decorate a callable so that it runs in a transaction. Args: @@ -307,7 +314,9 @@ def async_transactional(to_wrap) -> _AsyncTransactional: # TODO(crwilcox): this was 'coroutine' from pytype merge-pyi... -async def _commit_with_retry(client, write_pbs, transaction_id) -> types.CommitResponse: +async def _commit_with_retry( + client: Client, write_pbs: list, transaction_id: bytes +) -> types.CommitResponse: """Call ``Commit`` on the GAPIC client with retry / sleep. Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level @@ -350,7 +359,9 @@ async def _commit_with_retry(client, write_pbs, transaction_id) -> types.CommitR current_sleep = await _sleep(current_sleep) -async def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER) -> float: +async def _sleep( + current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER +) -> float: """Sleep and produce a new sleep time. .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index dadcb0ec0..f84af4b3d 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -17,6 +17,10 @@ from google.cloud.firestore_v1 import _helpers +# Types needed only for Type Hints +from google.cloud.firestore_v1.document import DocumentReference +from typing import Union + class BaseWriteBatch(object): """Accumulate write operations to be sent in a batch. @@ -36,7 +40,7 @@ def __init__(self, client) -> None: self.write_results = None self.commit_time = None - def _add_write_pbs(self, write_pbs) -> None: + def _add_write_pbs(self, write_pbs: list) -> None: """Add `Write`` protobufs to this transaction. This method intended to be over-ridden by subclasses. @@ -47,7 +51,7 @@ def _add_write_pbs(self, write_pbs) -> None: """ self._write_pbs.extend(write_pbs) - def create(self, reference, document_data) -> None: + def create(self, reference: DocumentReference, document_data: dict) -> None: """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -62,7 +66,12 @@ def create(self, reference, document_data) -> None: write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) self._add_write_pbs(write_pbs) - def set(self, reference, document_data, merge=False) -> None: + def set( + self, + reference: DocumentReference, + document_data: dict, + merge: Union[bool, list] = False, + ) -> None: """Add a "change" to replace a document. See @@ -90,7 +99,12 @@ def set(self, reference, document_data, merge=False) -> None: self._add_write_pbs(write_pbs) - def update(self, reference, field_updates, option=None) -> None: + def update( + self, + reference: DocumentReference, + field_updates: dict, + option: _helpers.WriteOption = None, + ) -> None: """Add a "change" to update a document. See @@ -113,7 +127,9 @@ def update(self, reference, field_updates, option=None) -> None: ) self._add_write_pbs(write_pbs) - def delete(self, reference, option=None) -> None: + def delete( + self, reference: DocumentReference, option: _helpers.WriteOption = None + ) -> None: """Add a "change" to delete a document. See diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 8ad6d1441..b2a422291 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -41,6 +41,7 @@ Any, AsyncGenerator, Generator, + Iterable, List, Optional, Tuple, @@ -227,10 +228,10 @@ def _rpc_metadata(self): def collection(self, *collection_path) -> BaseCollectionReference: raise NotImplementedError - def collection_group(self, collection_id) -> BaseQuery: + def collection_group(self, collection_id: str) -> BaseQuery: raise NotImplementedError - def _get_collection_reference(self, collection_id) -> BaseCollectionReference: + def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference: """Checks validity of collection_id and then uses subclasses collection implementation. Args: @@ -271,7 +272,7 @@ def _document_path_helper(self, *document_path) -> List[str]: return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) @staticmethod - def field_path(*field_names) -> Any: + def field_path(*field_names: Tuple[str]) -> Any: """Create a **field path** from a list of nested field names. A **field path** is a ``.``-delimited concatenation of the field @@ -353,7 +354,10 @@ def write_option( raise TypeError(_BAD_OPTION_ERR, extra) def get_all( - self, references, field_paths=None, transaction=None + self, + references: list, + field_paths: Iterable[str] = None, + transaction: BaseTransaction = None, ) -> Union[ AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any] ]: @@ -374,7 +378,7 @@ def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError -def _reference_info(references) -> Tuple[list, dict]: +def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`. @@ -401,7 +405,7 @@ def _reference_info(references) -> Tuple[list, dict]: return document_paths, reference_map -def _get_reference(document_path, reference_map) -> Any: +def _get_reference(document_path: str, reference_map: dict) -> Any: """Get a document reference from a dictionary. This just wraps a simple dictionary look-up with a helpful error that is @@ -427,7 +431,11 @@ def _get_reference(document_path, reference_map) -> Any: raise ValueError(msg) -def _parse_batch_get(get_doc_response, reference_map, client) -> DocumentSnapshot: +def _parse_batch_get( + get_doc_response: types.BatchGetDocumentsResponse, + reference_map: dict, + client: BaseClient, +) -> DocumentSnapshot: """Parse a `BatchGetDocumentsResponse` protobuf. Args: @@ -477,7 +485,7 @@ def _parse_batch_get(get_doc_response, reference_map, client) -> DocumentSnapsho return snapshot -def _get_doc_mask(field_paths,) -> Optional[types.common.DocumentMask]: +def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentMask]: """Get a document mask if field paths are provided. Args: @@ -495,7 +503,7 @@ def _get_doc_mask(field_paths,) -> Optional[types.common.DocumentMask]: return types.DocumentMask(field_paths=field_paths) -def _item_to_collection_ref(iterator, item) -> Any: +def _item_to_collection_ref(iterator, item: str) -> Any: """Convert collection ID to collection ref. Args: @@ -506,7 +514,7 @@ def _item_to_collection_ref(iterator, item) -> Any: return iterator.client.collection(item) -def _path_helper(path) -> Any: +def _path_helper(path: tuple) -> Any: """Standardize path into a tuple of path segments. Args: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 67dfc36d5..72480a911 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -24,6 +24,7 @@ Generator, AsyncIterator, Iterator, + Iterable, NoReturn, Tuple, Union, @@ -32,6 +33,7 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.transaction import Transaction _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -103,7 +105,7 @@ def parent(self): def _query(self) -> BaseQuery: raise NotImplementedError - def document(self, document_id=None) -> Any: + def document(self, document_id: str = None) -> Any: """Create a sub-document underneath the current collection. Args: @@ -145,18 +147,18 @@ def _parent_info(self) -> Tuple[Any, str]: return parent_path, expected_prefix def add( - self, document_data, document_id=None + self, document_data: dict, document_id: str = None ) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]: raise NotImplementedError def list_documents( - self, page_size=None + self, page_size: int = None ) -> Union[ Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] ]: raise NotImplementedError - def select(self, field_paths) -> BaseQuery: + def select(self, field_paths: Iterable[str]) -> BaseQuery: """Create a "select" query with this collection as parent. See @@ -175,7 +177,7 @@ def select(self, field_paths) -> BaseQuery: query = self._query() return query.select(field_paths) - def where(self, field_path, op_string, value) -> BaseQuery: + def where(self, field_path: str, op_string: str, value) -> BaseQuery: """Create a "where" query with this collection as parent. See @@ -199,7 +201,7 @@ def where(self, field_path, op_string, value) -> BaseQuery: query = self._query() return query.where(field_path, op_string, value) - def order_by(self, field_path, **kwargs) -> BaseQuery: + def order_by(self, field_path: str, **kwargs) -> BaseQuery: """Create an "order by" query with this collection as parent. See @@ -221,7 +223,7 @@ def order_by(self, field_path, **kwargs) -> BaseQuery: query = self._query() return query.order_by(field_path, **kwargs) - def limit(self, count) -> BaseQuery: + def limit(self, count: int) -> BaseQuery: """Create a limited query with this collection as parent. .. note:: @@ -243,7 +245,7 @@ def limit(self, count) -> BaseQuery: query = self._query() return query.limit(count) - def limit_to_last(self, count): + def limit_to_last(self, count: int): """Create a limited to last query with this collection as parent. .. note:: `limit` and `limit_to_last` are mutually exclusive. @@ -261,7 +263,7 @@ def limit_to_last(self, count): query = self._query() return query.limit_to_last(count) - def offset(self, num_to_skip) -> BaseQuery: + def offset(self, num_to_skip: int) -> BaseQuery: """Skip to an offset in a query with this collection as parent. See @@ -279,7 +281,9 @@ def offset(self, num_to_skip) -> BaseQuery: query = self._query() return query.offset(num_to_skip) - def start_at(self, document_fields) -> BaseQuery: + def start_at( + self, document_fields: Union[DocumentSnapshot, dict, list, tuple] + ) -> BaseQuery: """Start query at a cursor with this collection as parent. See @@ -300,7 +304,9 @@ def start_at(self, document_fields) -> BaseQuery: query = self._query() return query.start_at(document_fields) - def start_after(self, document_fields) -> BaseQuery: + def start_after( + self, document_fields: Union[DocumentSnapshot, dict, list, tuple] + ) -> BaseQuery: """Start query after a cursor with this collection as parent. See @@ -321,7 +327,9 @@ def start_after(self, document_fields) -> BaseQuery: query = self._query() return query.start_after(document_fields) - def end_before(self, document_fields) -> BaseQuery: + def end_before( + self, document_fields: Union[DocumentSnapshot, dict, list, tuple] + ) -> BaseQuery: """End query before a cursor with this collection as parent. See @@ -342,7 +350,9 @@ def end_before(self, document_fields) -> BaseQuery: query = self._query() return query.end_before(document_fields) - def end_at(self, document_fields) -> BaseQuery: + def end_at( + self, document_fields: Union[DocumentSnapshot, dict, list, tuple] + ) -> BaseQuery: """End query at a cursor with this collection as parent. See @@ -364,14 +374,14 @@ def end_at(self, document_fields) -> BaseQuery: return query.end_at(document_fields) def get( - self, transaction=None + self, transaction: Transaction = None ) -> Union[ Generator[DocumentSnapshot, Any, Any], AsyncGenerator[DocumentSnapshot, Any] ]: raise NotImplementedError def stream( - self, transaction=None + self, transaction: Transaction = None ) -> Union[Iterator[DocumentSnapshot], AsyncIterator[DocumentSnapshot]]: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index f11546cac..68534c471 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -18,7 +18,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import field_path as field_path_module -from typing import Any, NoReturn +from typing import Any, Iterable, NoReturn, Tuple class BaseDocumentReference(object): @@ -164,7 +164,7 @@ def parent(self): parent_path = self._path[:-1] return self._client.collection(*parent_path) - def collection(self, collection_id) -> Any: + def collection(self, collection_id: str) -> Any: """Create a sub-collection underneath the current document. Args: @@ -178,22 +178,26 @@ def collection(self, collection_id) -> Any: child_path = self._path + (collection_id,) return self._client.collection(*child_path) - def create(self, document_data) -> NoReturn: + def create(self, document_data: dict) -> NoReturn: raise NotImplementedError - def set(self, document_data, merge=False) -> NoReturn: + def set(self, document_data: dict, merge: bool = False) -> NoReturn: raise NotImplementedError - def update(self, field_updates, option=None) -> NoReturn: + def update( + self, field_updates: dict, option: _helpers.WriteOption = None + ) -> NoReturn: raise NotImplementedError - def delete(self, option=None) -> NoReturn: + def delete(self, option: _helpers.WriteOption = None) -> NoReturn: raise NotImplementedError - def get(self, field_paths=None, transaction=None) -> "DocumentSnapshot": + def get( + self, field_paths: Iterable[str] = None, transaction=None + ) -> "DocumentSnapshot": raise NotImplementedError - def collections(self, page_size=None) -> NoReturn: + def collections(self, page_size: int = None) -> NoReturn: raise NotImplementedError def on_snapshot(self, callback) -> NoReturn: @@ -291,7 +295,7 @@ def reference(self): """ return self._reference - def get(self, field_path) -> Any: + def get(self, field_path: str) -> Any: """Get a value from the snapshot data. If the data is nested, for example: @@ -371,7 +375,7 @@ def to_dict(self) -> Any: return copy.deepcopy(self._data) -def _get_document_path(client, path) -> str: +def _get_document_path(client, path: Tuple[str]) -> str: """Convert a path tuple into a full path string. Of the form: @@ -423,7 +427,7 @@ def _consume_single_get(response_iterator) -> Any: return all_responses[0] -def _first_write_result(write_results) -> Any: +def _first_write_result(write_results: list) -> Any: """Get first write result from list. For cases where ``len(write_results) > 1``, this assumes the writes @@ -449,7 +453,7 @@ def _first_write_result(write_results) -> Any: return write_results[0] -def _item_to_collection_ref(iterator, item) -> Any: +def _item_to_collection_ref(iterator, item: str) -> Any: """Convert collection ID to collection ref. Args: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 1f7d9fdb7..188c15b6a 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -30,8 +30,12 @@ from google.cloud.firestore_v1.types import StructuredQuery from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import Cursor +from google.cloud.firestore_v1.types import RunQueryResponse from google.cloud.firestore_v1.order import Order -from typing import Any, Dict, NoReturn, Optional, Tuple +from typing import Any, Dict, Iterable, NoReturn, Optional, Tuple, Union + +# Types needed only for Type Hints +from google.cloud.firestore_v1.base_document import DocumentSnapshot _BAD_DIR_STRING: str _BAD_OP_NAN_NULL: str @@ -191,7 +195,7 @@ def _client(self): """ return self._parent._client - def select(self, field_paths) -> "BaseQuery": + def select(self, field_paths: Iterable[str]) -> "BaseQuery": """Project documents matching query to a limited set of fields. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -236,7 +240,7 @@ def select(self, field_paths) -> "BaseQuery": all_descendants=self._all_descendants, ) - def where(self, field_path, op_string, value) -> "BaseQuery": + def where(self, field_path: str, op_string: str, value) -> "BaseQuery": """Filter the query on a field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -314,7 +318,7 @@ def _make_order(field_path, direction) -> Any: direction=_enum_from_direction(direction), ) - def order_by(self, field_path, direction=ASCENDING) -> "BaseQuery": + def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery": """Modify the query to add an order clause on a specific field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -359,7 +363,7 @@ def order_by(self, field_path, direction=ASCENDING) -> "BaseQuery": all_descendants=self._all_descendants, ) - def limit(self, count) -> "BaseQuery": + def limit(self, count: int) -> "BaseQuery": """Limit a query to return at most `count` matching results. If the current query already has a `limit` set, this will override it. @@ -387,7 +391,7 @@ def limit(self, count) -> "BaseQuery": all_descendants=self._all_descendants, ) - def limit_to_last(self, count): + def limit_to_last(self, count: int): """Limit a query to return the last `count` matching results. If the current query already has a `limit_to_last` set, this will override it. @@ -415,7 +419,7 @@ def limit_to_last(self, count): all_descendants=self._all_descendants, ) - def offset(self, num_to_skip) -> "BaseQuery": + def offset(self, num_to_skip: int) -> "BaseQuery": """Skip to an offset in a query. If the current query already has specified an offset, this will @@ -456,7 +460,12 @@ def _check_snapshot(self, document_snapshot) -> None: if document_snapshot.reference._path[:-1] != self._parent._path: raise ValueError("Cannot use snapshot from another collection as a cursor.") - def _cursor_helper(self, document_fields_or_snapshot, before, start) -> "BaseQuery": + def _cursor_helper( + self, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + before: bool, + start: bool, + ) -> "BaseQuery": """Set values to be used for a ``start_at`` or ``end_at`` cursor. The values will later be used in a query protobuf. @@ -508,7 +517,9 @@ def _cursor_helper(self, document_fields_or_snapshot, before, start) -> "BaseQue return self.__class__(self._parent, **query_kwargs) - def start_at(self, document_fields_or_snapshot) -> "BaseQuery": + def start_at( + self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] + ) -> "BaseQuery": """Start query results at a particular document value. The result set will **include** the document specified by @@ -538,7 +549,9 @@ def start_at(self, document_fields_or_snapshot) -> "BaseQuery": """ return self._cursor_helper(document_fields_or_snapshot, before=True, start=True) - def start_after(self, document_fields_or_snapshot) -> "BaseQuery": + def start_after( + self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] + ) -> "BaseQuery": """Start query results after a particular document value. The result set will **exclude** the document specified by @@ -569,7 +582,9 @@ def start_after(self, document_fields_or_snapshot) -> "BaseQuery": document_fields_or_snapshot, before=False, start=True ) - def end_before(self, document_fields_or_snapshot) -> "BaseQuery": + def end_before( + self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] + ) -> "BaseQuery": """End query results before a particular document value. The result set will **exclude** the document specified by @@ -600,7 +615,9 @@ def end_before(self, document_fields_or_snapshot) -> "BaseQuery": document_fields_or_snapshot, before=True, start=False ) - def end_at(self, document_fields_or_snapshot) -> "BaseQuery": + def end_at( + self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] + ) -> "BaseQuery": """End query results at a particular document value. The result set will **include** the document specified by @@ -839,7 +856,7 @@ def _comparator(self, doc1, doc2) -> Any: return 0 -def _enum_from_op_string(op_string) -> Any: +def _enum_from_op_string(op_string: str) -> Any: """Convert a string representation of a binary operator to an enum. These enums come from the protobuf message definition @@ -882,7 +899,7 @@ def _isnan(value) -> bool: return False -def _enum_from_direction(direction) -> Any: +def _enum_from_direction(direction: str) -> Any: """Convert a string representation of a direction to an enum. Args: @@ -934,7 +951,7 @@ def _filter_pb(field_or_unary) -> Any: raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) -def _cursor_pb(cursor_pair) -> Optional[Cursor]: +def _cursor_pb(cursor_pair: Tuple[list, bool]) -> Optional[Cursor]: """Convert a cursor pair to a protobuf. If ``cursor_pair`` is :data:`None`, just returns :data:`None`. @@ -956,7 +973,7 @@ def _cursor_pb(cursor_pair) -> Optional[Cursor]: def _query_response_to_snapshot( - response_pb, collection, expected_prefix + response_pb: RunQueryResponse, collection, expected_prefix: str ) -> Optional[document.DocumentSnapshot]: """Parse a query response protobuf to a document snapshot. @@ -992,7 +1009,7 @@ def _query_response_to_snapshot( def _collection_group_query_response_to_snapshot( - response_pb, collection + response_pb: RunQueryResponse, collection ) -> Optional[document.DocumentSnapshot]: """Parse a query response protobuf to a document snapshot. diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index 9f2eff0ec..c676d3d7a 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -67,7 +67,9 @@ def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: def _add_write_pbs(self, write_pbs) -> NoReturn: raise NotImplementedError - def _options_protobuf(self, retry_id) -> Optional[types.common.TransactionOptions]: + def _options_protobuf( + self, retry_id: Union[bytes, None] + ) -> Optional[types.common.TransactionOptions]: """Convert the current object to protobuf. The ``retry_id`` value is used when retrying a transaction that @@ -139,7 +141,7 @@ def _rollback(self) -> NoReturn: def _commit(self) -> Union[list, Coroutine[Any, Any, list]]: raise NotImplementedError - def get_all(self, references) -> NoReturn: + def get_all(self, references: list) -> NoReturn: raise NotImplementedError def get(self, ref_or_query) -> NoReturn: diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 448a8f4fb..e6c9f45c9 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -44,7 +44,7 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc as firestore_grpc_transport, ) -from typing import Any, Generator +from typing import Any, Generator, Iterable, Tuple class Client(BaseClient): @@ -114,7 +114,7 @@ def _target(self): """ return self._target_helper(firestore_client.FirestoreClient) - def collection(self, *collection_path) -> CollectionReference: + def collection(self, *collection_path: Tuple[str]) -> CollectionReference: """Get a reference to a collection. For a top-level collection: @@ -145,7 +145,7 @@ def collection(self, *collection_path) -> CollectionReference: """ return CollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id) -> CollectionGroup: + def collection_group(self, collection_id: str) -> CollectionGroup: """ Creates and returns a new Query that includes all documents in the database that are contained in a collection or subcollection with the @@ -167,7 +167,7 @@ def collection_group(self, collection_id) -> CollectionGroup: """ return CollectionGroup(self._get_collection_reference(collection_id)) - def document(self, *document_path) -> DocumentReference: + def document(self, *document_path: Tuple[str]) -> DocumentReference: """Get a reference to a document in a collection. For a top-level document: @@ -203,7 +203,10 @@ def document(self, *document_path) -> DocumentReference: ) def get_all( - self, references, field_paths=None, transaction=None + self, + references: list, + field_paths: Iterable[str] = None, + transaction: Transaction = None, ) -> Generator[Any, Any, None]: """Retrieve a batch of documents. diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 43f2d8fc8..4cd857095 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -21,7 +21,10 @@ from google.cloud.firestore_v1 import query as query_mod from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import document -from typing import Any, Generator, Tuple +from typing import Any, Callable, Generator, Tuple + +# Types needed only for Type Hints +from google.cloud.firestore_v1.transaction import Transaction class CollectionReference(BaseCollectionReference): @@ -61,7 +64,7 @@ def _query(self) -> query_mod.Query: """ return query_mod.Query(self) - def add(self, document_data, document_id=None) -> Tuple[Any, Any]: + def add(self, document_data: dict, document_id: str = None) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. Args: @@ -92,7 +95,7 @@ def add(self, document_data, document_id=None) -> Tuple[Any, Any]: write_result = document_ref.create(document_data) return write_result.update_time, document_ref - def list_documents(self, page_size=None) -> Generator[Any, Any, None]: + def list_documents(self, page_size: int = None) -> Generator[Any, Any, None]: """List all subdocuments of the current collection. Args: @@ -119,7 +122,7 @@ def list_documents(self, page_size=None) -> Generator[Any, Any, None]: ) return (_item_to_document_ref(self, i) for i in iterator) - def get(self, transaction=None) -> list: + def get(self, transaction: Transaction = None) -> list: """Read the documents in this collection. This sends a ``RunQuery`` RPC and returns a list of documents @@ -141,7 +144,7 @@ def get(self, transaction=None) -> list: return query.get(transaction=transaction) def stream( - self, transaction=None + self, transaction: Transaction = None ) -> Generator[document.DocumentSnapshot, Any, None]: """Read the documents in this collection. @@ -172,7 +175,7 @@ def stream( query = query_mod.Query(self) return query.stream(transaction=transaction) - def on_snapshot(self, callback) -> Watch: + def on_snapshot(self, callback: Callable) -> Watch: """Monitor the documents in this collection. This starts a watch on this collection using a background thread. The diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index f4f08ee71..ca5fc8378 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -24,7 +24,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.watch import Watch -from typing import Any, Generator +from typing import Any, Callable, Generator, Iterable class DocumentReference(BaseDocumentReference): @@ -76,7 +76,7 @@ def create(self, document_data) -> Any: write_results = batch.commit() return _first_write_result(write_results) - def set(self, document_data, merge=False) -> Any: + def set(self, document_data: dict, merge: bool = False) -> Any: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -107,7 +107,7 @@ def set(self, document_data, merge=False) -> Any: write_results = batch.commit() return _first_write_result(write_results) - def update(self, field_updates, option=None) -> Any: + def update(self, field_updates: dict, option: _helpers.WriteOption = None) -> Any: """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -255,7 +255,7 @@ def update(self, field_updates, option=None) -> Any: write_results = batch.commit() return _first_write_result(write_results) - def delete(self, option=None) -> Any: + def delete(self, option: _helpers.WriteOption = None) -> Any: """Delete the current document in the Firestore database. Args: @@ -282,7 +282,9 @@ def delete(self, option=None) -> Any: return commit_response.commit_time - def get(self, field_paths=None, transaction=None) -> DocumentSnapshot: + def get( + self, field_paths: Iterable[str] = None, transaction=None + ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. See :meth:`~google.cloud.firestore_v1.base_client.BaseClient.field_path` for @@ -347,7 +349,7 @@ def get(self, field_paths=None, transaction=None) -> DocumentSnapshot: update_time=update_time, ) - def collections(self, page_size=None) -> Generator[Any, Any, None]: + def collections(self, page_size: int = None) -> Generator[Any, Any, None]: """List subcollections of the current document. Args: @@ -387,7 +389,7 @@ def collections(self, page_size=None) -> Generator[Any, Any, None]: # iterator.item_to_value = _item_to_collection_ref # return iterator - def on_snapshot(self, callback) -> Watch: + def on_snapshot(self, callback: Callable) -> Watch: """Watch this document. This starts a watch on this document using a background thread. The diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index b1bfa860d..610d8ffd8 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -17,6 +17,7 @@ from collections import abc import re +from typing import Iterable _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" @@ -42,7 +43,7 @@ TOKENS_REGEX = re.compile(TOKENS_PATTERN) -def _tokenize_field_path(path): +def _tokenize_field_path(path: str): """Lex a field path into tokens (including dots). Args: @@ -63,7 +64,7 @@ def _tokenize_field_path(path): raise ValueError("Path {} not consumed, residue: {}".format(path, path[pos:])) -def split_field_path(path): +def split_field_path(path: str): """Split a field path into valid elements (without dots). Args: @@ -98,7 +99,7 @@ def split_field_path(path): return elements -def parse_field_path(api_repr): +def parse_field_path(api_repr: str): """Parse a **field path** from into a list of nested field names. See :func:`field_path` for more on **field paths**. @@ -127,7 +128,7 @@ def parse_field_path(api_repr): return field_names -def render_field_path(field_names): +def render_field_path(field_names: Iterable[str]): """Create a **field path** from a list of nested field names. A **field path** is a ``.``-delimited concatenation of the field @@ -171,7 +172,7 @@ def render_field_path(field_names): get_field_path = render_field_path # backward-compatibility -def get_nested_value(field_path, data): +def get_nested_value(field_path: str, data: dict): """Get a (potentially nested) value from a dictionary. If the data is nested, for example: @@ -272,7 +273,7 @@ def __init__(self, *parts): self.parts = tuple(parts) @classmethod - def from_api_repr(cls, api_repr): + def from_api_repr(cls, api_repr: str): """Factory: create a FieldPath from the string formatted per the API. Args: @@ -289,7 +290,7 @@ def from_api_repr(cls, api_repr): return cls(*parse_field_path(api_repr)) @classmethod - def from_string(cls, path_string): + def from_string(cls, path_string: str): """Factory: create a FieldPath from a unicode string representation. This method splits on the character `.` and disallows the diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 09f8dc47b..ef38b68f4 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -30,7 +30,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Generator +from typing import Any, Callable, Generator class Query(BaseQuery): @@ -209,7 +209,7 @@ def stream( if snapshot is not None: yield snapshot - def on_snapshot(self, callback) -> Watch: + def on_snapshot(self, callback: Callable) -> Watch: """Monitor the documents in this collection that match this query. This starts a watch on this query using a background thread. The diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index a93f3c62e..1549fcf7d 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -36,7 +36,7 @@ from google.cloud.firestore_v1 import batch from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.query import Query -from typing import Any, Optional +from typing import Any, Callable, Optional class Transaction(batch.WriteBatch, BaseTransaction): @@ -57,7 +57,7 @@ def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: super(Transaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs) -> None: + def _add_write_pbs(self, write_pbs: list) -> None: """Add `Write`` protobufs to this transaction. Args: @@ -72,7 +72,7 @@ def _add_write_pbs(self, write_pbs) -> None: super(Transaction, self)._add_write_pbs(write_pbs) - def _begin(self, retry_id=None) -> None: + def _begin(self, retry_id: bytes = None) -> None: """Begin the transaction. Args: @@ -136,7 +136,7 @@ def _commit(self) -> list: self._clean_up() return list(commit_response.write_results) - def get_all(self, references) -> Any: + def get_all(self, references: list) -> Any: """Retrieves multiple documents from Firestore. Args: @@ -182,7 +182,7 @@ class _Transactional(_BaseTransactional): def __init__(self, to_wrap) -> None: super(_Transactional, self).__init__(to_wrap) - def _pre_commit(self, transaction, *args, **kwargs) -> Any: + def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any: """Begin transaction and call the wrapped callable. If the callable raises an exception, the transaction will be rolled @@ -220,7 +220,7 @@ def _pre_commit(self, transaction, *args, **kwargs) -> Any: transaction._rollback() raise - def _maybe_commit(self, transaction) -> Optional[bool]: + def _maybe_commit(self, transaction: Transaction) -> Optional[bool]: """Try to commit the transaction. If the transaction is read-write and the ``Commit`` fails with the @@ -248,7 +248,7 @@ def _maybe_commit(self, transaction) -> Optional[bool]: else: raise - def __call__(self, transaction, *args, **kwargs): + def __call__(self, transaction: Transaction, *args, **kwargs): """Execute the wrapped callable within a transaction. Args: @@ -286,7 +286,7 @@ def __call__(self, transaction, *args, **kwargs): raise ValueError(msg) -def transactional(to_wrap) -> _Transactional: +def transactional(to_wrap: Callable) -> _Transactional: """Decorate a callable so that it runs in a transaction. Args: @@ -301,7 +301,7 @@ def transactional(to_wrap) -> _Transactional: return _Transactional(to_wrap) -def _commit_with_retry(client, write_pbs, transaction_id) -> Any: +def _commit_with_retry(client, write_pbs: list, transaction_id: bytes) -> Any: """Call ``Commit`` on the GAPIC client with retry / sleep. Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level @@ -344,7 +344,9 @@ def _commit_with_retry(client, write_pbs, transaction_id) -> Any: current_sleep = _sleep(current_sleep) -def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER) -> Any: +def _sleep( + current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER +) -> Any: """Sleep and produce a new sleep time. .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\