Skip to content

Commit

Permalink
fix: type hint improvements (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
crwilcox committed Aug 19, 2020
1 parent f3bedc1 commit d30fff8
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 92 deletions.
5 changes: 1 addition & 4 deletions google/cloud/firestore.py
Expand Up @@ -48,11 +48,8 @@
from google.cloud.firestore_v1 import WriteOption
from typing import List

__all__: List[str]
__version__: str


__all__ = [
__all__: List[str] = [
"__version__",
"ArrayRemove",
"ArrayUnion",
Expand Down
6 changes: 1 addition & 5 deletions google/cloud/firestore_v1/__init__.py
Expand Up @@ -22,7 +22,6 @@

__version__ = get_distribution("google-cloud-firestore").version


from google.cloud.firestore_v1 import types
from google.cloud.firestore_v1._helpers import GeoPoint
from google.cloud.firestore_v1._helpers import ExistsOption
Expand Down Expand Up @@ -99,15 +98,12 @@
from .types.write import DocumentTransform
from typing import List

__all__: List[str]
__version__: str

# from .types.write import ExistenceFilter
# from .types.write import Write
# from .types.write import WriteResult


__all__ = [
__all__: List[str] = [
"__version__",
"ArrayRemove",
"ArrayUnion",
Expand Down
1 change: 0 additions & 1 deletion google/cloud/firestore_v1/_helpers.py
Expand Up @@ -35,7 +35,6 @@

_EmptyDict: transforms.Sentinel
_GRPC_ERROR_MAPPING: dict
_datetime_to_pb_timestamp: Any


BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}."
Expand Down
6 changes: 2 additions & 4 deletions google/cloud/firestore_v1/async_client.py
Expand Up @@ -49,9 +49,7 @@
from google.cloud.firestore_v1.services.firestore.transports import (
grpc_asyncio as firestore_grpc_transport,
)
from typing import Any, AsyncGenerator, NoReturn

_CLIENT_INFO: Any
from typing import Any, AsyncGenerator


class AsyncClient(BaseClient):
Expand Down Expand Up @@ -152,7 +150,7 @@ def collection(self, *collection_path) -> AsyncCollectionReference:
"""
return AsyncCollectionReference(*_path_helper(collection_path), client=self)

def collection_group(self, collection_id) -> NoReturn:
def collection_group(self, collection_id) -> AsyncQuery:
"""
Creates and returns a new AsyncQuery that includes all documents in the
database that are contained in a collection or subcollection with the
Expand Down
6 changes: 4 additions & 2 deletions google/cloud/firestore_v1/async_document.py
Expand Up @@ -23,7 +23,7 @@
from google.api_core import exceptions # type: ignore
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import common
from typing import AsyncGenerator, Coroutine
from typing import Any, AsyncGenerator, Coroutine, Union


class AsyncDocumentReference(BaseDocumentReference):
Expand Down Expand Up @@ -281,7 +281,9 @@ async def delete(self, option=None) -> Coroutine:

return commit_response.commit_time

async def get(self, field_paths=None, transaction=None) -> DocumentSnapshot:
async def get(
self, field_paths=None, transaction=None
) -> Union[DocumentSnapshot, Coroutine[Any, Any, DocumentSnapshot]]:
"""Retrieve a snapshot of the current document.
See :meth:`~google.cloud.firestore_v1.base_client.BaseClient.field_path` for
Expand Down
14 changes: 3 additions & 11 deletions google/cloud/firestore_v1/async_transaction.py
Expand Up @@ -37,17 +37,9 @@
from google.cloud.firestore_v1 import types

from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.async_document import DocumentSnapshot
from google.cloud.firestore_v1.async_query import AsyncQuery
from typing import Coroutine

_CANT_BEGIN: str
_CANT_COMMIT: str
_CANT_ROLLBACK: str
_EXCEED_ATTEMPTS_TEMPLATE: str
_INITIAL_SLEEP: float
_MAX_SLEEP: float
_MULTIPLIER: float
_WRITE_READ_ONLY: str
from typing import Any, AsyncGenerator, Coroutine


class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
Expand Down Expand Up @@ -162,7 +154,7 @@ async def get_all(self, references) -> Coroutine:
"""
return await self._client.get_all(references, transaction=self)

async def get(self, ref_or_query) -> Coroutine:
async def get(self, ref_or_query) -> AsyncGenerator[DocumentSnapshot, Any]:
"""
Retrieve a document or a query result from the database.
Args:
Expand Down
61 changes: 40 additions & 21 deletions google/cloud/firestore_v1/base_client.py
Expand Up @@ -23,6 +23,7 @@
* a :class:`~google.cloud.firestore_v1.client.Client` owns a
:class:`~google.cloud.firestore_v1.document.DocumentReference`
"""

import os

import google.api_core.client_options # type: ignore
Expand All @@ -34,29 +35,38 @@
from google.cloud.firestore_v1 import __version__
from google.cloud.firestore_v1 import types
from google.cloud.firestore_v1.base_document import DocumentSnapshot

from google.cloud.firestore_v1.field_path import render_field_path
from typing import Any, List, NoReturn, Optional, Tuple, Union
from typing import (
Any,
AsyncGenerator,
Generator,
List,
Optional,
Tuple,
Union,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_collection import BaseCollectionReference
from google.cloud.firestore_v1.base_document import BaseDocumentReference
from google.cloud.firestore_v1.base_transaction import BaseTransaction
from google.cloud.firestore_v1.base_batch import BaseWriteBatch
from google.cloud.firestore_v1.base_query import BaseQuery

_ACTIVE_TXN: str
_BAD_DOC_TEMPLATE: str
_BAD_OPTION_ERR: str
_CLIENT_INFO: Any
_FIRESTORE_EMULATOR_HOST: str
_INACTIVE_TXN: str
__version__: str

DEFAULT_DATABASE = "(default)"
"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`."""
_BAD_OPTION_ERR = (
"Exactly one of ``last_update_time`` or ``exists`` " "must be provided."
)
_BAD_DOC_TEMPLATE = (
_BAD_DOC_TEMPLATE: str = (
"Document {!r} appeared in response but was not present among references"
)
_ACTIVE_TXN = "There is already an active transaction."
_INACTIVE_TXN = "There is no active transaction."
_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
_FIRESTORE_EMULATOR_HOST = "FIRESTORE_EMULATOR_HOST"
_ACTIVE_TXN: str = "There is already an active transaction."
_INACTIVE_TXN: str = "There is no active transaction."
_CLIENT_INFO: Any = client_info.ClientInfo(client_library_version=__version__)
_FIRESTORE_EMULATOR_HOST: str = "FIRESTORE_EMULATOR_HOST"


class BaseClient(ClientWithProject):
Expand Down Expand Up @@ -214,13 +224,13 @@ def _rpc_metadata(self):

return self._rpc_metadata_internal

def collection(self, *collection_path) -> NoReturn:
def collection(self, *collection_path) -> BaseCollectionReference:
raise NotImplementedError

def collection_group(self, collection_id) -> NoReturn:
def collection_group(self, collection_id) -> BaseQuery:
raise NotImplementedError

def _get_collection_reference(self, collection_id) -> NoReturn:
def _get_collection_reference(self, collection_id) -> BaseCollectionReference:
"""Checks validity of collection_id and then uses subclasses collection implementation.
Args:
Expand All @@ -241,7 +251,7 @@ def _get_collection_reference(self, collection_id) -> NoReturn:

return self.collection(collection_id)

def document(self, *document_path) -> NoReturn:
def document(self, *document_path) -> BaseDocumentReference:
raise NotImplementedError

def _document_path_helper(self, *document_path) -> List[str]:
Expand Down Expand Up @@ -342,16 +352,25 @@ def write_option(
extra = "{!r} was provided".format(name)
raise TypeError(_BAD_OPTION_ERR, extra)

def get_all(self, references, field_paths=None, transaction=None) -> NoReturn:
def get_all(
self, references, field_paths=None, transaction=None
) -> Union[
AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any]
]:
raise NotImplementedError

def collections(self) -> NoReturn:
def collections(
self,
) -> Union[
AsyncGenerator[BaseCollectionReference, Any],
Generator[BaseCollectionReference, Any, Any],
]:
raise NotImplementedError

def batch(self) -> NoReturn:
def batch(self) -> BaseWriteBatch:
raise NotImplementedError

def transaction(self, **kwargs) -> NoReturn:
def transaction(self, **kwargs) -> BaseTransaction:
raise NotImplementedError


Expand Down
57 changes: 41 additions & 16 deletions google/cloud/firestore_v1/base_collection.py
Expand Up @@ -17,8 +17,21 @@

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.document import DocumentReference
from typing import Any, NoReturn, Tuple

from typing import (
Any,
AsyncGenerator,
Coroutine,
Generator,
AsyncIterator,
Iterator,
NoReturn,
Tuple,
Union,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import BaseQuery

_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"

Expand Down Expand Up @@ -87,7 +100,7 @@ def parent(self):
parent_path = self._path[:-1]
return self._client.document(*parent_path)

def _query(self) -> NoReturn:
def _query(self) -> BaseQuery:
raise NotImplementedError

def document(self, document_id=None) -> Any:
Expand Down Expand Up @@ -131,13 +144,19 @@ def _parent_info(self) -> Tuple[Any, str]:
expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id))
return parent_path, expected_prefix

def add(self, document_data, document_id=None) -> NoReturn:
def add(
self, document_data, document_id=None
) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]:
raise NotImplementedError

def list_documents(self, page_size=None) -> NoReturn:
def list_documents(
self, page_size=None
) -> Union[
Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any]
]:
raise NotImplementedError

def select(self, field_paths) -> NoReturn:
def select(self, field_paths) -> BaseQuery:
"""Create a "select" query with this collection as parent.
See
Expand All @@ -156,7 +175,7 @@ def select(self, field_paths) -> NoReturn:
query = self._query()
return query.select(field_paths)

def where(self, field_path, op_string, value) -> NoReturn:
def where(self, field_path, op_string, value) -> BaseQuery:
"""Create a "where" query with this collection as parent.
See
Expand All @@ -180,7 +199,7 @@ def where(self, field_path, op_string, value) -> NoReturn:
query = self._query()
return query.where(field_path, op_string, value)

def order_by(self, field_path, **kwargs) -> NoReturn:
def order_by(self, field_path, **kwargs) -> BaseQuery:
"""Create an "order by" query with this collection as parent.
See
Expand All @@ -202,7 +221,7 @@ def order_by(self, field_path, **kwargs) -> NoReturn:
query = self._query()
return query.order_by(field_path, **kwargs)

def limit(self, count) -> NoReturn:
def limit(self, count) -> BaseQuery:
"""Create a limited query with this collection as parent.
.. note::
Expand Down Expand Up @@ -242,7 +261,7 @@ def limit_to_last(self, count):
query = self._query()
return query.limit_to_last(count)

def offset(self, num_to_skip) -> NoReturn:
def offset(self, num_to_skip) -> BaseQuery:
"""Skip to an offset in a query with this collection as parent.
See
Expand All @@ -260,7 +279,7 @@ def offset(self, num_to_skip) -> NoReturn:
query = self._query()
return query.offset(num_to_skip)

def start_at(self, document_fields) -> NoReturn:
def start_at(self, document_fields) -> BaseQuery:
"""Start query at a cursor with this collection as parent.
See
Expand All @@ -281,7 +300,7 @@ def start_at(self, document_fields) -> NoReturn:
query = self._query()
return query.start_at(document_fields)

def start_after(self, document_fields) -> NoReturn:
def start_after(self, document_fields) -> BaseQuery:
"""Start query after a cursor with this collection as parent.
See
Expand All @@ -302,7 +321,7 @@ def start_after(self, document_fields) -> NoReturn:
query = self._query()
return query.start_after(document_fields)

def end_before(self, document_fields) -> NoReturn:
def end_before(self, document_fields) -> BaseQuery:
"""End query before a cursor with this collection as parent.
See
Expand All @@ -323,7 +342,7 @@ def end_before(self, document_fields) -> NoReturn:
query = self._query()
return query.end_before(document_fields)

def end_at(self, document_fields) -> NoReturn:
def end_at(self, document_fields) -> BaseQuery:
"""End query at a cursor with this collection as parent.
See
Expand All @@ -344,10 +363,16 @@ def end_at(self, document_fields) -> NoReturn:
query = self._query()
return query.end_at(document_fields)

def get(self, transaction=None) -> NoReturn:
def get(
self, transaction=None
) -> Union[
Generator[DocumentSnapshot, Any, Any], AsyncGenerator[DocumentSnapshot, Any]
]:
raise NotImplementedError

def stream(self, transaction=None) -> NoReturn:
def stream(
self, transaction=None
) -> Union[Iterator[DocumentSnapshot], AsyncIterator[DocumentSnapshot]]:
raise NotImplementedError

def on_snapshot(self, callback) -> NoReturn:
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/base_document.py
Expand Up @@ -190,7 +190,7 @@ def update(self, field_updates, option=None) -> NoReturn:
def delete(self, option=None) -> NoReturn:
raise NotImplementedError

def get(self, field_paths=None, transaction=None) -> NoReturn:
def get(self, field_paths=None, transaction=None) -> "DocumentSnapshot":
raise NotImplementedError

def collections(self, page_size=None) -> NoReturn:
Expand Down

0 comments on commit d30fff8

Please sign in to comment.