diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index c1213e243..89cf3b002 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -32,7 +32,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import write -from typing import Any, Generator, List, NoReturn, Optional, Tuple +from typing import Any, Generator, List, NoReturn, Optional, Tuple, Union _EmptyDict: transforms.Sentinel _GRPC_ERROR_MAPPING: dict @@ -69,7 +69,7 @@ def __init__(self, latitude, longitude) -> None: self.latitude = latitude self.longitude = longitude - def to_protobuf(self) -> Any: + def to_protobuf(self) -> latlng_pb2.LatLng: """Convert the current object to protobuf. Returns: @@ -253,7 +253,9 @@ def reference_value_to_document(reference_value, client) -> Any: return document -def decode_value(value, client) -> Any: +def decode_value( + value, client +) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]: """Converts a Firestore protobuf ``Value`` to a native Python value. Args: @@ -316,7 +318,7 @@ def decode_dict(value_fields, client) -> dict: return {key: decode_value(value, client) for key, value in value_fields.items()} -def get_doc_id(document_pb, expected_prefix) -> Any: +def get_doc_id(document_pb, expected_prefix) -> str: """Parse a document ID from a document protobuf. Args: @@ -887,7 +889,7 @@ class ReadAfterWriteError(Exception): """ -def get_transaction_id(transaction, read_operation=True) -> Any: +def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]: """Get the transaction ID from a ``Transaction`` object. Args: diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index a90227c1f..11dec64b0 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -25,6 +25,8 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.types import write +from google.protobuf import timestamp_pb2 from typing import Any, AsyncGenerator, Coroutine, Iterable, Union @@ -61,7 +63,7 @@ async def create( document_data: dict, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Coroutine: + ) -> write.WriteResult: """Create the current document in the Firestore database. Args: @@ -91,7 +93,7 @@ async def set( merge: bool = False, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Coroutine: + ) -> write.WriteResult: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -131,7 +133,7 @@ async def update( option: _helpers.WriteOption = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Coroutine: + ) -> write.WriteResult: """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -287,7 +289,7 @@ async def delete( option: _helpers.WriteOption = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Coroutine: + ) -> timestamp_pb2.Timestamp: """Delete the current document in the Firestore database. Args: diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index fd639e1ed..aae40b468 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -153,7 +153,7 @@ async def get_all( references: list, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Coroutine: + ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieves multiple documents from Firestore. Args: diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 64e38d0e0..22afb09de 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -166,7 +166,7 @@ def _firestore_api_helper(self, transport, client_class, client_module) -> Any: return self._firestore_api_internal - def _target_helper(self, client_class) -> Any: + def _target_helper(self, client_class) -> str: """Return the target (where the API is). Eg. "firestore.googleapis.com" @@ -273,7 +273,7 @@ def _document_path_helper(self, *document_path) -> List[str]: return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) @staticmethod - def field_path(*field_names: Tuple[str]) -> Any: + def field_path(*field_names: Tuple[str]) -> str: """Create a **field path** from a list of nested field names. A **field path** is a ``.``-delimited concatenation of the field @@ -438,7 +438,7 @@ def _reference_info(references: list) -> Tuple[list, dict]: return document_paths, reference_map -def _get_reference(document_path: str, reference_map: dict) -> Any: +def _get_reference(document_path: str, reference_map: dict) -> BaseDocumentReference: """Get a document reference from a dictionary. This just wraps a simple dictionary look-up with a helpful error that is @@ -536,7 +536,18 @@ def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentM return types.DocumentMask(field_paths=field_paths) -def _path_helper(path: tuple) -> Any: +def _item_to_collection_ref(iterator, item: str) -> BaseCollectionReference: + """Convert collection ID to collection ref. + + Args: + iterator (google.api_core.page_iterator.GRPCIterator): + iterator response + item (str): ID of the collection + """ + return iterator.client.collection(item) + + +def _path_helper(path: tuple) -> Tuple[str]: """Standardize path into a tuple of path segments. Args: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index ae58fe820..956c4b4b1 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -107,7 +107,7 @@ def parent(self): def _query(self) -> BaseQuery: raise NotImplementedError - def document(self, document_id: str = None) -> Any: + def document(self, document_id: str = None) -> DocumentReference: """Create a sub-document underneath the current collection. Args: diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index f06d5a8c4..441a30b51 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -22,10 +22,10 @@ from google.cloud.firestore_v1 import field_path as field_path_module from google.cloud.firestore_v1.types import common -from typing import Any -from typing import Iterable -from typing import NoReturn -from typing import Tuple +# Types needed only for Type Hints +from google.cloud.firestore_v1.types import firestore +from google.cloud.firestore_v1.types import write +from typing import Any, Dict, Iterable, NoReturn, Union, Tuple class BaseDocumentReference(object): @@ -475,7 +475,7 @@ def get(self, field_path: str) -> Any: nested_data = field_path_module.get_nested_value(field_path, self._data) return copy.deepcopy(nested_data) - def to_dict(self) -> Any: + def to_dict(self) -> Union[Dict[str, Any], None]: """Retrieve the data contained in this snapshot. A copy is returned since the data may contain mutable values, @@ -512,7 +512,7 @@ def _get_document_path(client, path: Tuple[str]) -> str: return _helpers.DOCUMENT_PATH_DELIMITER.join(parts) -def _consume_single_get(response_iterator) -> Any: +def _consume_single_get(response_iterator) -> firestore.BatchGetDocumentsResponse: """Consume a gRPC stream that should contain a single response. The stream will correspond to a ``BatchGetDocuments`` request made @@ -543,7 +543,7 @@ def _consume_single_get(response_iterator) -> Any: return all_responses[0] -def _first_write_result(write_results: list) -> Any: +def _first_write_result(write_results: list) -> write.WriteResult: """Get first write result from list. For cases where ``len(write_results) > 1``, this assumes the writes diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 2393d3711..6e0671907 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -314,7 +314,7 @@ def where(self, field_path: str, op_string: str, value) -> "BaseQuery": ) @staticmethod - def _make_order(field_path, direction) -> Any: + def _make_order(field_path, direction) -> StructuredQuery.Order: """Helper for :meth:`order_by`.""" return query.StructuredQuery.Order( field=query.StructuredQuery.FieldReference(field_path=field_path), @@ -394,7 +394,7 @@ def limit(self, count: int) -> "BaseQuery": all_descendants=self._all_descendants, ) - def limit_to_last(self, count: int): + def limit_to_last(self, count: int) -> "BaseQuery": """Limit a query to return the last `count` matching results. If the current query already has a `limit_to_last` set, this will override it. @@ -651,7 +651,7 @@ def end_at( document_fields_or_snapshot, before=False, start=False ) - def _filters_pb(self) -> Any: + def _filters_pb(self) -> StructuredQuery.Filter: """Convert all the filters into a single generic Filter protobuf. This may be a lone field filter or unary filter, may be a composite @@ -674,7 +674,7 @@ def _filters_pb(self) -> Any: return query.StructuredQuery.Filter(composite_filter=composite_filter) @staticmethod - def _normalize_projection(projection) -> Any: + def _normalize_projection(projection) -> StructuredQuery.Projection: """Helper: convert field paths to message.""" if projection is not None: @@ -836,7 +836,7 @@ def stream( def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError - def _comparator(self, doc1, doc2) -> Any: + def _comparator(self, doc1, doc2) -> int: _orders = self._orders # Add implicit sorting by name, using the last specified direction. @@ -883,7 +883,7 @@ def _comparator(self, doc1, doc2) -> Any: return 0 -def _enum_from_op_string(op_string: str) -> Any: +def _enum_from_op_string(op_string: str) -> int: """Convert a string representation of a binary operator to an enum. These enums come from the protobuf message definition @@ -926,7 +926,7 @@ def _isnan(value) -> bool: return False -def _enum_from_direction(direction: str) -> Any: +def _enum_from_direction(direction: str) -> int: """Convert a string representation of a direction to an enum. Args: @@ -954,7 +954,7 @@ def _enum_from_direction(direction: str) -> Any: raise ValueError(msg) -def _filter_pb(field_or_unary) -> Any: +def _filter_pb(field_or_unary) -> StructuredQuery.Filter: """Convert a specific protobuf filter to the generic filter type. Args: diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 9ab945ef6..6ad5f76e6 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -46,6 +46,9 @@ ) from typing import Any, Generator, Iterable, Tuple +# Types needed only for Type Hints +from google.cloud.firestore_v1.base_document import DocumentSnapshot + class Client(BaseClient): """Client for interacting with Google Cloud Firestore API. @@ -209,7 +212,7 @@ def get_all( transaction: Transaction = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Generator[Any, Any, None]: + ) -> Generator[DocumentSnapshot, Any, None]: """Retrieve a batch of documents. .. note:: diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 42fd523d7..bdb5c7943 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -25,7 +25,9 @@ from google.api_core import exceptions # type: ignore from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.watch import Watch +from google.protobuf import timestamp_pb2 from typing import Any, Callable, Generator, Iterable @@ -62,7 +64,7 @@ def create( document_data: dict, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Any: + ) -> write.WriteResult: """Create the current document in the Firestore database. Args: @@ -92,7 +94,7 @@ def set( merge: bool = False, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Any: + ) -> write.WriteResult: """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -132,7 +134,7 @@ def update( option: _helpers.WriteOption = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Any: + ) -> write.WriteResult: """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -288,7 +290,7 @@ def delete( option: _helpers.WriteOption = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Any: + ) -> timestamp_pb2.Timestamp: """Delete the current document in the Firestore database. Args: @@ -339,7 +341,7 @@ def get( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this reference will be retrieved in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, + retry (google.api_core.retry.Retry): Designation of what errors, if an y, should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. diff --git a/google/cloud/firestore_v1/order.py b/google/cloud/firestore_v1/order.py index 5d1e3345d..37052f9f5 100644 --- a/google/cloud/firestore_v1/order.py +++ b/google/cloud/firestore_v1/order.py @@ -60,7 +60,7 @@ class Order(object): """ @classmethod - def compare(cls, left, right) -> Any: + def compare(cls, left, right) -> int: """ Main comparison function for all Firestore types. @return -1 is left < right, 0 if left == right, otherwise 1 @@ -102,7 +102,7 @@ def compare(cls, left, right) -> Any: raise ValueError(f"Unknown ``value_type`` {value_type}") @staticmethod - def compare_blobs(left, right) -> Any: + def compare_blobs(left, right) -> int: left_bytes = left.bytes_value right_bytes = right.bytes_value @@ -153,7 +153,7 @@ def compare_resource_paths(left, right) -> int: return (left_length > right_length) - (left_length < right_length) @staticmethod - def compare_arrays(left, right) -> Any: + def compare_arrays(left, right) -> int: l_values = left.array_value.values r_values = right.array_value.values @@ -166,7 +166,7 @@ def compare_arrays(left, right) -> Any: return Order._compare_to(len(l_values), len(r_values)) @staticmethod - def compare_objects(left, right) -> Any: + def compare_objects(left, right) -> int: left_fields = left.map_value.fields right_fields = right.map_value.fields @@ -184,13 +184,13 @@ def compare_objects(left, right) -> Any: return Order._compare_to(len(left_fields), len(right_fields)) @staticmethod - def compare_numbers(left, right) -> Any: + def compare_numbers(left, right) -> int: left_value = decode_value(left, None) right_value = decode_value(right, None) return Order.compare_doubles(left_value, right_value) @staticmethod - def compare_doubles(left, right) -> Any: + def compare_doubles(left, right) -> int: if math.isnan(left): if math.isnan(right): return 0 @@ -201,7 +201,7 @@ def compare_doubles(left, right) -> Any: return Order._compare_to(left, right) @staticmethod - def _compare_to(left, right) -> Any: + def _compare_to(left, right) -> int: # We can't just use cmp(left, right) because cmp doesn't exist # in Python 3, so this is an equivalent suggested by # https://docs.python.org/3.0/whatsnew/3.0.html#ordering-comparisons diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 7bab4b595..f4719f712 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -40,7 +40,11 @@ from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.query import Query -from typing import Any, Callable, Optional + +# Types needed only for Type Hints +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.types import CommitResponse +from typing import Any, Callable, Generator, Optional class Transaction(batch.WriteBatch, BaseTransaction): @@ -145,7 +149,7 @@ def get_all( references: list, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Any: + ) -> Generator[DocumentSnapshot, Any, None]: """Retrieves multiple documents from Firestore. Args: @@ -168,7 +172,7 @@ def get( ref_or_query, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, - ) -> Any: + ) -> Generator[DocumentSnapshot, Any, None]: """Retrieve a document or a query result from the database. Args: @@ -326,7 +330,9 @@ def transactional(to_wrap: Callable) -> _Transactional: return _Transactional(to_wrap) -def _commit_with_retry(client, write_pbs: list, transaction_id: bytes) -> Any: +def _commit_with_retry( + client, write_pbs: list, transaction_id: bytes +) -> CommitResponse: """Call ``Commit`` on the GAPIC client with retry / sleep. Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level @@ -371,7 +377,7 @@ def _commit_with_retry(client, write_pbs: list, transaction_id: bytes) -> Any: def _sleep( current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER -) -> Any: +) -> float: """Sleep and produce a new sleep time. .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\