Skip to content

Commit

Permalink
feat: asyncio microgen client (#118)
Browse files Browse the repository at this point in the history
* refactor: move generated client instantiation out of base class

* feat: integrate microgen async client to client

* feat: make collections call backed by async

* fix: failing asyncmock assertion

* refactor: remove unused install

* fix: lint

* refactor: shared functionality in client to base class

* refactor: move AsyncMock to test helpers

* fix: return type in client docs

* fix: add target example
  • Loading branch information
rafilong committed Jul 22, 2020
1 parent d82687d commit de4cc44
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 24 deletions.
33 changes: 31 additions & 2 deletions google/cloud/firestore_v1/async_client.py
Expand Up @@ -40,6 +40,12 @@
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.async_transaction import AsyncTransaction
from google.cloud.firestore_v1.services.firestore import (
async_client as firestore_client,
)
from google.cloud.firestore_v1.services.firestore.transports import (
grpc_asyncio as firestore_grpc_transport,
)


class AsyncClient(BaseClient):
Expand Down Expand Up @@ -86,6 +92,29 @@ def __init__(
client_options=client_options,
)

@property
def _firestore_api(self):
"""Lazy-loading getter GAPIC Firestore API.
Returns:
:class:`~google.cloud.gapic.firestore.v1`.async_firestore_client.FirestoreAsyncClient:
The GAPIC client with the credentials of the current client.
"""
return self._firestore_api_helper(
firestore_grpc_transport.FirestoreGrpcAsyncIOTransport,
firestore_client.FirestoreAsyncClient,
firestore_client,
)

@property
def _target(self):
"""Return the target (where the API is).
Eg. "firestore.googleapis.com"
Returns:
str: The location of the API.
"""
return self._target_helper(firestore_client.FirestoreAsyncClient)

def collection(self, *collection_path):
"""Get a reference to a collection.
Expand Down Expand Up @@ -233,7 +262,7 @@ async def collections(self):
Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]:
iterator of subcollections of the current document.
"""
iterator = self._firestore_api.list_collection_ids(
iterator = await self._firestore_api.list_collection_ids(
request={"parent": "{}/documents".format(self._database_string)},
metadata=self._rpc_metadata,
)
Expand All @@ -242,7 +271,7 @@ async def collections(self):
for i in iterator.collection_ids:
yield self.collection(i)
if iterator.next_page_token:
iterator = self._firestore_api.list_collection_ids(
iterator = await self._firestore_api.list_collection_ids(
request={
"parent": "{}/documents".format(self._database_string),
"page_token": iterator.next_page_token,
Expand Down
30 changes: 10 additions & 20 deletions google/cloud/firestore_v1/base_client.py
Expand Up @@ -35,10 +35,6 @@
from google.cloud.firestore_v1 import types
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.field_path import render_field_path
from google.cloud.firestore_v1.services.firestore import client as firestore_client
from google.cloud.firestore_v1.services.firestore.transports import (
grpc as firestore_grpc_transport,
)

DEFAULT_DATABASE = "(default)"
"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`."""
Expand Down Expand Up @@ -117,12 +113,10 @@ def __init__(
self._database = database
self._emulator_host = os.getenv(_FIRESTORE_EMULATOR_HOST)

@property
def _firestore_api(self):
def _firestore_api_helper(self, transport, client_class, client_module):
"""Lazy-loading getter GAPIC Firestore API.
Returns:
:class:`~google.cloud.gapic.firestore.v1`.firestore_client.FirestoreClient:
<The GAPIC client with the credentials of the current client.
The GAPIC client with the credentials of the current client.
"""
if self._firestore_api_internal is None:
# Use a custom channel.
Expand All @@ -131,30 +125,26 @@ def _firestore_api(self):
if self._emulator_host is not None:
# TODO(microgen): this likely needs to be adapted to use insecure_channel
# on new generated surface.
channel = firestore_grpc_transport.FirestoreGrpcTransport.create_channel(
host=self._emulator_host
)
channel = transport.create_channel(host=self._emulator_host)
else:
channel = firestore_grpc_transport.FirestoreGrpcTransport.create_channel(
channel = transport.create_channel(
self._target,
credentials=self._credentials,
options={"grpc.keepalive_time_ms": 30000}.items(),
)

self._transport = firestore_grpc_transport.FirestoreGrpcTransport(
host=self._target, channel=channel
)
self._transport = transport(host=self._target, channel=channel)

self._firestore_api_internal = firestore_client.FirestoreClient(
self._firestore_api_internal = client_class(
transport=self._transport, client_options=self._client_options
)
firestore_client._client_info = self._client_info
client_module._client_info = self._client_info

return self._firestore_api_internal

@property
def _target(self):
def _target_helper(self, client_class):
"""Return the target (where the API is).
Eg. "firestore.googleapis.com"
Returns:
str: The location of the API.
Expand All @@ -164,7 +154,7 @@ def _target(self):
elif self._client_options and self._client_options.api_endpoint:
return self._client_options.api_endpoint
else:
return firestore_client.FirestoreClient.DEFAULT_ENDPOINT
return client_class.DEFAULT_ENDPOINT

@property
def _database_string(self):
Expand Down
27 changes: 27 additions & 0 deletions google/cloud/firestore_v1/client.py
Expand Up @@ -40,6 +40,10 @@
from google.cloud.firestore_v1.collection import CollectionReference
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.transaction import Transaction
from google.cloud.firestore_v1.services.firestore import client as firestore_client
from google.cloud.firestore_v1.services.firestore.transports import (
grpc as firestore_grpc_transport,
)


class Client(BaseClient):
Expand Down Expand Up @@ -86,6 +90,29 @@ def __init__(
client_options=client_options,
)

@property
def _firestore_api(self):
"""Lazy-loading getter GAPIC Firestore API.
Returns:
:class:`~google.cloud.gapic.firestore.v1`.firestore_client.FirestoreClient:
The GAPIC client with the credentials of the current client.
"""
return self._firestore_api_helper(
firestore_grpc_transport.FirestoreGrpcTransport,
firestore_client.FirestoreClient,
firestore_client,
)

@property
def _target(self):
"""Return the target (where the API is).
Eg. "firestore.googleapis.com"
Returns:
str: The location of the API.
"""
return self._target_helper(firestore_client.FirestoreClient)

def collection(self, *collection_path):
"""Get a reference to a collection.
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Expand Up @@ -70,7 +70,7 @@ def lint_setup_py(session):

def default(session, test_dir, ignore_dir=None):
# Install all test dependencies, then install this package in-place.
session.install("asyncmock", "pytest-asyncio", "aiounittest")
session.install("pytest-asyncio", "aiounittest")

session.install("mock", "pytest", "pytest-cov")
session.install("-e", ".")
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/v1/test__helpers.py
Expand Up @@ -20,6 +20,11 @@
import mock


class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)


class TestGeoPoint(unittest.TestCase):
@staticmethod
def _get_target_class():
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/v1/test_async_client.py
Expand Up @@ -18,6 +18,7 @@
import aiounittest

import mock
from tests.unit.v1.test__helpers import AsyncMock


class TestAsyncClient(aiounittest.AsyncTestCase):
Expand Down Expand Up @@ -200,7 +201,8 @@ async def test_collections(self):

collection_ids = ["users", "projects"]
client = self._make_default_one()
firestore_api = mock.Mock(spec=["list_collection_ids"])
firestore_api = AsyncMock()
firestore_api.mock_add_spec(spec=["list_collection_ids"])
client._firestore_api_internal = firestore_api

# TODO(microgen): list_collection_ids isn't a pager.
Expand Down

0 comments on commit de4cc44

Please sign in to comment.