From b2d0d36ee08249caa7a1d7f16aa7eb3bdb454cd0 Mon Sep 17 00:00:00 2001 From: dpcollins-google <40498610+dpcollins-google@users.noreply.github.com> Date: Tue, 15 Sep 2020 16:07:34 -0400 Subject: [PATCH] feat: Implement assigner, which handles partition-subscriber assignment. (#14) * feat: Implement assigner which generates subscription-partition assignments. Also slightly change the semantics of PermanentFailable to not fail a RetryingConnection on retryable errors from a watched awaitable. --- .../pubsublite/internal/wire/assigner.py | 15 ++ .../pubsublite/internal/wire/assigner_impl.py | 88 +++++++++ .../internal/wire/gapic_connection.py | 16 +- .../internal/wire/permanent_failable.py | 15 +- .../internal/wire/retrying_connection.py | 6 +- .../internal/wire/assigner_impl_test.py | 183 ++++++++++++++++++ 6 files changed, 311 insertions(+), 12 deletions(-) create mode 100644 google/cloud/pubsublite/internal/wire/assigner.py create mode 100644 google/cloud/pubsublite/internal/wire/assigner_impl.py create mode 100644 tests/unit/pubsublite/internal/wire/assigner_impl_test.py diff --git a/google/cloud/pubsublite/internal/wire/assigner.py b/google/cloud/pubsublite/internal/wire/assigner.py new file mode 100644 index 00000000..6307850a --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/assigner.py @@ -0,0 +1,15 @@ +from abc import abstractmethod +from typing import AsyncContextManager, Set + +from google.cloud.pubsublite.partition import Partition + + +class Assigner(AsyncContextManager): + """ + An assigner will deliver a continuous stream of assignments when called into. Perform all necessary work with the + assignment before attempting to get the next one. + """ + + @abstractmethod + async def get_assignment(self) -> Set[Partition]: + raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/assigner_impl.py b/google/cloud/pubsublite/internal/wire/assigner_impl.py new file mode 100644 index 00000000..65a0117c --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/assigner_impl.py @@ -0,0 +1,88 @@ +import asyncio +from typing import Optional, Set + +from absl import logging +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection, ConnectionFactory +from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError +from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer +from google.cloud.pubsublite.internal.wire.connection import Connection +from google.cloud.pubsublite.partition import Partition +from google.cloud.pubsublite_v1.types import PartitionAssignmentRequest, PartitionAssignment, \ + InitialPartitionAssignmentRequest, PartitionAssignmentAck + +# Maximum bytes per batch at 3.5 MiB to avoid GRPC limit of 4 MiB +_MAX_BYTES = int(3.5 * 1024 * 1024) + +# Maximum messages per batch at 1000 +_MAX_MESSAGES = 1000 + + +class AssignerImpl(Assigner, ConnectionReinitializer[PartitionAssignmentRequest, PartitionAssignment]): + _initial: InitialPartitionAssignmentRequest + _connection: RetryingConnection[PartitionAssignmentRequest, PartitionAssignment] + + _outstanding_assignment: bool + + _receiver: Optional[asyncio.Future] + + # A queue that may only hold one element with the next assignment. + _new_assignment: 'asyncio.Queue[Set[Partition]]' + + def __init__(self, initial: InitialPartitionAssignmentRequest, + factory: ConnectionFactory[PartitionAssignmentRequest, PartitionAssignment]): + self._initial = initial + self._connection = RetryingConnection(factory, self) + self._outstanding_assignment = False + self._receiver = None + self._new_assignment = asyncio.Queue(maxsize=1) + + async def __aenter__(self): + await self._connection.__aenter__() + + def _start_receiver(self): + assert self._receiver is None + self._receiver = asyncio.ensure_future(self._receive_loop()) + + async def _stop_receiver(self): + if self._receiver: + self._receiver.cancel() + await self._receiver + self._receiver = None + + async def _receive_loop(self): + try: + while True: + response = await self._connection.read() + if self._outstanding_assignment or not self._new_assignment.empty(): + self._connection.fail(FailedPrecondition( + "Received a duplicate assignment on the stream while one was outstanding.")) + return + self._outstanding_assignment = True + partitions = set() + for partition in response.partitions: + partitions.add(Partition(partition)) + self._new_assignment.put_nowait(partitions) + except asyncio.CancelledError: + return + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._connection.__aexit__(exc_type, exc_val, exc_tb) + + async def reinitialize(self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment]): + self._outstanding_assignment = False + while not self._new_assignment.empty(): + self._new_assignment.get_nowait() + await self._stop_receiver() + await connection.write(PartitionAssignmentRequest(initial=self._initial)) + self._start_receiver() + + async def get_assignment(self) -> Set[Partition]: + if self._outstanding_assignment: + try: + await self._connection.write(PartitionAssignmentRequest(ack=PartitionAssignmentAck())) + self._outstanding_assignment = False + except GoogleAPICallError as e: + # If there is a failure to ack, keep going. The stream likely restarted. + logging.debug(f"Assignment ack attempt failed due to stream failure: {e}") + return await self._connection.await_unless_failed(self._new_assignment.get()) diff --git a/google/cloud/pubsublite/internal/wire/gapic_connection.py b/google/cloud/pubsublite/internal/wire/gapic_connection.py index 0ae193a7..5e87d914 100644 --- a/google/cloud/pubsublite/internal/wire/gapic_connection.py +++ b/google/cloud/pubsublite/internal/wire/gapic_connection.py @@ -1,6 +1,8 @@ from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable import asyncio +from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition + from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response, ConnectionFactory from google.cloud.pubsublite.internal.wire.work_item import WorkItem from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable @@ -22,11 +24,17 @@ def set_response_it(self, response_it: AsyncIterator[Response]): async def write(self, request: Request) -> None: item = WorkItem(request) - await self.await_or_fail(self._write_queue.put(item)) - await self.await_or_fail(item.response_future) + await self.await_unless_failed(self._write_queue.put(item)) + await self.await_unless_failed(item.response_future) async def read(self) -> Response: - return await self.await_or_fail(self._response_it.__anext__()) + try: + return await self.await_unless_failed(self._response_it.__anext__()) + except StopAsyncIteration: + self.fail(FailedPrecondition("Server sent unprompted half close.")) + except GoogleAPICallError as e: + self.fail(e) + raise self.error() def __aenter__(self): return self @@ -35,7 +43,7 @@ def __aexit__(self, exc_type, exc_value, traceback) -> None: pass async def __anext__(self) -> Request: - item: WorkItem[Request] = await self.await_or_fail(self._write_queue.get()) + item: WorkItem[Request] = await self.await_unless_failed(self._write_queue.get()) item.response_future.set_result(None) return item.request diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index 3efa1c99..7d8a01f8 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -13,16 +13,21 @@ class PermanentFailable: def __init__(self): self._failure_task = asyncio.Future() - async def await_or_fail(self, awaitable: Awaitable[T]) -> T: + async def await_unless_failed(self, awaitable: Awaitable[T]) -> T: + """ + Await the awaitable, unless fail() is called first. + Args: + awaitable: An awaitable + + Returns: The result of the awaitable + Raises: The permanent error if fail() is called or the awaitable raises one. + """ if self._failure_task.done(): raise self._failure_task.exception() task = asyncio.ensure_future(awaitable) done, _ = await asyncio.wait([task, self._failure_task], return_when=asyncio.FIRST_COMPLETED) if task in done: - try: - return await task - except GoogleAPICallError as e: - self.fail(e) + return await task task.cancel() raise self._failure_task.exception() diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index e97ebd9b..dfc7eb87 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -38,11 +38,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def write(self, request: Request) -> None: item = WorkItem(request) - await self.await_or_fail(self._write_queue.put(item)) - return await self.await_or_fail(item.response_future) + await self.await_unless_failed(self._write_queue.put(item)) + return await self.await_unless_failed(item.response_future) async def read(self) -> Response: - return await self.await_or_fail(self._read_queue.get()) + return await self.await_unless_failed(self._read_queue.get()) async def _run_loop(self): """ diff --git a/tests/unit/pubsublite/internal/wire/assigner_impl_test.py b/tests/unit/pubsublite/internal/wire/assigner_impl_test.py new file mode 100644 index 00000000..14dc8abf --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/assigner_impl_test.py @@ -0,0 +1,183 @@ +import asyncio +from unittest.mock import call +from collections import defaultdict +from typing import Dict, Set + +from asynctest.mock import MagicMock, CoroutineMock +import pytest + +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.internal.wire.assigner_impl import AssignerImpl +from google.cloud.pubsublite.internal.wire.connection import Connection, ConnectionFactory +from google.api_core.exceptions import InternalServerError + +from google.cloud.pubsublite.partition import Partition +from google.cloud.pubsublite_v1.types.subscriber import PartitionAssignmentRequest, InitialPartitionAssignmentRequest, \ + PartitionAssignment, PartitionAssignmentAck +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter +from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def default_connection(): + conn = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment]) + conn.__aenter__.return_value = conn + return conn + + +@pytest.fixture() +def connection_factory(default_connection): + factory = MagicMock(spec=ConnectionFactory[PartitionAssignmentRequest, PartitionAssignment]) + factory.new.return_value = default_connection + return factory + + +@pytest.fixture() +def initial_request(): + return PartitionAssignmentRequest(initial=InitialPartitionAssignmentRequest(subscription="mysub")) + + +class QueuePair: + called: asyncio.Queue + results: asyncio.Queue + + def __init__(self): + self.called = asyncio.Queue() + self.results = asyncio.Queue() + + +@pytest.fixture +def sleep_queues() -> Dict[float, QueuePair]: + return defaultdict(QueuePair) + + +@pytest.fixture +def asyncio_sleep(monkeypatch, sleep_queues): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + mock = CoroutineMock() + monkeypatch.setattr(asyncio, "sleep", mock) + + async def sleeper(delay: float): + await make_queue_waiter(sleep_queues[delay].called, sleep_queues[delay].results)(delay) + + mock.side_effect = sleeper + return mock + + +@pytest.fixture() +def assigner(connection_factory, initial_request): + return AssignerImpl(initial_request.initial, connection_factory) + + +def as_response(partitions: Set[Partition]): + req = PartitionAssignment() + req.partitions = [partition.value for partition in partitions] + return req + + +def ack_request(): + return PartitionAssignmentRequest(ack=PartitionAssignmentAck()) + + +async def test_basic_assign( + assigner: Assigner, default_connection, initial_request): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + write_result_queue.put_nowait(None) + async with assigner: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Wait for the first assignment + assign_fut1 = asyncio.ensure_future(assigner.get_assignment()) + assert not assign_fut1.done() + + partitions = {Partition(2), Partition(7)} + + # Send the first assignment. + await read_result_queue.put(as_response(partitions=partitions)) + assert (await assign_fut1) == partitions + + # Get the next assignment: should send an ack on the stream + assign_fut2 = asyncio.ensure_future(assigner.get_assignment()) + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls([call(initial_request), call(ack_request())]) + + partitions = {Partition(5)} + + # Send the second assignment. + await read_called_queue.get() + await read_result_queue.put(as_response(partitions=partitions)) + assert (await assign_fut2) == partitions + + +async def test_restart( + assigner: Assigner, default_connection, connection_factory, initial_request, asyncio_sleep, sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + write_result_queue.put_nowait(None) + async with assigner: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Wait for the first assignment + assign_fut1 = asyncio.ensure_future(assigner.get_assignment()) + assert not assign_fut1.done() + + partitions = {Partition(2), Partition(7)} + + # Send the first assignment. + await read_result_queue.put(as_response(partitions=partitions)) + await read_called_queue.get() + assert (await assign_fut1) == partitions + + # Get the next assignment: should attempt to send an ack on the stream + assign_fut2 = asyncio.ensure_future(assigner.get_assignment()) + await write_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request), call(ack_request())]) + + # Set up the next connection + conn2 = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment]) + conn2.__aenter__.return_value = conn2 + connection_factory.new.return_value = conn2 + write_called_queue_2 = asyncio.Queue() + write_result_queue_2 = asyncio.Queue() + conn2.write.side_effect = make_queue_waiter(write_called_queue_2, write_result_queue_2) + read_called_queue_2 = asyncio.Queue() + read_result_queue_2 = asyncio.Queue() + conn2.read.side_effect = make_queue_waiter(read_called_queue_2, read_result_queue_2) + + # Fail the connection by failing the write call. + await write_result_queue.put(InternalServerError("failed")) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + + # Reinitialize + await write_called_queue_2.get() + write_result_queue_2.put_nowait(None) + conn2.write.assert_has_calls([call(initial_request)]) + + partitions = {Partition(5)} + + # Send the second assignment on the new connection. + await read_called_queue_2.get() + await read_result_queue_2.put(as_response(partitions=partitions)) + assert (await assign_fut2) == partitions + # No ack call ever made. + conn2.write.assert_has_calls([call(initial_request)])