From 6afd477e2f17cc534b8bf8a2f4fc30cca951e248 Mon Sep 17 00:00:00 2001 From: dpcollins-google <40498610+dpcollins-google@users.noreply.github.com> Date: Thu, 24 Sep 2020 15:27:23 -0400 Subject: [PATCH] feat: implement assigning subscriber (#23) * feat: Implement SinglePartitionSubscriber. This handles mapping a single partition to a Cloud Pub/Sub Like asynchronous subscriber. * feat: Add DefaultNackHandler. * feat: Add AssigningSubscriber. This handles changing partition assignments and creates AsyncSubscribers per-partition. --- .../internal/assigning_subscriber.py | 73 ++++++++++ .../internal/managed_event_loop.py | 2 +- .../internal/wait_ignore_cancelled.py | 9 ++ .../internal/wire/permanent_failable.py | 15 ++- google/cloud/pubsublite/testing/test_utils.py | 17 +++ .../internal/assigning_subscriber_test.py | 127 ++++++++++++++++++ 6 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py create mode 100644 google/cloud/pubsublite/internal/wait_ignore_cancelled.py create mode 100644 tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py diff --git a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py new file mode 100644 index 00000000..19877d75 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py @@ -0,0 +1,73 @@ +from asyncio import Future, Queue, ensure_future +from typing import Callable, NamedTuple, Dict, Set + +from google.cloud.pubsub_v1.subscriber.message import Message + +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable +from google.cloud.pubsublite.partition import Partition + +_PartitionSubscriberFactory = Callable[[Partition], AsyncSubscriber] + + +class _RunningSubscriber(NamedTuple): + subscriber: AsyncSubscriber + poller: Future + + +class AssigningSubscriber(AsyncSubscriber, PermanentFailable): + _assigner: Assigner + _subscriber_factory: _PartitionSubscriberFactory + + _subscribers: Dict[Partition, _RunningSubscriber] + _messages: "Queue[Message]" + _assign_poller: Future + + def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory): + super().__init__() + self._assigner = assigner + self._subscriber_factory = subscriber_factory + self._subscribers = {} + self._messages = Queue() + + async def read(self) -> Message: + return await self.await_unless_failed(self._messages.get()) + + async def _subscribe_action(self, subscriber: AsyncSubscriber): + message = await subscriber.read() + await self._messages.put(message) + + async def _start_subscriber(self, partition: Partition): + new_subscriber = self._subscriber_factory(partition) + await new_subscriber.__aenter__() + poller = ensure_future(self.run_poller(lambda: self._subscribe_action(new_subscriber))) + self._subscribers[partition] = _RunningSubscriber(new_subscriber, poller) + + async def _stop_subscriber(self, running: _RunningSubscriber): + running.poller.cancel() + await wait_ignore_cancelled(running.poller) + await running.subscriber.__aexit__(None, None, None) + + async def _assign_action(self): + assignment: Set[Partition] = await self._assigner.get_assignment() + added_partitions = assignment - self._subscribers.keys() + removed_partitions = self._subscribers.keys() - assignment + for partition in added_partitions: + await self._start_subscriber(partition) + for partition in removed_partitions: + await self._stop_subscriber(self._subscribers[partition]) + del self._subscribers[partition] + + async def __aenter__(self): + await self._assigner.__aenter__() + self._assign_poller = ensure_future(self.run_poller(self._assign_action)) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self._assign_poller.cancel() + await wait_ignore_cancelled(self._assign_poller) + await self._assigner.__aexit__(exc_type, exc_value, traceback) + for running in self._subscribers.values(): + await self._stop_subscriber(running) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py b/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py index 7840bac3..7e787ca9 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py @@ -15,7 +15,7 @@ def __init__(self): def __enter__(self): self._thread.start() - def __exit__(self, __exc_type, __exc_value, __traceback): + def __exit__(self, exc_type, exc_value, traceback): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() diff --git a/google/cloud/pubsublite/internal/wait_ignore_cancelled.py b/google/cloud/pubsublite/internal/wait_ignore_cancelled.py new file mode 100644 index 00000000..d7e16499 --- /dev/null +++ b/google/cloud/pubsublite/internal/wait_ignore_cancelled.py @@ -0,0 +1,9 @@ +from asyncio import CancelledError +from typing import Awaitable + + +async def wait_ignore_cancelled(awaitable: Awaitable): + try: + await awaitable + except CancelledError: + pass diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index 7d8a01f8..f3fadc66 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -1,5 +1,5 @@ import asyncio -from typing import Awaitable, TypeVar, Optional +from typing import Awaitable, TypeVar, Optional, Callable from google.api_core.exceptions import GoogleAPICallError @@ -31,6 +31,19 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T: task.cancel() raise self._failure_task.exception() + async def run_poller(self, poll_action: Callable[[], Awaitable[None]]): + """ + Run a polling loop, which runs poll_action forever unless this is failed. + Args: + poll_action: A callable returning an awaitable to run in a loop. Note that async functions which return once + satisfy this. + """ + try: + while True: + await self.await_unless_failed(poll_action()) + except GoogleAPICallError as e: + self.fail(e) + def fail(self, err: GoogleAPICallError): if not self._failure_task.done(): self._failure_task.set_exception(err) diff --git a/google/cloud/pubsublite/testing/test_utils.py b/google/cloud/pubsublite/testing/test_utils.py index e93c16ba..ad199ca5 100644 --- a/google/cloud/pubsublite/testing/test_utils.py +++ b/google/cloud/pubsublite/testing/test_utils.py @@ -1,6 +1,8 @@ import asyncio from typing import List, Union, Any, TypeVar, Generic, Optional +from asynctest import CoroutineMock + T = TypeVar("T") @@ -27,5 +29,20 @@ async def waiter(*args, **kwargs): return waiter +class QueuePair: + called: asyncio.Queue + results: asyncio.Queue + + def __init__(self): + self.called = asyncio.Queue() + self.results = asyncio.Queue() + + +def wire_queues(mock: CoroutineMock) -> QueuePair: + queues = QueuePair() + mock.side_effect = make_queue_waiter(queues.called, queues.results) + return queues + + class Box(Generic[T]): val: Optional[T] diff --git a/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py new file mode 100644 index 00000000..a8c62d87 --- /dev/null +++ b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py @@ -0,0 +1,127 @@ +import asyncio +from typing import Callable, Set + +from asynctest.mock import MagicMock, call +import pytest +from google.api_core.exceptions import FailedPrecondition +from google.cloud.pubsub_v1.subscriber.message import Message +from google.pubsub_v1 import PubsubMessage + +from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import AssigningSubscriber +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.partition import Partition +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter, wire_queues + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +def mock_async_context_manager(cm): + cm.__aenter__.return_value = cm + return cm + + +@pytest.fixture() +def assigner(): + return mock_async_context_manager(MagicMock(spec=Assigner)) + + +@pytest.fixture() +def subscriber_factory(): + return MagicMock(spec=Callable[[Partition], AsyncSubscriber]) + + +@pytest.fixture() +def subscriber(assigner, subscriber_factory): + return AssigningSubscriber(assigner, subscriber_factory) + + +async def test_init(subscriber, assigner): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + assigner.__aenter__.assert_called_once() + await assign_queues.called.get() + assigner.get_assignment.assert_called_once() + assigner.__aexit__.assert_called_once() + + +async def test_initial_assignment(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2 + await assign_queues.results.put({Partition(1), Partition(2)}) + await assign_queues.called.get() + subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True) + sub1.__aenter__.assert_called_once() + sub2.__aenter__.assert_called_once() + sub1.__aexit__.assert_called_once() + sub2.__aexit__.assert_called_once() + + +async def test_assigner_failure(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + await assign_queues.results.put(FailedPrecondition("bad assign")) + with pytest.raises(FailedPrecondition): + await subscriber.read() + + +async def test_assignment_change(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub3 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition( + 1) else sub2 if partition == Partition(2) else sub3 + await assign_queues.results.put({Partition(1), Partition(2)}) + await assign_queues.called.get() + subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True) + sub1.__aenter__.assert_called_once() + sub2.__aenter__.assert_called_once() + await assign_queues.results.put({Partition(1), Partition(3)}) + await assign_queues.called.get() + subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2)), call(Partition(3))], any_order=True) + sub3.__aenter__.assert_called_once() + sub2.__aexit__.assert_called_once() + sub1.__aexit__.assert_called_once() + sub2.__aexit__.assert_called_once() + sub3.__aexit__.assert_called_once() + + +async def test_subscriber_failure(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub1_queues = wire_queues(sub1.read) + subscriber_factory.return_value = sub1 + await assign_queues.results.put({Partition(1)}) + await sub1_queues.called.get() + await sub1_queues.results.put(FailedPrecondition("sub failed")) + with pytest.raises(FailedPrecondition): + await subscriber.read() + + +async def test_delivery_from_multiple(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub1_queues = wire_queues(sub1.read) + sub2_queues = wire_queues(sub2.read) + subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2 + await assign_queues.results.put({Partition(1), Partition(2)}) + await sub1_queues.results.put(Message(PubsubMessage(message_id="1")._pb, "", 0, None)) + await sub2_queues.results.put(Message(PubsubMessage(message_id="2")._pb, "", 0, None)) + message_ids: Set[str] = set() + message_ids.add((await subscriber.read()).message_id) + message_ids.add((await subscriber.read()).message_id) + assert message_ids == {"1", "2"}