Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
feat: add type hints for method params (#182)
Co-authored-by: Christopher Wilcox <crwilcox@google.com>
  • Loading branch information
HemangChothani and crwilcox committed Oct 9, 2020
1 parent c3acd4a commit 9b6c2f3
Show file tree
Hide file tree
Showing 17 changed files with 214 additions and 125 deletions.
10 changes: 5 additions & 5 deletions google/cloud/firestore_v1/async_client.py
Expand Up @@ -49,7 +49,7 @@
from google.cloud.firestore_v1.services.firestore.transports import (
grpc_asyncio as firestore_grpc_transport,
)
from typing import Any, AsyncGenerator
from typing import Any, AsyncGenerator, Iterable, Tuple


class AsyncClient(BaseClient):
Expand Down Expand Up @@ -119,7 +119,7 @@ def _target(self):
"""
return self._target_helper(firestore_client.FirestoreAsyncClient)

def collection(self, *collection_path) -> AsyncCollectionReference:
def collection(self, *collection_path: Tuple[str]) -> AsyncCollectionReference:
"""Get a reference to a collection.
For a top-level collection:
Expand Down Expand Up @@ -150,7 +150,7 @@ def collection(self, *collection_path) -> AsyncCollectionReference:
"""
return AsyncCollectionReference(*_path_helper(collection_path), client=self)

def collection_group(self, collection_id) -> AsyncCollectionGroup:
def collection_group(self, collection_id: str) -> AsyncCollectionGroup:
"""
Creates and returns a new AsyncQuery that includes all documents in the
database that are contained in a collection or subcollection with the
Expand All @@ -172,7 +172,7 @@ def collection_group(self, collection_id) -> AsyncCollectionGroup:
"""
return AsyncCollectionGroup(self._get_collection_reference(collection_id))

def document(self, *document_path) -> AsyncDocumentReference:
def document(self, *document_path: Tuple[str]) -> AsyncDocumentReference:
"""Get a reference to a document in a collection.
For a top-level document:
Expand Down Expand Up @@ -208,7 +208,7 @@ def document(self, *document_path) -> AsyncDocumentReference:
)

async def get_all(
self, references, field_paths=None, transaction=None
self, references: list, field_paths: Iterable[str] = None, transaction=None,
) -> AsyncGenerator[DocumentSnapshot, Any]:
"""Retrieve a batch of documents.
Expand Down
13 changes: 9 additions & 4 deletions google/cloud/firestore_v1/async_collection.py
Expand Up @@ -28,6 +28,9 @@
from typing import AsyncIterator
from typing import Any, AsyncGenerator, Tuple

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction


class AsyncCollectionReference(BaseCollectionReference):
"""A reference to a collection in a Firestore database.
Expand Down Expand Up @@ -66,7 +69,9 @@ def _query(self) -> async_query.AsyncQuery:
"""
return async_query.AsyncQuery(self)

async def add(self, document_data, document_id=None) -> Tuple[Any, Any]:
async def add(
self, document_data: dict, document_id: str = None
) -> Tuple[Any, Any]:
"""Create a document in the Firestore database with the provided data.
Args:
Expand Down Expand Up @@ -98,7 +103,7 @@ async def add(self, document_data, document_id=None) -> Tuple[Any, Any]:
return write_result.update_time, document_ref

async def list_documents(
self, page_size=None
self, page_size: int = None
) -> AsyncGenerator[DocumentReference, None]:
"""List all subdocuments of the current collection.
Expand Down Expand Up @@ -127,7 +132,7 @@ async def list_documents(
async for i in iterator:
yield _item_to_document_ref(self, i)

async def get(self, transaction=None) -> list:
async def get(self, transaction: Transaction = None) -> list:
"""Read the documents in this collection.
This sends a ``RunQuery`` RPC and returns a list of documents
Expand All @@ -149,7 +154,7 @@ async def get(self, transaction=None) -> list:
return await query.get(transaction=transaction)

async def stream(
self, transaction=None
self, transaction: Transaction = None
) -> AsyncIterator[async_document.DocumentSnapshot]:
"""Read the documents in this collection.
Expand Down
16 changes: 9 additions & 7 deletions google/cloud/firestore_v1/async_document.py
Expand Up @@ -23,7 +23,7 @@
from google.api_core import exceptions # type: ignore
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import common
from typing import Any, AsyncGenerator, Coroutine, Union
from typing import Any, AsyncGenerator, Coroutine, Iterable, Union


class AsyncDocumentReference(BaseDocumentReference):
Expand Down Expand Up @@ -54,7 +54,7 @@ class AsyncDocumentReference(BaseDocumentReference):
def __init__(self, *path, **kwargs) -> None:
super(AsyncDocumentReference, self).__init__(*path, **kwargs)

async def create(self, document_data) -> Coroutine:
async def create(self, document_data: dict) -> Coroutine:
"""Create the current document in the Firestore database.
Args:
Expand All @@ -75,7 +75,7 @@ async def create(self, document_data) -> Coroutine:
write_results = await batch.commit()
return _first_write_result(write_results)

