Skip to content

Commit

Permalink
feat: improve type information (#176)
Browse files Browse the repository at this point in the history

Co-authored-by: Tres Seaver <tseaver@palladion.com>
  • Loading branch information
HemangChothani and tseaver committed Oct 23, 2020
1 parent e8f6c4d commit 30bb3fb
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 48 deletions.
12 changes: 7 additions & 5 deletions google/cloud/firestore_v1/_helpers.py
Expand Up @@ -32,7 +32,7 @@
from google.cloud.firestore_v1.types import common
from google.cloud.firestore_v1.types import document
from google.cloud.firestore_v1.types import write
from typing import Any, Generator, List, NoReturn, Optional, Tuple
from typing import Any, Generator, List, NoReturn, Optional, Tuple, Union

_EmptyDict: transforms.Sentinel
_GRPC_ERROR_MAPPING: dict
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, latitude, longitude) -> None:
self.latitude = latitude
self.longitude = longitude

def to_protobuf(self) -> Any:
def to_protobuf(self) -> latlng_pb2.LatLng:
"""Convert the current object to protobuf.
Returns:
Expand Down Expand Up @@ -253,7 +253,9 @@ def reference_value_to_document(reference_value, client) -> Any:
return document


def decode_value(value, client) -> Any:
def decode_value(
value, client
) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]:
"""Converts a Firestore protobuf ``Value`` to a native Python value.
Args:
Expand Down Expand Up @@ -316,7 +318,7 @@ def decode_dict(value_fields, client) -> dict:
return {key: decode_value(value, client) for key, value in value_fields.items()}


def get_doc_id(document_pb, expected_prefix) -> Any:
def get_doc_id(document_pb, expected_prefix) -> str:
"""Parse a document ID from a document protobuf.
Args:
Expand Down Expand Up @@ -887,7 +889,7 @@ class ReadAfterWriteError(Exception):
"""


def get_transaction_id(transaction, read_operation=True) -> Any:
def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]:
"""Get the transaction ID from a ``Transaction`` object.
Args:
Expand Down
10 changes: 6 additions & 4 deletions google/cloud/firestore_v1/async_document.py
Expand Up @@ -25,6 +25,8 @@

from google.api_core import exceptions # type: ignore
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import write
from google.protobuf import timestamp_pb2
from typing import Any, AsyncGenerator, Coroutine, Iterable, Union


Expand Down Expand Up @@ -61,7 +63,7 @@ async def create(
document_data: dict,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Coroutine:
) -> write.WriteResult:
"""Create the current document in the Firestore database.
Args:
Expand Down Expand Up @@ -91,7 +93,7 @@ async def set(
merge: bool = False,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Coroutine:
) -> write.WriteResult:
"""Replace the current document in the Firestore database.
A write ``option`` can be specified to indicate preconditions of
Expand Down Expand Up @@ -131,7 +133,7 @@ async def update(
option: _helpers.WriteOption = None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Coroutine:
) -> write.WriteResult:
"""Update an existing document in the Firestore database.
By default, this method verifies that the document exists on the
Expand Down Expand Up @@ -287,7 +289,7 @@ async def delete(
option: _helpers.WriteOption = None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Coroutine:
) -> timestamp_pb2.Timestamp:
"""Delete the current document in the Firestore database.
Args:
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/async_transaction.py
Expand Up @@ -153,7 +153,7 @@ async def get_all(
references: list,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Coroutine:
) -> AsyncGenerator[DocumentSnapshot, Any]:
"""Retrieves multiple documents from Firestore.
Args:
Expand Down
19 changes: 15 additions & 4 deletions google/cloud/firestore_v1/base_client.py
Expand Up @@ -166,7 +166,7 @@ def _firestore_api_helper(self, transport, client_class, client_module) -> Any:

return self._firestore_api_internal

