Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remaining issues with subscriber client #43

Merged
merged 2 commits into from Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 13 additions & 12 deletions google/cloud/pubsublite/cloudpubsub/make_subscriber.py
Expand Up @@ -88,14 +88,18 @@ def assignment_connection_factory(

def _make_partition_subscriber_factory(
subscription: SubscriptionPath,
subscribe_client: SubscriberServiceAsyncClient,
cursor_client: CursorServiceAsyncClient,
client_options: ClientOptions,
credentials: Optional[Credentials],
base_metadata: Optional[Mapping[str, str]],
flow_control_settings: FlowControlSettings,
nack_handler: NackHandler,
message_transformer: MessageTransformer,
) -> PartitionSubscriberFactory:
def factory(partition: Partition) -> AsyncSubscriber:
subscribe_client = SubscriberServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore
final_metadata = merge_metadata(
base_metadata, subscription_routing_metadata(subscription, partition)
)
Expand Down Expand Up @@ -174,25 +178,22 @@ def make_async_subscriber(
if fixed_partitions:
assigner_factory = lambda: FixedSetAssigner(fixed_partitions) # noqa: E731
else:
assignment_client = PartitionAssignmentServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
assigner_factory = lambda: _make_dynamic_assigner( # noqa: E731
subscription, assignment_client, metadata
subscription,
PartitionAssignmentServiceAsyncClient(
credentials=credentials, client_options=client_options
),
metadata,
)

subscribe_client = SubscriberServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore
if nack_handler is None:
nack_handler = DefaultNackHandler()
if message_transformer is None:
message_transformer = DefaultMessageTransformer()
partition_subscriber_factory = _make_partition_subscriber_factory(
subscription,
subscribe_client,
cursor_client,
client_options,
credentials,
metadata,
per_partition_flow_control_settings,
nack_handler,
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
Expand Up @@ -24,6 +24,7 @@ class RetryingConnection(Connection[Request, Response], PermanentFailable):

_connection_factory: ConnectionFactory[Request, Response]
_reinitializer: ConnectionReinitializer[Request, Response]
_initialized_once: asyncio.Event

_loop_task: asyncio.Future

Expand All @@ -38,11 +39,13 @@ def __init__(
super().__init__()
self._connection_factory = connection_factory
self._reinitializer = reinitializer
self._initialized_once = asyncio.Event()
self._write_queue = asyncio.Queue(maxsize=1)
self._read_queue = asyncio.Queue(maxsize=1)

async def __aenter__(self):
self._loop_task = asyncio.ensure_future(self._run_loop())
await self.await_unless_failed(self._initialized_once.wait())
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -76,6 +79,7 @@ async def _run_loop(self):
self._read_queue = asyncio.Queue(maxsize=1)
self._write_queue = asyncio.Queue(maxsize=1)
await self._reinitializer.reinitialize(connection)
self._initialized_once.set()
bad_retries = 0
await self._loop_connection(connection)
except GoogleAPICallError as e:
Expand Down
50 changes: 13 additions & 37 deletions tests/unit/pubsublite/internal/wire/retrying_connection_test.py
@@ -1,5 +1,4 @@
import asyncio
from typing import Union

from asynctest.mock import MagicMock, CoroutineMock
import pytest
Expand All @@ -15,6 +14,7 @@
RetryingConnection,
_MIN_BACKOFF_SECS,
)
from google.cloud.pubsublite.testing.test_utils import wire_queues

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
Expand All @@ -41,7 +41,7 @@ def connection_factory(default_connection):

@pytest.fixture()
def retrying_connection(connection_factory, reinitializer):
return RetryingConnection[int, int](connection_factory, reinitializer)
return RetryingConnection(connection_factory, reinitializer)


@pytest.fixture
Expand All @@ -55,40 +55,27 @@ def asyncio_sleep(monkeypatch):
async def test_permanent_error_on_reinitializer(
retrying_connection: Connection[int, int], reinitializer, default_connection
):
fut = asyncio.Future()
reinitialize_called = asyncio.Future()

async def reinit_action(conn):
assert conn == default_connection
reinitialize_called.set_result(None)
return await fut
raise InvalidArgument("abc")

reinitializer.reinitialize.side_effect = reinit_action
async with retrying_connection as _:
await reinitialize_called
reinitializer.reinitialize.assert_called_once()
fut.set_exception(InvalidArgument("abc"))
with pytest.raises(InvalidArgument):
await retrying_connection.read()
with pytest.raises(InvalidArgument):
async with retrying_connection as _:
pass


async def test_successful_reinitialize(
retrying_connection: Connection[int, int], reinitializer, default_connection
):
fut = asyncio.Future()
reinitialize_called = asyncio.Future()

async def reinit_action(conn):
assert conn == default_connection
reinitialize_called.set_result(None)
return await fut
return None

default_connection.read.return_value = 1

reinitializer.reinitialize.side_effect = reinit_action
async with retrying_connection as _:
await reinitialize_called
reinitializer.reinitialize.assert_called_once()
fut.set_result(None)
default_connection.read.return_value = 1
assert await retrying_connection.read() == 1
assert (
default_connection.read.call_count == 2
Expand All @@ -111,26 +98,15 @@ async def test_reinitialize_after_retryable(
default_connection,
asyncio_sleep,
):
reinit_called = asyncio.Queue()
reinit_results: "asyncio.Queue[Union[None, Exception]]" = asyncio.Queue()
reinit_queues = wire_queues(reinitializer.reinitialize)

async def reinit_action(conn):
assert conn == default_connection
await reinit_called.put(None)
result = await reinit_results.get()
if isinstance(result, Exception):
raise result
default_connection.read.return_value = 1

reinitializer.reinitialize.side_effect = reinit_action
await reinit_queues.results.put(InternalServerError("abc"))
await reinit_queues.results.put(None)
async with retrying_connection as _:
await reinit_called.get()
reinitializer.reinitialize.assert_called_once()
await reinit_results.put(InternalServerError("abc"))
await reinit_called.get()
asyncio_sleep.assert_called_once_with(_MIN_BACKOFF_SECS)
assert reinitializer.reinitialize.call_count == 2
await reinit_results.put(None)
default_connection.read.return_value = 1
assert await retrying_connection.read() == 1
assert (
default_connection.read.call_count == 2
Expand Down