diff --git a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py index fab43dd8..61cf70bf 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py @@ -25,7 +25,7 @@ class AssigningSubscriber(AsyncSubscriber, PermanentFailable): _messages: "Queue[Message]" _assign_poller: Future - def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory): + def __init__(self, assigner: Assigner, subscriber_factory: PartitionSubscriberFactory): super().__init__() self._assigner = assigner self._subscriber_factory = subscriber_factory diff --git a/google/cloud/pubsublite/cloudpubsub/internal/streaming_pull_manager.py b/google/cloud/pubsublite/cloudpubsub/internal/streaming_pull_manager.py new file mode 100644 index 00000000..41296896 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/streaming_pull_manager.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from typing import Optional, Callable + +from google.api_core.exceptions import GoogleAPICallError + + +CloseCallback = Callable[["StreamingPullManager", Optional[GoogleAPICallError]], None] + + +class StreamingPullManager(ABC): + """The API expected by StreamingPullFuture.""" + @abstractmethod + def add_close_callback(self, close_callback: CloseCallback): + pass + + @abstractmethod + def close(self): + pass diff --git a/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py new file mode 100644 index 00000000..77d50d6c --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py @@ -0,0 +1,80 @@ +import concurrent.futures +import threading +from asyncio import CancelledError +from concurrent.futures.thread import ThreadPoolExecutor +from typing import ContextManager, Optional +from google.api_core.exceptions import GoogleAPICallError +from google.cloud.pubsublite.cloudpubsub.internal.managed_event_loop import ManagedEventLoop +from google.cloud.pubsublite.cloudpubsub.internal.streaming_pull_manager import StreamingPullManager, CloseCallback +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback + + +class SubscriberImpl(ContextManager, StreamingPullManager): + _underlying: AsyncSubscriber + _callback: MessageCallback + _executor: ThreadPoolExecutor + + _event_loop: ManagedEventLoop + + _poller_future: concurrent.futures.Future + _close_lock: threading.Lock + _failure: Optional[GoogleAPICallError] + _close_callback: Optional[CloseCallback] + _closed: bool + + def __init__(self, underlying: AsyncSubscriber, callback: MessageCallback, executor: ThreadPoolExecutor): + self._underlying = underlying + self._callback = callback + self._executor = executor + self._event_loop = ManagedEventLoop() + self._close_lock = threading.Lock() + self._failure = None + self._close_callback = None + self._closed = False + + def add_close_callback(self, close_callback: CloseCallback): + """ + A close callback must be set exactly once by the StreamingPullFuture managing this subscriber. + + This two-phase init model is made necessary by the requirements of StreamingPullFuture. + """ + with self._close_lock: + assert self._close_callback is None + self._close_callback = close_callback + + def close(self): + with self._close_lock: + if not self._closed: + self._closed = True + self.__exit__(None, None, None) + + def _fail(self, error: GoogleAPICallError): + self._failure = error + self.close() + + async def _poller(self): + try: + while True: + message = await self._underlying.read() + self._executor.submit(self._callback, message) + except GoogleAPICallError as e: + self._executor.submit(lambda: self._fail(e)) + + def __enter__(self): + assert self._close_callback is not None + self._event_loop.__enter__() + self._event_loop.submit(self._underlying.__aenter__()).result() + self._poller_future = self._event_loop.submit(self._poller()) + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + self._poller_future.cancel() + self._poller_future.result() + except CancelledError: + pass + self._event_loop.submit(self._underlying.__aexit__(exc_type, exc_value, traceback)).result() + self._event_loop.__exit__(exc_type, exc_value, traceback) + assert self._close_callback is not None + self._executor.shutdown(wait=False) # __exit__ may be called from the executor. + self._close_callback(self, self._failure) diff --git a/google/cloud/pubsublite/cloudpubsub/make_subscriber.py b/google/cloud/pubsublite/cloudpubsub/make_subscriber.py index 1da397bd..ea909dcf 100644 --- a/google/cloud/pubsublite/cloudpubsub/make_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/make_subscriber.py @@ -1,17 +1,19 @@ +from concurrent.futures.thread import ThreadPoolExecutor from typing import Optional, Mapping, Set, AsyncIterator from uuid import uuid4 from google.api_core.client_options import ClientOptions from google.auth.credentials import Credentials - +from google.cloud.pubsub_v1.subscriber.futures import StreamingPullFuture from google.cloud.pubsublite.cloudpubsub.flow_control_settings import FlowControlSettings from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import AckSetTrackerImpl from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import PartitionSubscriberFactory, \ AssigningSubscriber from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import SinglePartitionSubscriber +import google.cloud.pubsublite.cloudpubsub.internal.subscriber_impl as cps_subscriber from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer, DefaultMessageTransformer from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler, DefaultNackHandler -from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback from google.cloud.pubsublite.endpoints import regional_endpoint from google.cloud.pubsublite.internal.wire.assigner import Assigner from google.cloud.pubsublite.internal.wire.assigner_impl import AssignerImpl @@ -20,7 +22,7 @@ from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnectionFactory from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata from google.cloud.pubsublite.internal.wire.pubsub_context import pubsub_context -from google.cloud.pubsublite.internal.wire.subscriber_impl import SubscriberImpl +import google.cloud.pubsublite.internal.wire.subscriber_impl as wire_subscriber from google.cloud.pubsublite.partition import Partition from google.cloud.pubsublite.paths import SubscriptionPath from google.cloud.pubsublite.routing_metadata import subscription_routing_metadata @@ -63,14 +65,14 @@ def subscribe_connection_factory(requests: AsyncIterator[SubscribeRequest]): def cursor_connection_factory(requests: AsyncIterator[StreamingCommitCursorRequest]): return cursor_client.streaming_commit_cursor(requests, metadata=list(final_metadata.items())) - wire_subscriber = SubscriberImpl( + subscriber = wire_subscriber.SubscriberImpl( InitialSubscribeRequest(subscription=str(subscription), partition=partition.value), _DEFAULT_FLUSH_SECONDS, GapicConnectionFactory(subscribe_connection_factory)) committer = CommitterImpl( InitialCommitCursorRequest(subscription=str(subscription), partition=partition.value), _DEFAULT_FLUSH_SECONDS, GapicConnectionFactory(cursor_connection_factory)) ack_set_tracker = AckSetTrackerImpl(committer) - return SinglePartitionSubscriber(wire_subscriber, flow_control_settings, ack_set_tracker, nack_handler, + return SinglePartitionSubscriber(subscriber, flow_control_settings, ack_set_tracker, nack_handler, message_transformer) return factory @@ -124,3 +126,46 @@ def make_async_subscriber( metadata, per_partition_flow_control_settings, nack_handler, message_transformer) return AssigningSubscriber(assigner, partition_subscriber_factory) + + +def make_subscriber( + subscription: SubscriptionPath, + per_partition_flow_control_settings: FlowControlSettings, + callback: MessageCallback, + nack_handler: Optional[NackHandler] = None, + message_transformer: Optional[MessageTransformer] = None, + fixed_partitions: Optional[Set[Partition]] = None, + executor: Optional[ThreadPoolExecutor] = None, + credentials: Optional[Credentials] = None, + client_options: Optional[ClientOptions] = None, + metadata: Optional[Mapping[str, str]] = None) -> StreamingPullFuture: + """ + Make a Pub/Sub Lite Subscriber. + + Args: + subscription: The subscription to subscribe to. + per_partition_flow_control_settings: The flow control settings for each partition subscribed to. Note that these + settings apply to each partition individually, not in aggregate. + callback: The callback to call with each message. + nack_handler: An optional handler for when nack() is called on a Message. The default will fail the client. + message_transformer: An optional transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages. + fixed_partitions: A fixed set of partitions to subscribe to. If not present, will instead use auto-assignment. + executor: The executor to use for user callbacks. If not provided, will use the default constructed + ThreadPoolExecutor. If provided a single threaded executor, messages will be ordered per-partition, but take care + that the callback does not block for too long as it will impede forward progress on all partitions. + credentials: The credentials to use to connect. GOOGLE_DEFAULT_CREDENTIALS is used if None. + client_options: Other options to pass to the client. Note that if you pass any you must set api_endpoint. + metadata: Additional metadata to send with the RPC. + + Returns: + A StreamingPullFuture, managing the subscriber's lifetime. + """ + underlying = make_async_subscriber( + subscription, per_partition_flow_control_settings, nack_handler, message_transformer, fixed_partitions, credentials, + client_options, metadata) + if executor is None: + executor = ThreadPoolExecutor() + subscriber = cps_subscriber.SubscriberImpl(underlying, callback, executor) + future = StreamingPullFuture(subscriber) + subscriber.__enter__() + return future diff --git a/google/cloud/pubsublite/cloudpubsub/subscriber.py b/google/cloud/pubsublite/cloudpubsub/subscriber.py index 2c8922fd..3607eba1 100644 --- a/google/cloud/pubsublite/cloudpubsub/subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/subscriber.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import AsyncContextManager +from typing import AsyncContextManager, Callable from google.cloud.pubsub_v1.subscriber.message import Message @@ -23,3 +23,6 @@ async def read(self) -> Message: GoogleAPICallError: On a permanent error. """ raise NotImplementedError() + + +MessageCallback = Callable[[Message], None] diff --git a/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py index 69ac62a7..68ba89d8 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py @@ -1,5 +1,4 @@ -import asyncio -from typing import Callable, Set +from typing import Set from asynctest.mock import MagicMock, call import pytest diff --git a/tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py b/tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py new file mode 100644 index 00000000..a8debf6b --- /dev/null +++ b/tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py @@ -0,0 +1,93 @@ +import asyncio +import concurrent +from concurrent.futures.thread import ThreadPoolExecutor +from queue import Queue + +from asynctest.mock import MagicMock +import pytest +from google.api_core.exceptions import FailedPrecondition +from google.cloud.pubsub_v1.subscriber.message import Message +from google.pubsub_v1 import PubsubMessage + +from google.cloud.pubsublite.cloudpubsub.internal.streaming_pull_manager import CloseCallback +from google.cloud.pubsublite.cloudpubsub.internal.subscriber_impl import SubscriberImpl +from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback +from google.cloud.pubsublite.testing.test_utils import Box + + +@pytest.fixture() +def async_subscriber(): + subscriber = MagicMock(spec=AsyncSubscriber) + subscriber.__aenter__.return_value = subscriber + return subscriber + + +@pytest.fixture() +def message_callback(): + return MagicMock(spec=MessageCallback) + + +@pytest.fixture() +def close_callback(): + return MagicMock(spec=CloseCallback) + + +@pytest.fixture() +def subscriber(async_subscriber, message_callback, close_callback): + return SubscriberImpl(async_subscriber, message_callback, ThreadPoolExecutor(max_workers=1)) + + +async def sleep_forever(*args, **kwargs): + await asyncio.sleep(float("inf")) + + +def test_init(subscriber: SubscriberImpl, async_subscriber, close_callback): + async_subscriber.read.side_effect = sleep_forever + subscriber.add_close_callback(close_callback) + subscriber.__enter__() + async_subscriber.__aenter__.assert_called_once() + subscriber.close() + async_subscriber.__aexit__.assert_called_once() + close_callback.assert_called_once_with(subscriber, None) + + +def test_failed(subscriber: SubscriberImpl, async_subscriber, close_callback): + error = FailedPrecondition("bad read") + async_subscriber.read.side_effect = error + + close_called = concurrent.futures.Future() + close_callback.side_effect = lambda manager, err: close_called.set_result(None) + + subscriber.add_close_callback(close_callback) + subscriber.__enter__() + async_subscriber.__aenter__.assert_called_once() + close_called.result() + async_subscriber.__aexit__.assert_called_once() + close_callback.assert_called_once_with(subscriber, error) + + +def test_messages_received(subscriber: SubscriberImpl, async_subscriber, message_callback, close_callback): + message1 = Message(PubsubMessage(message_id="1")._pb, "", 0, None) + message2 = Message(PubsubMessage(message_id="2")._pb, "", 0, None) + + counter = Box[int]() + counter.val = 0 + + async def on_read() -> Message: + counter.val += 1 + if counter.val == 1: + return message1 + if counter.val == 2: + return message2 + await sleep_forever() + + async_subscriber.read.side_effect = on_read + + results = Queue() + message_callback.side_effect = lambda m: results.put(m.message_id) + + subscriber.add_close_callback(close_callback) + subscriber.__enter__() + assert results.get() == "1" + assert results.get() == "2" + subscriber.close()