def _target_helper(self, client_class) -> Any:
def _target_helper(self, client_class) -> str:
"""Return the target (where the API is).
Eg. "firestore.googleapis.com"
Expand Down Expand Up @@ -273,7 +273,7 @@ def _document_path_helper(self, *document_path) -> List[str]:
return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER)

@staticmethod
def field_path(*field_names: Tuple[str]) -> Any:
def field_path(*field_names: Tuple[str]) -> str:
"""Create a **field path** from a list of nested field names.
A **field path** is a ``.``-delimited concatenation of the field
Expand Down Expand Up @@ -438,7 +438,7 @@ def _reference_info(references: list) -> Tuple[list, dict]:
return document_paths, reference_map


def _get_reference(document_path: str, reference_map: dict) -> Any:
def _get_reference(document_path: str, reference_map: dict) -> BaseDocumentReference:
"""Get a document reference from a dictionary.
This just wraps a simple dictionary look-up with a helpful error that is
Expand Down Expand Up @@ -536,7 +536,18 @@ def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentM
return types.DocumentMask(field_paths=field_paths)


def _path_helper(path: tuple) -> Any:
def _item_to_collection_ref(iterator, item: str) -> BaseCollectionReference:
"""Convert collection ID to collection ref.
Args:
iterator (google.api_core.page_iterator.GRPCIterator):
iterator response
item (str): ID of the collection
"""
return iterator.client.collection(item)


def _path_helper(path: tuple) -> Tuple[str]:
"""Standardize path into a tuple of path segments.
Args:
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/base_collection.py
Expand Up @@ -107,7 +107,7 @@ def parent(self):
def _query(self) -> BaseQuery:
raise NotImplementedError

def document(self, document_id: str = None) -> Any:
def document(self, document_id: str = None) -> DocumentReference:
"""Create a sub-document underneath the current collection.
Args:
Expand Down
14 changes: 7 additions & 7 deletions google/cloud/firestore_v1/base_document.py
Expand Up @@ -22,10 +22,10 @@
from google.cloud.firestore_v1 import field_path as field_path_module
from google.cloud.firestore_v1.types import common

from typing import Any
from typing import Iterable
from typing import NoReturn
from typing import Tuple
# Types needed only for Type Hints
from google.cloud.firestore_v1.types import firestore
from google.cloud.firestore_v1.types import write
from typing import Any, Dict, Iterable, NoReturn, Union, Tuple


class BaseDocumentReference(object):
Expand Down Expand Up @@ -475,7 +475,7 @@ def get(self, field_path: str) -> Any:
nested_data = field_path_module.get_nested_value(field_path, self._data)
return copy.deepcopy(nested_data)

def to_dict(self) -> Any:
def to_dict(self) -> Union[Dict[str, Any], None]:
"""Retrieve the data contained in this snapshot.
A copy is returned since the data may contain mutable values,
Expand Down Expand Up @@ -512,7 +512,7 @@ def _get_document_path(client, path: Tuple[str]) -> str:
return _helpers.DOCUMENT_PATH_DELIMITER.join(parts)


def _consume_single_get(response_iterator) -> Any:
def _consume_single_get(response_iterator) -> firestore.BatchGetDocumentsResponse:
"""Consume a gRPC stream that should contain a single response.
The stream will correspond to a ``BatchGetDocuments`` request made
Expand Down Expand Up @@ -543,7 +543,7 @@ def _consume_single_get(response_iterator) -> Any:
return all_responses[0]


def _first_write_result(write_results: list) -> Any:
def _first_write_result(write_results: list) -> write.WriteResult:
"""Get first write result from list.
For cases where ``len(write_results) > 1``, this assumes the writes
Expand Down
16 changes: 8 additions & 8 deletions google/cloud/firestore_v1/base_query.py
Expand Up @@ -314,7 +314,7 @@ def where(self, field_path: str, op_string: str, value) -> "BaseQuery":
)

@staticmethod
def _make_order(field_path, direction) -> Any:
def _make_order(field_path, direction) -> StructuredQuery.Order:
"""Helper for :meth:`order_by`."""
return query.StructuredQuery.Order(
field=query.StructuredQuery.FieldReference(field_path=field_path),
Expand Down Expand Up @@ -394,7 +394,7 @@ def limit(self, count: int) -> "BaseQuery":
all_descendants=self._all_descendants,
)

