Skip to content

Commit

Permalink
fix: remaining issues with subscriber client (#43)
Browse files Browse the repository at this point in the history
* fix: Remaining issues with subscriber client.

Fix make_subscriber to defer GRPC client creation.

Fix retrying_connection to not let __aenter__ return until successful initialization or permanent failure.

* chore: reformat
  • Loading branch information
dpcollins-google committed Oct 12, 2020
1 parent a037d0b commit ec19dfc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 49 deletions.
25 changes: 13 additions & 12 deletions google/cloud/pubsublite/cloudpubsub/make_subscriber.py
Expand Up @@ -88,14 +88,18 @@ def assignment_connection_factory(

def _make_partition_subscriber_factory(
subscription: SubscriptionPath,
subscribe_client: SubscriberServiceAsyncClient,
cursor_client: CursorServiceAsyncClient,
client_options: ClientOptions,
credentials: Optional[Credentials],
base_metadata: Optional[Mapping[str, str]],
flow_control_settings: FlowControlSettings,
nack_handler: NackHandler,
message_transformer: MessageTransformer,
) -> PartitionSubscriberFactory:
def factory(partition: Partition) -> AsyncSubscriber:
subscribe_client = SubscriberServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore
final_metadata = merge_metadata(
base_metadata, subscription_routing_metadata(subscription, partition)
)
Expand Down Expand Up @@ -174,25 +178,22 @@ def make_async_subscriber(
if fixed_partitions:
assigner_factory = lambda: FixedSetAssigner(fixed_partitions) # noqa: E731
else:
assignment_client = PartitionAssignmentServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
assigner_factory = lambda: _make_dynamic_assigner( # noqa: E731
subscription, assignment_client, metadata
subscription,
PartitionAssignmentServiceAsyncClient(
credentials=credentials, client_options=client_options
),
metadata,
)

subscribe_client = SubscriberServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore
if nack_handler is None:
nack_handler = DefaultNackHandler()
if message_transformer is None:
message_transformer = DefaultMessageTransformer()
partition_subscriber_factory = _make_partition_subscriber_factory(
subscription,
subscribe_client,
cursor_client,
client_options,
credentials,
metadata,
per_partition_flow_control_settings,
nack_handler,
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
Expand Up @@ -24,6 +24,7 @@ class RetryingConnection(Connection[Request, Response], PermanentFailable):

_connection_factory: ConnectionFactory[Request, Response]
_reinitializer: ConnectionReinitializer[Request, Response]
_initialized_once: asyncio.Event

_loop_task: asyncio.Future

Expand All @@ -38,11 +39,13 @@ def __init__(
super().__init__()
self._connection_factory = connection_factory
self._reinitializer = reinitializer
self._initialized_once = asyncio.Event()
self._write_queue = asyncio.Queue(maxsize=1)
self._read_queue = asyncio.Queue(maxsize=1)

async def __aenter__(self):
self._loop_task = asyncio.ensure_future(self._run_loop())
await self.await_unless_failed(self._initialized_once.wait())
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -76,6 +79,7 @@ async def _run_loop(self):
self._read_queue = asyncio.Queue(maxsize=1)
self._write_queue = asyncio.Queue(maxsize=1)
await self._reinitializer.reinitialize(connection)
self._initialized_once.set()
bad_retries = 0
await self._loop_connection(connection)
except GoogleAPICallError as e:
Expand Down
50 changes: 13 additions & 37 deletions tests/unit/pubsublite/internal/wire/retrying_connection_test.py
@@ -1,5 +1,4 @@
import asyncio
from typing import Union

from asynctest.mock import MagicMock, CoroutineMock
import pytest
Expand All @@ -15,6 +14,7 @@
RetryingConnection,
_MIN_BACKOFF_SECS,
)
from google.cloud.pubsublite.testing.test_utils import wire_queues

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
Expand All @@ -41,7 +41,7 @@ def connection_factory(default_connection):

@pytest.fixture()
def retrying_connection(connection_factory, reinitializer):
return RetryingConnection[int, int](connection_factory, reinitializer)
return RetryingConnection(connection_factory, reinitializer)


@pytest.fixture
Expand All @@ -55,40 +55,27 @@ def asyncio_sleep(monkeypatch):
async def test_permanent_error_on_reinitializer(
retrying_connection: Connection[int, int], reinitializer, default_connection
):
fut = asyncio.Future()
reinitialize_called = asyncio.Future()

async def reinit_action(conn):
assert conn == default_connection
reinitialize_called.set_result(None)
return await fut
raise InvalidArgument("abc")

reinitializer.reinitialize.side_effect = reinit_action
async with retrying_connection as _:
await reinitialize_called
reinitializer.reinitialize.assert_called_once()
fut.set_exception(InvalidArgument("abc"))
with pytest.raises(InvalidArgument):
await retrying_connection.read()
with pytest.raises(InvalidArgument):
async with retrying_connection as _:
pass


async def test_successful_reinitialize(
retrying_connection: Connection[int, int], reinitializer, default_connection
):
fut = asyncio.Future()
reinitialize_called = asyncio.Future()

async def reinit_action(conn):
assert conn == default_connection
reinitialize_called.set_result(None)
return await fut
return None

default_connection.read.return_value = 1

reinitializer.reinitialize.side_effect = reinit_action
async with retrying_connection as _:
await reinitialize_called
reinitializer.reinitialize.assert_called_once()
fut.set_result(None)
default_connection.read.return_value = 1
assert await retrying_connection.read() == 1
assert (
default_connection.read.call_count == 2
Expand All @@ -111,26 +98,15 @@ async def test_reinitialize_after_retryable(
default_connection,
asyncio_sleep,
):
reinit_called = asyncio.Queue()
reinit_results: "asyncio.Queue[Union[None, Exception]]" = asyncio.Queue()
reinit_queues = wire_queues(reinitializer.reinitialize)

async def reinit_action(conn):
assert conn == default_connection
await reinit_called.put(None)
result = await reinit_results.get()
if isinstance(result, Exception):
raise result
default_connection.read.return_value = 1

reinitializer.reinitialize.side_effect = reinit_action
await reinit_queues.results.put(InternalServerError("abc"))
await reinit_queues.results.put(None)
async with retrying_connection as _:
await reinit_called.get()
reinitializer.reinitialize.assert_called_once()
await reinit_results.put(InternalServerError("abc"))
await reinit_called.get()
asyncio_sleep.assert_called_once_with(_MIN_BACKOFF_SECS)
assert reinitializer.reinitialize.call_count == 2
await reinit_results.put(None)
default_connection.read.return_value = 1
assert await retrying_connection.read() == 1
assert (
default_connection.read.call_count == 2
Expand Down

0 comments on commit ec19dfc

Please sign in to comment.