async def set(self, document_data, merge=False) -> Coroutine:
async def set(self, document_data: dict, merge: bool = False) -> Coroutine:
"""Replace the current document in the Firestore database.
A write ``option`` can be specified to indicate preconditions of
Expand Down Expand Up @@ -106,7 +106,9 @@ async def set(self, document_data, merge=False) -> Coroutine:
write_results = await batch.commit()
return _first_write_result(write_results)

async def update(self, field_updates, option=None) -> Coroutine:
async def update(
self, field_updates: dict, option: _helpers.WriteOption = None
) -> Coroutine:
"""Update an existing document in the Firestore database.
By default, this method verifies that the document exists on the
Expand Down Expand Up @@ -254,7 +256,7 @@ async def update(self, field_updates, option=None) -> Coroutine:
write_results = await batch.commit()
return _first_write_result(write_results)

async def delete(self, option=None) -> Coroutine:
async def delete(self, option: _helpers.WriteOption = None) -> Coroutine:
"""Delete the current document in the Firestore database.
Args:
Expand Down Expand Up @@ -282,7 +284,7 @@ async def delete(self, option=None) -> Coroutine:
return commit_response.commit_time

async def get(
self, field_paths=None, transaction=None
self, field_paths: Iterable[str] = None, transaction=None
) -> Union[DocumentSnapshot, Coroutine[Any, Any, DocumentSnapshot]]:
"""Retrieve a snapshot of the current document.
Expand Down Expand Up @@ -348,7 +350,7 @@ async def get(
update_time=update_time,
)

async def collections(self, page_size=None) -> AsyncGenerator:
async def collections(self, page_size: int = None) -> AsyncGenerator:
"""List subcollections of the current document.
Args:
Expand Down
7 changes: 5 additions & 2 deletions google/cloud/firestore_v1/async_query.py
Expand Up @@ -31,6 +31,9 @@
from google.cloud.firestore_v1 import async_document
from typing import AsyncGenerator

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction


class AsyncQuery(BaseQuery):
"""Represents a query to the Firestore API.
Expand Down Expand Up @@ -114,7 +117,7 @@ def __init__(
all_descendants=all_descendants,
)

async def get(self, transaction=None) -> list:
async def get(self, transaction: Transaction = None) -> list:
"""Read the documents in the collection that match this query.
This sends a ``RunQuery`` RPC and returns a list of documents
Expand Down Expand Up @@ -154,7 +157,7 @@ async def get(self, transaction=None) -> list:
return result

async def stream(
self, transaction=None
self, transaction: Transaction = None
) -> AsyncGenerator[async_document.DocumentSnapshot, None]:
"""Read the documents in the collection that match this query.
Expand Down
29 changes: 20 additions & 9 deletions google/cloud/firestore_v1/async_transaction.py
Expand Up @@ -39,7 +39,10 @@
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.async_document import DocumentSnapshot
from google.cloud.firestore_v1.async_query import AsyncQuery
from typing import Any, AsyncGenerator, Coroutine
from typing import Any, AsyncGenerator, Callable, Coroutine

# Types needed only for Type Hints
from google.cloud.firestore_v1.client import Client


class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
Expand All @@ -60,7 +63,7 @@ def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None:
super(AsyncTransaction, self).__init__(client)
BaseTransaction.__init__(self, max_attempts, read_only)

def _add_write_pbs(self, write_pbs) -> None:
def _add_write_pbs(self, write_pbs: list) -> None:
"""Add `Write`` protobufs to this transaction.
Args:
Expand All @@ -75,7 +78,7 @@ def _add_write_pbs(self, write_pbs) -> None:

super(AsyncTransaction, self)._add_write_pbs(write_pbs)

async def _begin(self, retry_id=None) -> None:
async def _begin(self, retry_id: bytes = None) -> None:
"""Begin the transaction.
Args:
Expand Down Expand Up @@ -141,7 +144,7 @@ async def _commit(self) -> list:
self._clean_up()
return list(commit_response.write_results)