def limit_to_last(self, count: int):
def limit_to_last(self, count: int) -> "BaseQuery":
"""Limit a query to return the last `count` matching results.
If the current query already has a `limit_to_last`
set, this will override it.
Expand Down Expand Up @@ -651,7 +651,7 @@ def end_at(
document_fields_or_snapshot, before=False, start=False
)

def _filters_pb(self) -> Any:
def _filters_pb(self) -> StructuredQuery.Filter:
"""Convert all the filters into a single generic Filter protobuf.
This may be a lone field filter or unary filter, may be a composite
Expand All @@ -674,7 +674,7 @@ def _filters_pb(self) -> Any:
return query.StructuredQuery.Filter(composite_filter=composite_filter)

@staticmethod
def _normalize_projection(projection) -> Any:
def _normalize_projection(projection) -> StructuredQuery.Projection:
"""Helper: convert field paths to message."""
if projection is not None:

Expand Down Expand Up @@ -836,7 +836,7 @@ def stream(
def on_snapshot(self, callback) -> NoReturn:
raise NotImplementedError

def _comparator(self, doc1, doc2) -> Any:
def _comparator(self, doc1, doc2) -> int:
_orders = self._orders

# Add implicit sorting by name, using the last specified direction.
Expand Down Expand Up @@ -883,7 +883,7 @@ def _comparator(self, doc1, doc2) -> Any:
return 0


def _enum_from_op_string(op_string: str) -> Any:
def _enum_from_op_string(op_string: str) -> int:
"""Convert a string representation of a binary operator to an enum.
These enums come from the protobuf message definition
Expand Down Expand Up @@ -926,7 +926,7 @@ def _isnan(value) -> bool:
return False


def _enum_from_direction(direction: str) -> Any:
def _enum_from_direction(direction: str) -> int:
"""Convert a string representation of a direction to an enum.
Args:
Expand Down Expand Up @@ -954,7 +954,7 @@ def _enum_from_direction(direction: str) -> Any:
raise ValueError(msg)


def _filter_pb(field_or_unary) -> Any:
def _filter_pb(field_or_unary) -> StructuredQuery.Filter:
"""Convert a specific protobuf filter to the generic filter type.
Args:
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/firestore_v1/client.py
Expand Up @@ -46,6 +46,9 @@
)
from typing import Any, Generator, Iterable, Tuple

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot


class Client(BaseClient):
"""Client for interacting with Google Cloud Firestore API.
Expand Down Expand Up @@ -209,7 +212,7 @@ def get_all(
transaction: Transaction = None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Generator[Any, Any, None]:
) -> Generator[DocumentSnapshot, Any, None]:
"""Retrieve a batch of documents.
.. note::
Expand Down
12 changes: 7 additions & 5 deletions google/cloud/firestore_v1/document.py
Expand Up @@ -25,7 +25,9 @@

from google.api_core import exceptions # type: ignore
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.watch import Watch
from google.protobuf import timestamp_pb2
from typing import Any, Callable, Generator, Iterable


Expand Down Expand Up @@ -62,7 +64,7 @@ def create(
document_data: dict,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Any:
) -> write.WriteResult:
"""Create the current document in the Firestore database.
Args:
Expand Down Expand Up @@ -92,7 +94,7 @@ def set(
merge: bool = False,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Any:
) -> write.WriteResult:
"""Replace the current document in the Firestore database.
A write ``option`` can be specified to indicate preconditions of
Expand Down Expand Up @@ -132,7 +134,7 @@ def update(
option: _helpers.WriteOption = None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Any:
) -> write.WriteResult:
"""Update an existing document in the Firestore database.
By default, this method verifies that the document exists on the
Expand Down Expand Up @@ -288,7 +290,7 @@ def delete(
option: _helpers.WriteOption = None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Any:
) -> timestamp_pb2.Timestamp:
"""Delete the current document in the Firestore database.
Args:
Expand Down Expand Up @@ -339,7 +341,7 @@ def get(
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this reference
will be retrieved in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
retry (google.api_core.retry.Retry): Designation of what errors, if an y,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
Expand Down

0 comments on commit 30bb3fb

Please sign in to comment.