diff --git a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py index 092abeba..9d667141 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py @@ -1,5 +1,5 @@ from asyncio import Future, Queue, ensure_future -from typing import Callable, NamedTuple, Dict, Set +from typing import Callable, NamedTuple, Dict, Set, Optional from google.cloud.pubsub_v1.subscriber.message import Message @@ -18,21 +18,31 @@ class _RunningSubscriber(NamedTuple): class AssigningSubscriber(AsyncSubscriber, PermanentFailable): - _assigner: Assigner + _assigner_factory: Callable[[], Assigner] _subscriber_factory: PartitionSubscriberFactory _subscribers: Dict[Partition, _RunningSubscriber] - _messages: "Queue[Message]" + + # Lazily initialized to ensure they are initialized on the thread where __aenter__ is called. + _assigner: Optional[Assigner] + _messages: Optional["Queue[Message]"] _assign_poller: Future def __init__( - self, assigner: Assigner, subscriber_factory: PartitionSubscriberFactory + self, + assigner_factory: Callable[[], Assigner], + subscriber_factory: PartitionSubscriberFactory, ): + """ + Accepts a factory for an Assigner instead of an Assigner because GRPC asyncio uses the current thread's event + loop. + """ super().__init__() - self._assigner = assigner + self._assigner_factory = assigner_factory + self._assigner = None self._subscriber_factory = subscriber_factory self._subscribers = {} - self._messages = Queue() + self._messages = None async def read(self) -> Message: return await self.await_unless_failed(self._messages.get()) @@ -65,6 +75,8 @@ async def _assign_action(self): del self._subscribers[partition] async def __aenter__(self): + self._messages = Queue() + self._assigner = self._assigner_factory() await self._assigner.__aenter__() self._assign_poller = ensure_future(self.run_poller(self._assign_action)) return self diff --git a/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py index cc91e101..832f244e 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py @@ -1,4 +1,4 @@ -from typing import Mapping +from typing import Mapping, Callable, Optional from google.pubsub_v1 import PubsubMessage @@ -10,11 +10,17 @@ class AsyncPublisherImpl(AsyncPublisher): - _publisher: Publisher - - def __init__(self, publisher: Publisher): + _publisher_factory: Callable[[], Publisher] + _publisher: Optional[Publisher] + + def __init__(self, publisher_factory: Callable[[], Publisher]): + """ + Accepts a factory for a Publisher instead of a Publisher because GRPC asyncio uses the current thread's event + loop. + """ super().__init__() - self._publisher = publisher + self._publisher_factory = publisher_factory + self._publisher = None async def publish( self, data: bytes, ordering_key: str = "", **attrs: Mapping[str, str] @@ -26,6 +32,7 @@ async def publish( return (await self._publisher.publish(psl_message)).encode() async def __aenter__(self): + self._publisher = self._publisher_factory() await self._publisher.__aenter__() return self diff --git a/google/cloud/pubsublite/cloudpubsub/make_publisher.py b/google/cloud/pubsublite/cloudpubsub/make_publisher.py index 9c65cbf9..a1acc51e 100644 --- a/google/cloud/pubsublite/cloudpubsub/make_publisher.py +++ b/google/cloud/pubsublite/cloudpubsub/make_publisher.py @@ -40,10 +40,13 @@ def make_async_publisher( GoogleApiCallException on any error determining topic structure. """ metadata = merge_metadata(pubsub_context(framework="CLOUD_PUBSUB_SHIM"), metadata) - underlying = make_wire_publisher( - topic, batching_delay_secs, credentials, client_options, metadata - ) - return AsyncPublisherImpl(underlying) + + def underlying_factory(): + return make_wire_publisher( + topic, batching_delay_secs, credentials, client_options, metadata + ) + + return AsyncPublisherImpl(underlying_factory) def make_publisher( diff --git a/google/cloud/pubsublite/cloudpubsub/make_subscriber.py b/google/cloud/pubsublite/cloudpubsub/make_subscriber.py index 55b30618..850d6014 100644 --- a/google/cloud/pubsublite/cloudpubsub/make_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/make_subscriber.py @@ -1,5 +1,5 @@ from concurrent.futures.thread import ThreadPoolExecutor -from typing import Optional, Mapping, Set, AsyncIterator +from typing import Optional, Mapping, Set, AsyncIterator, Callable from uuid import uuid4 from google.api_core.client_options import ClientOptions @@ -170,14 +170,16 @@ def make_async_subscriber( client_options = ClientOptions( api_endpoint=regional_endpoint(subscription.location.region) ) - assigner: Assigner + assigner_factory: Callable[[], Assigner] if fixed_partitions: - assigner = FixedSetAssigner(fixed_partitions) + assigner_factory = lambda: FixedSetAssigner(fixed_partitions) # noqa: E731 else: assignment_client = PartitionAssignmentServiceAsyncClient( credentials=credentials, client_options=client_options ) # type: ignore - assigner = _make_dynamic_assigner(subscription, assignment_client, metadata) + assigner_factory = lambda: _make_dynamic_assigner( # noqa: E731 + subscription, assignment_client, metadata + ) subscribe_client = SubscriberServiceAsyncClient( credentials=credentials, client_options=client_options @@ -196,7 +198,7 @@ def make_async_subscriber( nack_handler, message_transformer, ) - return AssigningSubscriber(assigner, partition_subscriber_factory) + return AssigningSubscriber(assigner_factory, partition_subscriber_factory) def make_subscriber( diff --git a/google/cloud/pubsublite/internal/wire/connection.py b/google/cloud/pubsublite/internal/wire/connection.py index 89895ca9..6d96a1b7 100644 --- a/google/cloud/pubsublite/internal/wire/connection.py +++ b/google/cloud/pubsublite/internal/wire/connection.py @@ -34,5 +34,5 @@ async def read(self) -> Response: class ConnectionFactory(Generic[Request, Response]): """A factory for producing Connections.""" - def new(self) -> Connection[Request, Response]: + async def new(self) -> Connection[Request, Response]: raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/gapic_connection.py b/google/cloud/pubsublite/internal/wire/gapic_connection.py index c63b69c3..41d720dd 100644 --- a/google/cloud/pubsublite/internal/wire/gapic_connection.py +++ b/google/cloud/pubsublite/internal/wire/gapic_connection.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable +from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable, Awaitable import asyncio from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition @@ -44,10 +44,10 @@ async def read(self) -> Response: self.fail(e) raise self.error() - def __aenter__(self): + async def __aenter__(self): return self - def __aexit__(self, exc_type, exc_value, traceback) -> None: + async def __aexit__(self, exc_type, exc_value, traceback) -> None: pass async def __anext__(self) -> Request: @@ -64,15 +64,19 @@ def __aiter__(self) -> AsyncIterator[Response]: class GapicConnectionFactory(ConnectionFactory[Request, Response]): """A ConnectionFactory that produces GapicConnections.""" - _producer = Callable[[AsyncIterator[Request]], AsyncIterable[Response]] + _producer = Callable[[AsyncIterator[Request]], Awaitable[AsyncIterable[Response]]] def __init__( - self, producer: Callable[[AsyncIterator[Request]], AsyncIterable[Response]] + self, + producer: Callable[ + [AsyncIterator[Request]], Awaitable[AsyncIterable[Response]] + ], ): self._producer = producer - def new(self) -> Connection[Request, Response]: + async def new(self) -> Connection[Request, Response]: conn = GapicConnection[Request, Response]() - response_iterable = self._producer(conn) + response_fut = self._producer(conn) + response_iterable = await response_fut conn.set_response_it(response_iterable.__aiter__()) return conn diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index a0323399..c5836750 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -65,7 +65,8 @@ async def _run_loop(self): bad_retries = 0 while True: try: - async with self._connection_factory.new() as connection: + conn_fut = self._connection_factory.new() + async with (await conn_fut) as connection: # Needs to happen prior to reinitialization to clear outstanding waiters. if last_failure is not None: while not self._write_queue.empty(): @@ -89,6 +90,11 @@ async def _run_loop(self): except asyncio.CancelledError: return + except Exception as e: + import traceback + + traceback.print_exc() + print(e) async def _loop_connection(self, connection: Connection[Request, Response]): read_task: Awaitable[Response] = asyncio.ensure_future(connection.read()) diff --git a/google/cloud/pubsublite_v1/services/cursor_service/async_client.py b/google/cloud/pubsublite_v1/services/cursor_service/async_client.py index d536f059..1ef57745 100644 --- a/google/cloud/pubsublite_v1/services/cursor_service/async_client.py +++ b/google/cloud/pubsublite_v1/services/cursor_service/async_client.py @@ -18,7 +18,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union +from typing import ( + Dict, + AsyncIterable, + Awaitable, + AsyncIterator, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources import google.api_core.client_options as ClientOptions # type: ignore @@ -103,7 +112,7 @@ def streaming_commit_cursor( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> AsyncIterable[cursor.StreamingCommitCursorResponse]: + ) -> Awaitable[AsyncIterable[cursor.StreamingCommitCursorResponse]]: r"""Establishes a stream with the server for managing committed cursors. diff --git a/google/cloud/pubsublite_v1/services/partition_assignment_service/async_client.py b/google/cloud/pubsublite_v1/services/partition_assignment_service/async_client.py index fe943d20..16cea993 100644 --- a/google/cloud/pubsublite_v1/services/partition_assignment_service/async_client.py +++ b/google/cloud/pubsublite_v1/services/partition_assignment_service/async_client.py @@ -18,7 +18,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union +from typing import ( + Dict, + AsyncIterable, + Awaitable, + AsyncIterator, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources import google.api_core.client_options as ClientOptions # type: ignore @@ -107,7 +116,7 @@ def assign_partitions( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> AsyncIterable[subscriber.PartitionAssignment]: + ) -> Awaitable[AsyncIterable[subscriber.PartitionAssignment]]: r"""Assign partitions for this client to handle for the specified subscription. The client must send an diff --git a/google/cloud/pubsublite_v1/services/publisher_service/async_client.py b/google/cloud/pubsublite_v1/services/publisher_service/async_client.py index 39fd6728..01488aba 100644 --- a/google/cloud/pubsublite_v1/services/publisher_service/async_client.py +++ b/google/cloud/pubsublite_v1/services/publisher_service/async_client.py @@ -18,7 +18,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union +from typing import ( + Dict, + AsyncIterable, + Awaitable, + AsyncIterator, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources import google.api_core.client_options as ClientOptions # type: ignore @@ -103,7 +112,7 @@ def publish( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> AsyncIterable[publisher.PublishResponse]: + ) -> Awaitable[AsyncIterable[publisher.PublishResponse]]: r"""Establishes a stream with the server for publishing messages. Once the stream is initialized, the client publishes messages by sending publish requests on the @@ -125,7 +134,7 @@ def publish( sent along with the request as metadata. Returns: - AsyncIterable[~.publisher.PublishResponse]: + Awaitable[AsyncIterable[~.publisher.PublishResponse]]: Response to a PublishRequest. """ diff --git a/google/cloud/pubsublite_v1/services/subscriber_service/async_client.py b/google/cloud/pubsublite_v1/services/subscriber_service/async_client.py index 260b45ba..28ab54c2 100644 --- a/google/cloud/pubsublite_v1/services/subscriber_service/async_client.py +++ b/google/cloud/pubsublite_v1/services/subscriber_service/async_client.py @@ -18,7 +18,16 @@ from collections import OrderedDict import functools import re -from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union +from typing import ( + Dict, + AsyncIterable, + Awaitable, + AsyncIterator, + Sequence, + Tuple, + Type, + Union, +) import pkg_resources import google.api_core.client_options as ClientOptions # type: ignore @@ -100,7 +109,7 @@ def subscribe( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> AsyncIterable[subscriber.SubscribeResponse]: + ) -> Awaitable[AsyncIterable[subscriber.SubscribeResponse]]: r"""Establishes a stream with the server for receiving messages. diff --git a/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py index 38fd66a6..8d4314e9 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py @@ -1,6 +1,7 @@ from typing import Set from asynctest.mock import MagicMock, call +import threading import pytest from google.api_core.exceptions import FailedPrecondition from google.cloud.pubsub_v1.subscriber.message import Message @@ -13,7 +14,7 @@ from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber from google.cloud.pubsublite.internal.wire.assigner import Assigner from google.cloud.pubsublite.partition import Partition -from google.cloud.pubsublite.testing.test_utils import wire_queues +from google.cloud.pubsublite.testing.test_utils import wire_queues, Box # All test coroutines will be treated as marked. pytestmark = pytest.mark.asyncio @@ -36,7 +37,16 @@ def subscriber_factory(): @pytest.fixture() def subscriber(assigner, subscriber_factory): - return AssigningSubscriber(assigner, subscriber_factory) + box = Box() + + def set_box(): + box.val = AssigningSubscriber(lambda: assigner, subscriber_factory) + + # Initialize AssigningSubscriber on another thread with a different event loop. + thread = threading.Thread(target=set_box) + thread.start() + thread.join() + return box.val async def test_init(subscriber, assigner):