From 8a702d2c6a8af95af9848c8a9e6e330f8ea05206 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Tue, 22 Sep 2020 21:57:59 -0400 Subject: [PATCH 1/3] feat: Implement SinglePartitionSubscriber. This handles mapping a single partition to a Cloud Pub/Sub Like asynchronous subscriber. --- .../cloudpubsub/flow_control_settings.py | 11 ++ .../internal/ack_set_tracker_impl.py | 1 + .../internal/async_publisher_impl.py | 1 + .../cloudpubsub/internal/publisher_impl.py | 1 + .../internal/single_partition_subscriber.py | 119 ++++++++++++ .../cloudpubsub/message_transformer.py | 30 +++ .../pubsublite/cloudpubsub/nack_handler.py | 23 +++ .../pubsublite/cloudpubsub/subscriber.py | 25 +++ setup.py | 2 +- .../single_partition_subscriber_test.py | 177 ++++++++++++++++++ 10 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 google/cloud/pubsublite/cloudpubsub/flow_control_settings.py create mode 100644 google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py create mode 100644 google/cloud/pubsublite/cloudpubsub/message_transformer.py create mode 100644 google/cloud/pubsublite/cloudpubsub/nack_handler.py create mode 100644 google/cloud/pubsublite/cloudpubsub/subscriber.py create mode 100644 tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py diff --git a/google/cloud/pubsublite/cloudpubsub/flow_control_settings.py b/google/cloud/pubsublite/cloudpubsub/flow_control_settings.py new file mode 100644 index 00000000..7837bb80 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/flow_control_settings.py @@ -0,0 +1,11 @@ +from typing import NamedTuple + + +class FlowControlSettings(NamedTuple): + messages_outstanding: int + bytes_outstanding: int + + +_MAX_INT64 = 0x7FFFFFFFFFFFFFFF + +DISABLED_FLOW_CONTROL = FlowControlSettings(_MAX_INT64, _MAX_INT64) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py index 45f0cd56..99232d45 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py @@ -15,6 +15,7 @@ class AckSetTrackerImpl(AckSetTracker): _acks: "queue.PriorityQueue[int]" def __init__(self, committer: Committer): + super().__init__() self._committer = committer self._receipts = deque() self._acks = queue.PriorityQueue() diff --git a/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py index 7828ee8d..a05d90f2 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py @@ -11,6 +11,7 @@ class AsyncPublisherImpl(AsyncPublisher): _publisher: Publisher def __init__(self, publisher: Publisher): + super().__init__() self._publisher = publisher async def publish(self, data: bytes, ordering_key: str = "", **attrs: Mapping[str, str]) -> str: diff --git a/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py index 25419580..9e760008 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py @@ -10,6 +10,7 @@ class PublisherImpl(Publisher): _underlying: AsyncPublisher def __init__(self, underlying: AsyncPublisher): + super().__init__() self._managed_loop = ManagedEventLoop() self._underlying = underlying diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py new file mode 100644 index 00000000..51c5eab9 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py @@ -0,0 +1,119 @@ +import asyncio +from typing import Union, Dict, NamedTuple +import queue + +from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError +from google.cloud.pubsub_v1.subscriber.message import Message +from google.pubsub_v1 import PubsubMessage + +from google.cloud.pubsublite.cloudpubsub.flow_control_settings import FlowControlSettings +from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker +from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer +from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber +from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable +from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage +from google.cloud.pubsub_v1.subscriber._protocol import requests + + +class _SizedMessage(NamedTuple): + message: PubsubMessage + size_bytes: int + + +class SinglePartitionSubscriber(PermanentFailable, AsyncSubscriber): + _underlying: Subscriber + _flow_control_settings: FlowControlSettings + _ack_set_tracker: AckSetTracker + _nack_handler: NackHandler + _transformer: MessageTransformer + + _queue: queue.Queue + _messages_by_offset: Dict[int, _SizedMessage] + _looper_future: asyncio.Future + + def __init__(self, underlying: Subscriber, flow_control_settings: FlowControlSettings, ack_set_tracker: AckSetTracker, + nack_handler: NackHandler, transformer: MessageTransformer): + super().__init__() + self._underlying = underlying + self._flow_control_settings = flow_control_settings + self._ack_set_tracker = ack_set_tracker + self._nack_handler = nack_handler + self._transformer = transformer + + self._queue = queue.Queue() + self._messages_by_offset = {} + + async def read(self) -> Message: + message: SequencedMessage = await self.await_unless_failed(self._underlying.read()) + try: + cps_message = self._transformer.transform(message) + offset = message.cursor.offset + self._ack_set_tracker.track(offset) + self._messages_by_offset[offset] = _SizedMessage(cps_message, message.size_bytes) + wrapped_message = Message(cps_message._pb, ack_id=str(offset), delivery_attempt=0, request_queue=self._queue) + return wrapped_message + except GoogleAPICallError as e: + self.fail(e) + raise e + + async def _handle_ack(self, message: requests.AckRequest): + offset = int(message.ack_id) + await self._underlying.allow_flow( + FlowControlRequest(allowed_messages=1, allowed_bytes=self._messages_by_offset[offset].size_bytes)) + del self._messages_by_offset[offset] + try: + await self._ack_set_tracker.ack(offset) + except GoogleAPICallError as e: + self.fail(e) + + def _handle_nack(self, message: requests.NackRequest): + offset = int(message.ack_id) + sized_message = self._messages_by_offset[offset] + try: + self._nack_handler.on_nack(sized_message.message, + lambda: self._queue.put(requests.AckRequest( + ack_id=message.ack_id, + byte_size=0, # Ignored + time_to_ack=0, # Ignored + ordering_key="" # Ignored + ))) + except GoogleAPICallError as e: + self.fail(e) + + async def _handle_queue_message(self, message: Union[ + requests.AckRequest, requests.DropRequest, requests.ModAckRequest, requests.NackRequest]): + if isinstance(message, requests.DropRequest) or isinstance(message, requests.ModAckRequest): + self.fail(FailedPrecondition("Called internal method of google.cloud.pubsub_v1.subscriber.message.Message " + f"Pub/Sub Lite does not support: {message}")) + elif isinstance(message, requests.AckRequest): + await self._handle_ack(message) + else: + self._handle_nack(message) + + async def _looper(self): + while True: + try: + queue_message = self._queue.get_nowait() + await self._handle_queue_message(queue_message) + except queue.Empty: + await asyncio.sleep(.1) + + async def __aenter__(self): + await self._ack_set_tracker.__aenter__() + await self._underlying.__aenter__() + self._looper_future = asyncio.ensure_future(self._looper()) + await self._underlying.allow_flow(FlowControlRequest( + allowed_messages=self._flow_control_settings.messages_outstanding, + allowed_bytes=self._flow_control_settings.bytes_outstanding)) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self._looper_future.cancel() + try: + await self._looper_future + except asyncio.CancelledError: + pass + await self._underlying.__aexit__(exc_type, exc_value, traceback) + await self._ack_set_tracker.__aexit__(exc_type, exc_value, traceback) diff --git a/google/cloud/pubsublite/cloudpubsub/message_transformer.py b/google/cloud/pubsublite/cloudpubsub/message_transformer.py new file mode 100644 index 00000000..a147c8d0 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/message_transformer.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from collections import Callable + +from google.pubsub_v1 import PubsubMessage + +from google.cloud.pubsublite.cloudpubsub.message_transforms import to_cps_subscribe_message +from google.cloud.pubsublite_v1 import SequencedMessage + + +class MessageTransformer(ABC): + """ + A MessageTransformer turns Pub/Sub Lite message protos into Pub/Sub message protos. + """ + + @abstractmethod + def transform(self, source: SequencedMessage) -> PubsubMessage: + """Transform a SequencedMessage to a PubsubMessage. + + Args: + source: The message to transform. + + Raises: + GoogleAPICallError: To fail the client if raised inline. + """ + pass + + +class DefaultMessageTransformer(MessageTransformer): + def transform(self, source: SequencedMessage) -> PubsubMessage: + return to_cps_subscribe_message(source) diff --git a/google/cloud/pubsublite/cloudpubsub/nack_handler.py b/google/cloud/pubsublite/cloudpubsub/nack_handler.py new file mode 100644 index 00000000..2138880c --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/nack_handler.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from typing import Callable + +from google.pubsub_v1 import PubsubMessage + + +class NackHandler(ABC): + """ + A NackHandler handles calls to the nack() method which is not expressible in Pub/Sub Lite. + """ + + @abstractmethod + def on_nack(self, message: PubsubMessage, ack: Callable[[], None]): + """Handle a negative acknowledgement. ack must eventually be called. + + Args: + message: The nacked message. + ack: A callable to acknowledge the underlying message. This must eventually be called. + + Raises: + GoogleAPICallError: To fail the client if raised inline. + """ + pass diff --git a/google/cloud/pubsublite/cloudpubsub/subscriber.py b/google/cloud/pubsublite/cloudpubsub/subscriber.py new file mode 100644 index 00000000..2c8922fd --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/subscriber.py @@ -0,0 +1,25 @@ +from abc import abstractmethod +from typing import AsyncContextManager + +from google.cloud.pubsub_v1.subscriber.message import Message + + +class AsyncSubscriber(AsyncContextManager): + """ + A Cloud Pub/Sub asynchronous subscriber. + """ + @abstractmethod + async def read(self) -> Message: + """ + Read the next message off of the stream. + + Returns: + The next message. ack() or nack() must eventually be called exactly once. + + Pub/Sub Lite does not support nack() by default- if you do call nack(), it will immediately fail the client + unless you have a NackHandler installed. + + Raises: + GoogleAPICallError: On a permanent error. + """ + raise NotImplementedError() diff --git a/setup.py b/setup.py index 2810f54d..fde78f97 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ "google-api-core >= 1.22.0", "absl-py >= 0.9.0", "proto-plus >= 0.4.0", - "google-cloud-pubsub >= 1.7.0", + "google-cloud-pubsub >= 2.1.0", "grpcio", "setuptools" ] diff --git a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py new file mode 100644 index 00000000..be39f4dc --- /dev/null +++ b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py @@ -0,0 +1,177 @@ +import asyncio +import datetime +from typing import Callable + +from asynctest.mock import MagicMock, call +import pytest +from google.api_core.exceptions import FailedPrecondition +from google.cloud.pubsub_v1.subscriber.message import Message +from google.protobuf.timestamp_pb2 import Timestamp +from google.pubsub_v1 import PubsubMessage + +from google.cloud.pubsublite.cloudpubsub.flow_control_settings import FlowControlSettings +from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker +from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import SinglePartitionSubscriber +from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer +from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber +from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter +from google.cloud.pubsublite_v1 import Cursor, FlowControlRequest, SequencedMessage + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +def mock_async_context_manager(cm): + cm.__aenter__.return_value = cm + return cm + + +@pytest.fixture() +def underlying(): + return mock_async_context_manager(MagicMock(spec=Subscriber)) + + +@pytest.fixture() +def flow_control_settings(): + return FlowControlSettings(1000, 1000) + + +@pytest.fixture() +def initial_flow_request(flow_control_settings): + return FlowControlRequest( + allowed_messages=flow_control_settings.messages_outstanding, + allowed_bytes=flow_control_settings.bytes_outstanding) + + +@pytest.fixture() +def ack_set_tracker(): + return mock_async_context_manager(MagicMock(spec=AckSetTracker)) + + +@pytest.fixture() +def nack_handler(): + return MagicMock(spec=NackHandler) + + +@pytest.fixture() +def transformer(): + result = MagicMock(spec=MessageTransformer) + result.transform.side_effect = lambda source: PubsubMessage(message_id=str(source.cursor.offset)) + return result + + +@pytest.fixture() +def subscriber(underlying, flow_control_settings, ack_set_tracker, nack_handler, transformer): + return SinglePartitionSubscriber(underlying, flow_control_settings, ack_set_tracker, nack_handler, transformer) + + +async def test_init(subscriber, underlying, ack_set_tracker, initial_flow_request): + async with subscriber: + underlying.__aenter__.assert_called_once() + ack_set_tracker.__aenter__.assert_called_once() + underlying.allow_flow.assert_called_once_with(initial_flow_request) + underlying.__aexit__.assert_called_once() + ack_set_tracker.__aexit__.assert_called_once() + + +async def test_failed_transform(subscriber, underlying, transformer): + async with subscriber: + transformer.transform.side_effect = FailedPrecondition("Bad message") + underlying.read.return_value = SequencedMessage() + with pytest.raises(FailedPrecondition): + await subscriber.read() + + +async def test_ack(subscriber: AsyncSubscriber, underlying, transformer, ack_set_tracker): + ack_called_queue = asyncio.Queue() + ack_result_queue = asyncio.Queue() + ack_set_tracker.ack.side_effect = make_queue_waiter(ack_called_queue, ack_result_queue) + async with subscriber: + message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) + message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10) + underlying.read.return_value = message_1 + read_1: Message = await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1)]) + assert read_1.message_id == "1" + underlying.read.return_value = message_2 + read_2: Message = await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1), call(2)]) + assert read_2.message_id == "2" + read_2.ack() + await ack_called_queue.get() + await ack_result_queue.put(None) + ack_set_tracker.ack.assert_has_calls([call(2)]) + read_1.ack() + await ack_called_queue.get() + await ack_result_queue.put(None) + ack_set_tracker.ack.assert_has_calls([call(2), call(1)]) + + +async def test_track_failure(subscriber: SinglePartitionSubscriber, underlying, transformer, ack_set_tracker): + async with subscriber: + ack_set_tracker.track.side_effect = FailedPrecondition("Bad track") + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) + underlying.read.return_value = message + with pytest.raises(FailedPrecondition): + await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1)]) + + +async def test_ack_failure(subscriber: SinglePartitionSubscriber, underlying, transformer, ack_set_tracker): + ack_called_queue = asyncio.Queue() + ack_result_queue = asyncio.Queue() + ack_set_tracker.ack.side_effect = make_queue_waiter(ack_called_queue, ack_result_queue) + async with subscriber: + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) + underlying.read.return_value = message + read: Message = await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1)]) + read.ack() + await ack_called_queue.get() + ack_set_tracker.ack.assert_has_calls([call(1)]) + await ack_result_queue.put(FailedPrecondition("Bad ack")) + + async def sleep_forever(): + await asyncio.sleep(float("inf")) + underlying.read.side_effect = sleep_forever + with pytest.raises(FailedPrecondition): + await subscriber.read() + + +async def test_nack_failure(subscriber: SinglePartitionSubscriber, underlying, transformer, ack_set_tracker, nack_handler): + async with subscriber: + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) + underlying.read.return_value = message + read: Message = await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1)]) + nack_handler.on_nack.side_effect = FailedPrecondition("Bad nack") + read.nack() + + async def sleep_forever(): + await asyncio.sleep(float("inf")) + underlying.read.side_effect = sleep_forever + with pytest.raises(FailedPrecondition): + await subscriber.read() + + +async def test_nack_calls_ack(subscriber: SinglePartitionSubscriber, underlying, transformer, ack_set_tracker, nack_handler): + ack_called_queue = asyncio.Queue() + ack_result_queue = asyncio.Queue() + ack_set_tracker.ack.side_effect = make_queue_waiter(ack_called_queue, ack_result_queue) + async with subscriber: + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) + underlying.read.return_value = message + read: Message = await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1)]) + + def on_nack(nacked: PubsubMessage, ack: Callable[[], None]): + assert nacked.message_id == "1" + ack() + nack_handler.on_nack.side_effect = on_nack + read.nack() + await ack_called_queue.get() + await ack_result_queue.put(None) + ack_set_tracker.ack.assert_has_calls([call(1)]) + From 9fa4c3726e145308b691923e8bb336b84907679d Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Tue, 22 Sep 2020 22:01:36 -0400 Subject: [PATCH 2/3] feat: Add DefaultNackHandler. --- google/cloud/pubsublite/cloudpubsub/nack_handler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/google/cloud/pubsublite/cloudpubsub/nack_handler.py b/google/cloud/pubsublite/cloudpubsub/nack_handler.py index 2138880c..6455efa5 100644 --- a/google/cloud/pubsublite/cloudpubsub/nack_handler.py +++ b/google/cloud/pubsublite/cloudpubsub/nack_handler.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Callable +from google.api_core.exceptions import FailedPrecondition from google.pubsub_v1 import PubsubMessage @@ -21,3 +22,10 @@ def on_nack(self, message: PubsubMessage, ack: Callable[[], None]): GoogleAPICallError: To fail the client if raised inline. """ pass + + +class DefaultNackHandler(NackHandler): + def on_nack(self, message: PubsubMessage, ack: Callable[[], None]): + raise FailedPrecondition( + "You may not nack messages by default when using a PubSub Lite client. See NackHandler for how to customize" + " this.") From c89b8639dc7fdba583172f3030a35d3e51f2c02d Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Wed, 23 Sep 2020 11:55:10 -0400 Subject: [PATCH 3/3] feat: Add AssigningSubscriber. This handles changing partition assignments and creates AsyncSubscribers per-partition. --- .../internal/assigning_subscriber.py | 73 ++++++++++ .../internal/managed_event_loop.py | 2 +- .../internal/wait_ignore_cancelled.py | 9 ++ .../internal/wire/permanent_failable.py | 15 ++- google/cloud/pubsublite/testing/test_utils.py | 17 +++ .../internal/assigning_subscriber_test.py | 127 ++++++++++++++++++ 6 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py create mode 100644 google/cloud/pubsublite/internal/wait_ignore_cancelled.py create mode 100644 tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py diff --git a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py new file mode 100644 index 00000000..19877d75 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py @@ -0,0 +1,73 @@ +from asyncio import Future, Queue, ensure_future +from typing import Callable, NamedTuple, Dict, Set + +from google.cloud.pubsub_v1.subscriber.message import Message + +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable +from google.cloud.pubsublite.partition import Partition + +_PartitionSubscriberFactory = Callable[[Partition], AsyncSubscriber] + + +class _RunningSubscriber(NamedTuple): + subscriber: AsyncSubscriber + poller: Future + + +class AssigningSubscriber(AsyncSubscriber, PermanentFailable): + _assigner: Assigner + _subscriber_factory: _PartitionSubscriberFactory + + _subscribers: Dict[Partition, _RunningSubscriber] + _messages: "Queue[Message]" + _assign_poller: Future + + def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory): + super().__init__() + self._assigner = assigner + self._subscriber_factory = subscriber_factory + self._subscribers = {} + self._messages = Queue() + + async def read(self) -> Message: + return await self.await_unless_failed(self._messages.get()) + + async def _subscribe_action(self, subscriber: AsyncSubscriber): + message = await subscriber.read() + await self._messages.put(message) + + async def _start_subscriber(self, partition: Partition): + new_subscriber = self._subscriber_factory(partition) + await new_subscriber.__aenter__() + poller = ensure_future(self.run_poller(lambda: self._subscribe_action(new_subscriber))) + self._subscribers[partition] = _RunningSubscriber(new_subscriber, poller) + + async def _stop_subscriber(self, running: _RunningSubscriber): + running.poller.cancel() + await wait_ignore_cancelled(running.poller) + await running.subscriber.__aexit__(None, None, None) + + async def _assign_action(self): + assignment: Set[Partition] = await self._assigner.get_assignment() + added_partitions = assignment - self._subscribers.keys() + removed_partitions = self._subscribers.keys() - assignment + for partition in added_partitions: + await self._start_subscriber(partition) + for partition in removed_partitions: + await self._stop_subscriber(self._subscribers[partition]) + del self._subscribers[partition] + + async def __aenter__(self): + await self._assigner.__aenter__() + self._assign_poller = ensure_future(self.run_poller(self._assign_action)) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self._assign_poller.cancel() + await wait_ignore_cancelled(self._assign_poller) + await self._assigner.__aexit__(exc_type, exc_value, traceback) + for running in self._subscribers.values(): + await self._stop_subscriber(running) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py b/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py index 7840bac3..7e787ca9 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py @@ -15,7 +15,7 @@ def __init__(self): def __enter__(self): self._thread.start() - def __exit__(self, __exc_type, __exc_value, __traceback): + def __exit__(self, exc_type, exc_value, traceback): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() diff --git a/google/cloud/pubsublite/internal/wait_ignore_cancelled.py b/google/cloud/pubsublite/internal/wait_ignore_cancelled.py new file mode 100644 index 00000000..d7e16499 --- /dev/null +++ b/google/cloud/pubsublite/internal/wait_ignore_cancelled.py @@ -0,0 +1,9 @@ +from asyncio import CancelledError +from typing import Awaitable + + +async def wait_ignore_cancelled(awaitable: Awaitable): + try: + await awaitable + except CancelledError: + pass diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index 7d8a01f8..f3fadc66 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -1,5 +1,5 @@ import asyncio -from typing import Awaitable, TypeVar, Optional +from typing import Awaitable, TypeVar, Optional, Callable from google.api_core.exceptions import GoogleAPICallError @@ -31,6 +31,19 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T: task.cancel() raise self._failure_task.exception() + async def run_poller(self, poll_action: Callable[[], Awaitable[None]]): + """ + Run a polling loop, which runs poll_action forever unless this is failed. + Args: + poll_action: A callable returning an awaitable to run in a loop. Note that async functions which return once + satisfy this. + """ + try: + while True: + await self.await_unless_failed(poll_action()) + except GoogleAPICallError as e: + self.fail(e) + def fail(self, err: GoogleAPICallError): if not self._failure_task.done(): self._failure_task.set_exception(err) diff --git a/google/cloud/pubsublite/testing/test_utils.py b/google/cloud/pubsublite/testing/test_utils.py index e93c16ba..ad199ca5 100644 --- a/google/cloud/pubsublite/testing/test_utils.py +++ b/google/cloud/pubsublite/testing/test_utils.py @@ -1,6 +1,8 @@ import asyncio from typing import List, Union, Any, TypeVar, Generic, Optional +from asynctest import CoroutineMock + T = TypeVar("T") @@ -27,5 +29,20 @@ async def waiter(*args, **kwargs): return waiter +class QueuePair: + called: asyncio.Queue + results: asyncio.Queue + + def __init__(self): + self.called = asyncio.Queue() + self.results = asyncio.Queue() + + +def wire_queues(mock: CoroutineMock) -> QueuePair: + queues = QueuePair() + mock.side_effect = make_queue_waiter(queues.called, queues.results) + return queues + + class Box(Generic[T]): val: Optional[T] diff --git a/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py new file mode 100644 index 00000000..a8c62d87 --- /dev/null +++ b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py @@ -0,0 +1,127 @@ +import asyncio +from typing import Callable, Set + +from asynctest.mock import MagicMock, call +import pytest +from google.api_core.exceptions import FailedPrecondition +from google.cloud.pubsub_v1.subscriber.message import Message +from google.pubsub_v1 import PubsubMessage + +from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import AssigningSubscriber +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.partition import Partition +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter, wire_queues + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +def mock_async_context_manager(cm): + cm.__aenter__.return_value = cm + return cm + + +@pytest.fixture() +def assigner(): + return mock_async_context_manager(MagicMock(spec=Assigner)) + + +@pytest.fixture() +def subscriber_factory(): + return MagicMock(spec=Callable[[Partition], AsyncSubscriber]) + + +@pytest.fixture() +def subscriber(assigner, subscriber_factory): + return AssigningSubscriber(assigner, subscriber_factory) + + +async def test_init(subscriber, assigner): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + assigner.__aenter__.assert_called_once() + await assign_queues.called.get() + assigner.get_assignment.assert_called_once() + assigner.__aexit__.assert_called_once() + + +async def test_initial_assignment(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2 + await assign_queues.results.put({Partition(1), Partition(2)}) + await assign_queues.called.get() + subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True) + sub1.__aenter__.assert_called_once() + sub2.__aenter__.assert_called_once() + sub1.__aexit__.assert_called_once() + sub2.__aexit__.assert_called_once() + + +async def test_assigner_failure(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + await assign_queues.results.put(FailedPrecondition("bad assign")) + with pytest.raises(FailedPrecondition): + await subscriber.read() + + +async def test_assignment_change(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub3 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition( + 1) else sub2 if partition == Partition(2) else sub3 + await assign_queues.results.put({Partition(1), Partition(2)}) + await assign_queues.called.get() + subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True) + sub1.__aenter__.assert_called_once() + sub2.__aenter__.assert_called_once() + await assign_queues.results.put({Partition(1), Partition(3)}) + await assign_queues.called.get() + subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2)), call(Partition(3))], any_order=True) + sub3.__aenter__.assert_called_once() + sub2.__aexit__.assert_called_once() + sub1.__aexit__.assert_called_once() + sub2.__aexit__.assert_called_once() + sub3.__aexit__.assert_called_once() + + +async def test_subscriber_failure(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub1_queues = wire_queues(sub1.read) + subscriber_factory.return_value = sub1 + await assign_queues.results.put({Partition(1)}) + await sub1_queues.called.get() + await sub1_queues.results.put(FailedPrecondition("sub failed")) + with pytest.raises(FailedPrecondition): + await subscriber.read() + + +async def test_delivery_from_multiple(subscriber, assigner, subscriber_factory): + assign_queues = wire_queues(assigner.get_assignment) + async with subscriber: + await assign_queues.called.get() + sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber)) + sub1_queues = wire_queues(sub1.read) + sub2_queues = wire_queues(sub2.read) + subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2 + await assign_queues.results.put({Partition(1), Partition(2)}) + await sub1_queues.results.put(Message(PubsubMessage(message_id="1")._pb, "", 0, None)) + await sub2_queues.results.put(Message(PubsubMessage(message_id="2")._pb, "", 0, None)) + message_ids: Set[str] = set() + message_ids.add((await subscriber.read()).message_id) + message_ids.add((await subscriber.read()).message_id) + assert message_ids == {"1", "2"}