async def get_all(self, references) -> Coroutine:
async def get_all(self, references: list) -> Coroutine:
"""Retrieves multiple documents from Firestore.
Args:
Expand Down Expand Up @@ -187,7 +190,9 @@ class _AsyncTransactional(_BaseTransactional):
def __init__(self, to_wrap) -> None:
super(_AsyncTransactional, self).__init__(to_wrap)

async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine:
async def _pre_commit(
self, transaction: AsyncTransaction, *args, **kwargs
) -> Coroutine:
"""Begin transaction and call the wrapped coroutine.
If the coroutine raises an exception, the transaction will be rolled
Expand Down Expand Up @@ -225,7 +230,7 @@ async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine:
await transaction._rollback()
raise

async def _maybe_commit(self, transaction) -> bool:
async def _maybe_commit(self, transaction: AsyncTransaction) -> bool:
"""Try to commit the transaction.
If the transaction is read-write and the ``Commit`` fails with the
Expand Down Expand Up @@ -291,7 +296,9 @@ async def __call__(self, transaction, *args, **kwargs):
raise ValueError(msg)


def async_transactional(to_wrap) -> _AsyncTransactional:
def async_transactional(
to_wrap: Callable[[AsyncTransaction], Any]
) -> _AsyncTransactional:
"""Decorate a callable so that it runs in a transaction.
Args:
Expand All @@ -307,7 +314,9 @@ def async_transactional(to_wrap) -> _AsyncTransactional:


# TODO(crwilcox): this was 'coroutine' from pytype merge-pyi...
async def _commit_with_retry(client, write_pbs, transaction_id) -> types.CommitResponse:
async def _commit_with_retry(
client: Client, write_pbs: list, transaction_id: bytes
) -> types.CommitResponse:
"""Call ``Commit`` on the GAPIC client with retry / sleep.
Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level
Expand Down Expand Up @@ -350,7 +359,9 @@ async def _commit_with_retry(client, write_pbs, transaction_id) -> types.CommitR
current_sleep = await _sleep(current_sleep)


async def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER) -> float:
async def _sleep(
current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER
) -> float:
"""Sleep and produce a new sleep time.
.. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\
Expand Down
26 changes: 21 additions & 5 deletions google/cloud/firestore_v1/base_batch.py
Expand Up @@ -17,6 +17,10 @@

from google.cloud.firestore_v1 import _helpers

# Types needed only for Type Hints
from google.cloud.firestore_v1.document import DocumentReference
from typing import Union


class BaseWriteBatch(object):
"""Accumulate write operations to be sent in a batch.
Expand All @@ -36,7 +40,7 @@ def __init__(self, client) -> None:
self.write_results = None
self.commit_time = None

def _add_write_pbs(self, write_pbs) -> None:
def _add_write_pbs(self, write_pbs: list) -> None:
"""Add `Write`` protobufs to this transaction.
This method intended to be over-ridden by subclasses.
Expand All @@ -47,7 +51,7 @@ def _add_write_pbs(self, write_pbs) -> None:
"""
self._write_pbs.extend(write_pbs)

def create(self, reference, document_data) -> None:
def create(self, reference: DocumentReference, document_data: dict) -> None:
"""Add a "change" to this batch to create a document.
If the document given by ``reference`` already exists, then this
Expand All @@ -62,7 +66,12 @@ def create(self, reference, document_data) -> None:
write_pbs = _helpers.pbs_for_create(reference._document_path, document_data)
self._add_write_pbs(write_pbs)

def set(self, reference, document_data, merge=False) -> None:
def set(
self,
reference: DocumentReference,
document_data: dict,
merge: Union[bool, list] = False,
) -> None:
"""Add a "change" to replace a document.
See
Expand Down Expand Up @@ -90,7 +99,12 @@ def set(self, reference, document_data, merge=False) -> None:

self._add_write_pbs(write_pbs)

def update(self, reference, field_updates, option=None) -> None:
def update(
self,
reference: DocumentReference,
field_updates: dict,
option: _helpers.WriteOption = None,
) -> None:
"""Add a "change" to update a document.
See
Expand All @@ -113,7 +127,9 @@ def update(self, reference, field_updates, option=None) -> None:
)
self._add_write_pbs(write_pbs)

def delete(self, reference, option=None) -> None:
def delete(
self, reference: DocumentReference, option: _helpers.WriteOption = None
) -> None:
"""Add a "change" to delete a document.
See
Expand Down

0 comments on commit 9b6c2f3

Please sign in to comment.