Skip to content

Commit

Permalink
fix: update cps async clients (#41)
Browse files Browse the repository at this point in the history
* fix: Update cps async clients

These async clients can now be constructed on a different thread than __aenter__ is called on.

Also backport typing changes from https://github.com/googleapis/gapic-generator-python/pull/641/files

* chore: fix lint errors
  • Loading branch information
dpcollins-google committed Oct 8, 2020
1 parent 4276882 commit f41c228
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 40 deletions.
@@ -1,5 +1,5 @@
from asyncio import Future, Queue, ensure_future
from typing import Callable, NamedTuple, Dict, Set
from typing import Callable, NamedTuple, Dict, Set, Optional

from google.cloud.pubsub_v1.subscriber.message import Message

Expand All @@ -18,21 +18,31 @@ class _RunningSubscriber(NamedTuple):


class AssigningSubscriber(AsyncSubscriber, PermanentFailable):
_assigner: Assigner
_assigner_factory: Callable[[], Assigner]
_subscriber_factory: PartitionSubscriberFactory

_subscribers: Dict[Partition, _RunningSubscriber]
_messages: "Queue[Message]"

# Lazily initialized to ensure they are initialized on the thread where __aenter__ is called.
_assigner: Optional[Assigner]
_messages: Optional["Queue[Message]"]
_assign_poller: Future

def __init__(
self, assigner: Assigner, subscriber_factory: PartitionSubscriberFactory
self,
assigner_factory: Callable[[], Assigner],
subscriber_factory: PartitionSubscriberFactory,
):
"""
Accepts a factory for an Assigner instead of an Assigner because GRPC asyncio uses the current thread's event
loop.
"""
super().__init__()
self._assigner = assigner
self._assigner_factory = assigner_factory
self._assigner = None
self._subscriber_factory = subscriber_factory
self._subscribers = {}
self._messages = Queue()
self._messages = None

async def read(self) -> Message:
return await self.await_unless_failed(self._messages.get())
Expand Down Expand Up @@ -65,6 +75,8 @@ async def _assign_action(self):
del self._subscribers[partition]

async def __aenter__(self):
self._messages = Queue()
self._assigner = self._assigner_factory()
await self._assigner.__aenter__()
self._assign_poller = ensure_future(self.run_poller(self._assign_action))
return self
Expand Down
@@ -1,4 +1,4 @@
from typing import Mapping
from typing import Mapping, Callable, Optional

from google.pubsub_v1 import PubsubMessage

Expand All @@ -10,11 +10,17 @@


class AsyncPublisherImpl(AsyncPublisher):
_publisher: Publisher

def __init__(self, publisher: Publisher):
_publisher_factory: Callable[[], Publisher]
_publisher: Optional[Publisher]

def __init__(self, publisher_factory: Callable[[], Publisher]):
"""
Accepts a factory for a Publisher instead of a Publisher because GRPC asyncio uses the current thread's event
loop.
"""
super().__init__()
self._publisher = publisher
self._publisher_factory = publisher_factory
self._publisher = None

async def publish(
self, data: bytes, ordering_key: str = "", **attrs: Mapping[str, str]
Expand All @@ -26,6 +32,7 @@ async def publish(
return (await self._publisher.publish(psl_message)).encode()

async def __aenter__(self):
self._publisher = self._publisher_factory()
await self._publisher.__aenter__()
return self

Expand Down
11 changes: 7 additions & 4 deletions google/cloud/pubsublite/cloudpubsub/make_publisher.py
Expand Up @@ -40,10 +40,13 @@ def make_async_publisher(
GoogleApiCallException on any error determining topic structure.
"""
metadata = merge_metadata(pubsub_context(framework="CLOUD_PUBSUB_SHIM"), metadata)
underlying = make_wire_publisher(
topic, batching_delay_secs, credentials, client_options, metadata
)
return AsyncPublisherImpl(underlying)

def underlying_factory():
return make_wire_publisher(
topic, batching_delay_secs, credentials, client_options, metadata
)

return AsyncPublisherImpl(underlying_factory)


def make_publisher(
Expand Down
12 changes: 7 additions & 5 deletions google/cloud/pubsublite/cloudpubsub/make_subscriber.py
@@ -1,5 +1,5 @@
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, Mapping, Set, AsyncIterator
from typing import Optional, Mapping, Set, AsyncIterator, Callable
from uuid import uuid4

from google.api_core.client_options import ClientOptions
Expand Down Expand Up @@ -170,14 +170,16 @@ def make_async_subscriber(
client_options = ClientOptions(
api_endpoint=regional_endpoint(subscription.location.region)
)
assigner: Assigner
assigner_factory: Callable[[], Assigner]
if fixed_partitions:
assigner = FixedSetAssigner(fixed_partitions)
assigner_factory = lambda: FixedSetAssigner(fixed_partitions) # noqa: E731
else:
assignment_client = PartitionAssignmentServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
assigner = _make_dynamic_assigner(subscription, assignment_client, metadata)
assigner_factory = lambda: _make_dynamic_assigner( # noqa: E731
subscription, assignment_client, metadata
)

subscribe_client = SubscriberServiceAsyncClient(
credentials=credentials, client_options=client_options
Expand All @@ -196,7 +198,7 @@ def make_async_subscriber(
nack_handler,
message_transformer,
)
return AssigningSubscriber(assigner, partition_subscriber_factory)
return AssigningSubscriber(assigner_factory, partition_subscriber_factory)


def make_subscriber(
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/pubsublite/internal/wire/connection.py
Expand Up @@ -34,5 +34,5 @@ async def read(self) -> Response:
class ConnectionFactory(Generic[Request, Response]):
"""A factory for producing Connections."""

def new(self) -> Connection[Request, Response]:
async def new(self) -> Connection[Request, Response]:
raise NotImplementedError()
18 changes: 11 additions & 7 deletions google/cloud/pubsublite/internal/wire/gapic_connection.py
@@ -1,4 +1,4 @@
from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable
from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable, Awaitable
import asyncio

from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition
Expand Down Expand Up @@ -44,10 +44,10 @@ async def read(self) -> Response:
self.fail(e)
raise self.error()

def __aenter__(self):
async def __aenter__(self):
return self

def __aexit__(self, exc_type, exc_value, traceback) -> None:
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass

async def __anext__(self) -> Request:
Expand All @@ -64,15 +64,19 @@ def __aiter__(self) -> AsyncIterator[Response]:
class GapicConnectionFactory(ConnectionFactory[Request, Response]):
"""A ConnectionFactory that produces GapicConnections."""

_producer = Callable[[AsyncIterator[Request]], AsyncIterable[Response]]
_producer = Callable[[AsyncIterator[Request]], Awaitable[AsyncIterable[Response]]]

def __init__(
self, producer: Callable[[AsyncIterator[Request]], AsyncIterable[Response]]
self,
producer: Callable[
[AsyncIterator[Request]], Awaitable[AsyncIterable[Response]]
],
):
self._producer = producer

def new(self) -> Connection[Request, Response]:
async def new(self) -> Connection[Request, Response]:
conn = GapicConnection[Request, Response]()
response_iterable = self._producer(conn)
response_fut = self._producer(conn)
response_iterable = await response_fut
conn.set_response_it(response_iterable.__aiter__())
return conn
8 changes: 7 additions & 1 deletion google/cloud/pubsublite/internal/wire/retrying_connection.py
Expand Up @@ -65,7 +65,8 @@ async def _run_loop(self):
bad_retries = 0
while True:
try:
async with self._connection_factory.new() as connection:
conn_fut = self._connection_factory.new()
async with (await conn_fut) as connection:
# Needs to happen prior to reinitialization to clear outstanding waiters.
if last_failure is not None:
while not self._write_queue.empty():
Expand All @@ -89,6 +90,11 @@ async def _run_loop(self):

except asyncio.CancelledError:
return
except Exception as e:
import traceback

traceback.print_exc()
print(e)

async def _loop_connection(self, connection: Connection[Request, Response]):
read_task: Awaitable[Response] = asyncio.ensure_future(connection.read())
Expand Down
13 changes: 11 additions & 2 deletions google/cloud/pubsublite_v1/services/cursor_service/async_client.py
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -103,7 +112,7 @@ def streaming_commit_cursor(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[cursor.StreamingCommitCursorResponse]:
) -> Awaitable[AsyncIterable[cursor.StreamingCommitCursorResponse]]:
r"""Establishes a stream with the server for managing
committed cursors.
Expand Down
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -107,7 +116,7 @@ def assign_partitions(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[subscriber.PartitionAssignment]:
) -> Awaitable[AsyncIterable[subscriber.PartitionAssignment]]:
r"""Assign partitions for this client to handle for the
specified subscription.
The client must send an
Expand Down
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -103,7 +112,7 @@ def publish(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[publisher.PublishResponse]:
) -> Awaitable[AsyncIterable[publisher.PublishResponse]]:
r"""Establishes a stream with the server for publishing
messages. Once the stream is initialized, the client
publishes messages by sending publish requests on the
Expand All @@ -125,7 +134,7 @@ def publish(
sent along with the request as metadata.
Returns:
AsyncIterable[~.publisher.PublishResponse]:
Awaitable[AsyncIterable[~.publisher.PublishResponse]]:
Response to a PublishRequest.
"""

Expand Down
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -100,7 +109,7 @@ def subscribe(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[subscriber.SubscribeResponse]:
) -> Awaitable[AsyncIterable[subscriber.SubscribeResponse]]:
r"""Establishes a stream with the server for receiving
messages.
Expand Down
@@ -1,6 +1,7 @@
from typing import Set

from asynctest.mock import MagicMock, call
import threading
import pytest
from google.api_core.exceptions import FailedPrecondition
from google.cloud.pubsub_v1.subscriber.message import Message
Expand All @@ -13,7 +14,7 @@
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite.testing.test_utils import wire_queues
from google.cloud.pubsublite.testing.test_utils import wire_queues, Box

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
Expand All @@ -36,7 +37,16 @@ def subscriber_factory():

@pytest.fixture()
def subscriber(assigner, subscriber_factory):
return AssigningSubscriber(assigner, subscriber_factory)
box = Box()

def set_box():
box.val = AssigningSubscriber(lambda: assigner, subscriber_factory)

# Initialize AssigningSubscriber on another thread with a different event loop.
thread = threading.Thread(target=set_box)
thread.start()
thread.join()
return box.val


async def test_init(subscriber, assigner):
Expand Down

0 comments on commit f41c228

Please sign in to comment.