diff --git a/google/cloud/pubsublite/cloudpubsub/make_subscriber.py b/google/cloud/pubsublite/cloudpubsub/make_subscriber.py index 850d6014..01e3c486 100644 --- a/google/cloud/pubsublite/cloudpubsub/make_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/make_subscriber.py @@ -88,14 +88,18 @@ def assignment_connection_factory( def _make_partition_subscriber_factory( subscription: SubscriptionPath, - subscribe_client: SubscriberServiceAsyncClient, - cursor_client: CursorServiceAsyncClient, + client_options: ClientOptions, + credentials: Optional[Credentials], base_metadata: Optional[Mapping[str, str]], flow_control_settings: FlowControlSettings, nack_handler: NackHandler, message_transformer: MessageTransformer, ) -> PartitionSubscriberFactory: def factory(partition: Partition) -> AsyncSubscriber: + subscribe_client = SubscriberServiceAsyncClient( + credentials=credentials, client_options=client_options + ) # type: ignore + cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore final_metadata = merge_metadata( base_metadata, subscription_routing_metadata(subscription, partition) ) @@ -174,25 +178,22 @@ def make_async_subscriber( if fixed_partitions: assigner_factory = lambda: FixedSetAssigner(fixed_partitions) # noqa: E731 else: - assignment_client = PartitionAssignmentServiceAsyncClient( - credentials=credentials, client_options=client_options - ) # type: ignore assigner_factory = lambda: _make_dynamic_assigner( # noqa: E731 - subscription, assignment_client, metadata + subscription, + PartitionAssignmentServiceAsyncClient( + credentials=credentials, client_options=client_options + ), + metadata, ) - subscribe_client = SubscriberServiceAsyncClient( - credentials=credentials, client_options=client_options - ) # type: ignore - cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore if nack_handler is None: nack_handler = DefaultNackHandler() if message_transformer is None: message_transformer = DefaultMessageTransformer() partition_subscriber_factory = _make_partition_subscriber_factory( subscription, - subscribe_client, - cursor_client, + client_options, + credentials, metadata, per_partition_flow_control_settings, nack_handler, diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index c5836750..a5a633cc 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -24,6 +24,7 @@ class RetryingConnection(Connection[Request, Response], PermanentFailable): _connection_factory: ConnectionFactory[Request, Response] _reinitializer: ConnectionReinitializer[Request, Response] + _initialized_once: asyncio.Event _loop_task: asyncio.Future @@ -38,11 +39,13 @@ def __init__( super().__init__() self._connection_factory = connection_factory self._reinitializer = reinitializer + self._initialized_once = asyncio.Event() self._write_queue = asyncio.Queue(maxsize=1) self._read_queue = asyncio.Queue(maxsize=1) async def __aenter__(self): self._loop_task = asyncio.ensure_future(self._run_loop()) + await self.await_unless_failed(self._initialized_once.wait()) return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -76,6 +79,7 @@ async def _run_loop(self): self._read_queue = asyncio.Queue(maxsize=1) self._write_queue = asyncio.Queue(maxsize=1) await self._reinitializer.reinitialize(connection) + self._initialized_once.set() bad_retries = 0 await self._loop_connection(connection) except GoogleAPICallError as e: diff --git a/tests/unit/pubsublite/internal/wire/retrying_connection_test.py b/tests/unit/pubsublite/internal/wire/retrying_connection_test.py index f1e36815..60b1f9f0 100644 --- a/tests/unit/pubsublite/internal/wire/retrying_connection_test.py +++ b/tests/unit/pubsublite/internal/wire/retrying_connection_test.py @@ -1,5 +1,4 @@ import asyncio -from typing import Union from asynctest.mock import MagicMock, CoroutineMock import pytest @@ -15,6 +14,7 @@ RetryingConnection, _MIN_BACKOFF_SECS, ) +from google.cloud.pubsublite.testing.test_utils import wire_queues # All test coroutines will be treated as marked. pytestmark = pytest.mark.asyncio @@ -41,7 +41,7 @@ def connection_factory(default_connection): @pytest.fixture() def retrying_connection(connection_factory, reinitializer): - return RetryingConnection[int, int](connection_factory, reinitializer) + return RetryingConnection(connection_factory, reinitializer) @pytest.fixture @@ -55,40 +55,27 @@ def asyncio_sleep(monkeypatch): async def test_permanent_error_on_reinitializer( retrying_connection: Connection[int, int], reinitializer, default_connection ): - fut = asyncio.Future() - reinitialize_called = asyncio.Future() - async def reinit_action(conn): assert conn == default_connection - reinitialize_called.set_result(None) - return await fut + raise InvalidArgument("abc") reinitializer.reinitialize.side_effect = reinit_action - async with retrying_connection as _: - await reinitialize_called - reinitializer.reinitialize.assert_called_once() - fut.set_exception(InvalidArgument("abc")) - with pytest.raises(InvalidArgument): - await retrying_connection.read() + with pytest.raises(InvalidArgument): + async with retrying_connection as _: + pass async def test_successful_reinitialize( retrying_connection: Connection[int, int], reinitializer, default_connection ): - fut = asyncio.Future() - reinitialize_called = asyncio.Future() - async def reinit_action(conn): assert conn == default_connection - reinitialize_called.set_result(None) - return await fut + return None + + default_connection.read.return_value = 1 reinitializer.reinitialize.side_effect = reinit_action async with retrying_connection as _: - await reinitialize_called - reinitializer.reinitialize.assert_called_once() - fut.set_result(None) - default_connection.read.return_value = 1 assert await retrying_connection.read() == 1 assert ( default_connection.read.call_count == 2 @@ -111,26 +98,15 @@ async def test_reinitialize_after_retryable( default_connection, asyncio_sleep, ): - reinit_called = asyncio.Queue() - reinit_results: "asyncio.Queue[Union[None, Exception]]" = asyncio.Queue() + reinit_queues = wire_queues(reinitializer.reinitialize) - async def reinit_action(conn): - assert conn == default_connection - await reinit_called.put(None) - result = await reinit_results.get() - if isinstance(result, Exception): - raise result + default_connection.read.return_value = 1 - reinitializer.reinitialize.side_effect = reinit_action + await reinit_queues.results.put(InternalServerError("abc")) + await reinit_queues.results.put(None) async with retrying_connection as _: - await reinit_called.get() - reinitializer.reinitialize.assert_called_once() - await reinit_results.put(InternalServerError("abc")) - await reinit_called.get() asyncio_sleep.assert_called_once_with(_MIN_BACKOFF_SECS) assert reinitializer.reinitialize.call_count == 2 - await reinit_results.put(None) - default_connection.read.return_value = 1 assert await retrying_connection.read() == 1 assert ( default_connection.read.call_count == 2