Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement CPS non-asyncio subscriber #25

Merged
merged 3 commits into from Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -25,7 +25,7 @@ class AssigningSubscriber(AsyncSubscriber, PermanentFailable):
_messages: "Queue[Message]"
_assign_poller: Future

def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory):
def __init__(self, assigner: Assigner, subscriber_factory: PartitionSubscriberFactory):
super().__init__()
self._assigner = assigner
self._subscriber_factory = subscriber_factory
Expand Down
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from typing import Optional, Callable

from google.api_core.exceptions import GoogleAPICallError


CloseCallback = Callable[["StreamingPullManager", Optional[GoogleAPICallError]], None]


class StreamingPullManager(ABC):
"""The API expected by StreamingPullFuture."""
@abstractmethod
def add_close_callback(self, close_callback: CloseCallback):
pass

@abstractmethod
def close(self):
pass
75 changes: 75 additions & 0 deletions google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py
@@ -0,0 +1,75 @@
import concurrent.futures
import threading
from asyncio import CancelledError
from concurrent.futures.thread import ThreadPoolExecutor
from typing import ContextManager, Optional
from google.api_core.exceptions import GoogleAPICallError
from google.cloud.pubsublite.cloudpubsub.internal.managed_event_loop import ManagedEventLoop
from google.cloud.pubsublite.cloudpubsub.internal.streaming_pull_manager import StreamingPullManager, CloseCallback
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback


class SubscriberImpl(ContextManager, StreamingPullManager):
_underlying: AsyncSubscriber
_callback: MessageCallback
_executor: ThreadPoolExecutor

_event_loop: ManagedEventLoop

_poller_future: concurrent.futures.Future
_close_lock: threading.Lock
_failure: Optional[GoogleAPICallError]
_close_callback: Optional[CloseCallback]
_closed: bool

def __init__(self, underlying: AsyncSubscriber, callback: MessageCallback, executor: ThreadPoolExecutor):
self._underlying = underlying
self._callback = callback
self._executor = executor
self._event_loop = ManagedEventLoop()
self._close_lock = threading.Lock()
self._failure = None
self._close_callback = None
self._closed = False

def add_close_callback(self, close_callback: CloseCallback):
with self._close_lock:
assert self._close_callback is None
self._close_callback = close_callback

def close(self):
with self._close_lock:
if not self._closed:
self._closed = True
self.__exit__(None, None, None)

def _fail(self, error: GoogleAPICallError):
self._failure = error
self.close()

async def _poller(self):
try:
while True:
message = await self._underlying.read()
self._executor.submit(self._callback, message)
except GoogleAPICallError as e:
self._executor.submit(lambda: self._fail(e))

def __enter__(self):
assert self._close_callback is not None
self._event_loop.__enter__()
self._event_loop.submit(self._underlying.__aenter__()).result()
self._poller_future = self._event_loop.submit(self._poller())
return self

def __exit__(self, exc_type, exc_value, traceback):
try:
self._poller_future.cancel()
self._poller_future.result()
except CancelledError:
pass
self._event_loop.submit(self._underlying.__aexit__(exc_type, exc_value, traceback)).result()
self._event_loop.__exit__(exc_type, exc_value, traceback)
assert self._close_callback is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we will fail if close_callback is None. I'd personally make this optional, but if you choose not to do this, can you document this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self._executor.shutdown(wait=False) # __exit__ may be called from the executor.
self._close_callback(self, self._failure)
55 changes: 50 additions & 5 deletions google/cloud/pubsublite/cloudpubsub/make_subscriber.py
@@ -1,17 +1,19 @@
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, Mapping, Set, AsyncIterator
from uuid import uuid4

from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials

from google.cloud.pubsub_v1.subscriber.futures import StreamingPullFuture
from google.cloud.pubsublite.cloudpubsub.flow_control_settings import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import AckSetTrackerImpl
from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import PartitionSubscriberFactory, \
AssigningSubscriber
from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import SinglePartitionSubscriber
import google.cloud.pubsublite.cloudpubsub.internal.subscriber_impl as cps_subscriber
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer, DefaultMessageTransformer
from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler, DefaultNackHandler
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback
from google.cloud.pubsublite.endpoints import regional_endpoint
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.assigner_impl import AssignerImpl
Expand All @@ -20,7 +22,7 @@
from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnectionFactory
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
from google.cloud.pubsublite.internal.wire.pubsub_context import pubsub_context
from google.cloud.pubsublite.internal.wire.subscriber_impl import SubscriberImpl
import google.cloud.pubsublite.internal.wire.subscriber_impl as wire_subscriber
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite.paths import SubscriptionPath
from google.cloud.pubsublite.routing_metadata import subscription_routing_metadata
Expand Down Expand Up @@ -63,14 +65,14 @@ def subscribe_connection_factory(requests: AsyncIterator[SubscribeRequest]):
def cursor_connection_factory(requests: AsyncIterator[StreamingCommitCursorRequest]):
return cursor_client.streaming_commit_cursor(requests, metadata=list(final_metadata.items()))

wire_subscriber = SubscriberImpl(
subscriber = wire_subscriber.SubscriberImpl(
InitialSubscribeRequest(subscription=str(subscription), partition=partition.value),
_DEFAULT_FLUSH_SECONDS, GapicConnectionFactory(subscribe_connection_factory))
committer = CommitterImpl(
InitialCommitCursorRequest(subscription=str(subscription), partition=partition.value),
_DEFAULT_FLUSH_SECONDS, GapicConnectionFactory(cursor_connection_factory))
ack_set_tracker = AckSetTrackerImpl(committer)
return SinglePartitionSubscriber(wire_subscriber, flow_control_settings, ack_set_tracker, nack_handler,
return SinglePartitionSubscriber(subscriber, flow_control_settings, ack_set_tracker, nack_handler,
message_transformer)

return factory
Expand Down Expand Up @@ -124,3 +126,46 @@ def make_async_subscriber(
metadata, per_partition_flow_control_settings,
nack_handler, message_transformer)
return AssigningSubscriber(assigner, partition_subscriber_factory)


def make_subscriber(
subscription: SubscriptionPath,
per_partition_flow_control_settings: FlowControlSettings,
callback: MessageCallback,
nack_handler: Optional[NackHandler] = None,
message_transformer: Optional[MessageTransformer] = None,
fixed_partitions: Optional[Set[Partition]] = None,
executor: Optional[ThreadPoolExecutor] = None,
credentials: Optional[Credentials] = None,
client_options: Optional[ClientOptions] = None,
metadata: Optional[Mapping[str, str]] = None) -> StreamingPullFuture:
"""
Make a Pub/Sub Lite Subscriber.

Args:
subscription: The subscription to subscribe to.
per_partition_flow_control_settings: The flow control settings for each partition subscribed to. Note that these
settings apply to each partition individually, not in aggregate.
callback: The callback to call with each message.
nack_handler: An optional handler for when nack() is called on a Message. The default will fail the client.
message_transformer: An optional transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages.
fixed_partitions: A fixed set of partitions to subscribe to. If not present, will instead use auto-assignment.
executor: The executor to use for user callbacks. If not provided, will use the default constructed
ThreadPoolExecutor. If provided a single threaded executor, messages will be ordered per-partition, but take care
that the callback does not block for too long as it will impede forward progress on all partitions.
credentials: The credentials to use to connect. GOOGLE_DEFAULT_CREDENTIALS is used if None.
client_options: Other options to pass to the client. Note that if you pass any you must set api_endpoint.
metadata: Additional metadata to send with the RPC.

Returns:
A StreamingPullFuture, managing the subscriber's lifetime.
"""
underlying = make_async_subscriber(
subscription, per_partition_flow_control_settings, nack_handler, message_transformer, fixed_partitions, credentials,
client_options, metadata)
if executor is None:
executor = ThreadPoolExecutor()
subscriber = cps_subscriber.SubscriberImpl(underlying, callback, executor)
future = StreamingPullFuture(subscriber)
subscriber.__enter__()
return future
5 changes: 4 additions & 1 deletion google/cloud/pubsublite/cloudpubsub/subscriber.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import AsyncContextManager
from typing import AsyncContextManager, Callable

from google.cloud.pubsub_v1.subscriber.message import Message

Expand All @@ -23,3 +23,6 @@ async def read(self) -> Message:
GoogleAPICallError: On a permanent error.
"""
raise NotImplementedError()


MessageCallback = Callable[[Message], None]
@@ -1,5 +1,4 @@
import asyncio
from typing import Callable, Set
from typing import Set

from asynctest.mock import MagicMock, call
import pytest
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py
@@ -0,0 +1,93 @@
import asyncio
import concurrent
from concurrent.futures.thread import ThreadPoolExecutor
from queue import Queue

from asynctest.mock import MagicMock
import pytest
from google.api_core.exceptions import FailedPrecondition
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.cloudpubsub.internal.streaming_pull_manager import CloseCallback
from google.cloud.pubsublite.cloudpubsub.internal.subscriber_impl import SubscriberImpl
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback
from google.cloud.pubsublite.testing.test_utils import Box


@pytest.fixture()
def async_subscriber():
subscriber = MagicMock(spec=AsyncSubscriber)
subscriber.__aenter__.return_value = subscriber
return subscriber


@pytest.fixture()
def message_callback():
return MagicMock(spec=MessageCallback)


@pytest.fixture()
def close_callback():
return MagicMock(spec=CloseCallback)


@pytest.fixture()
def subscriber(async_subscriber, message_callback, close_callback):
return SubscriberImpl(async_subscriber, message_callback, ThreadPoolExecutor(max_workers=1))


async def sleep_forever(*args, **kwargs):
await asyncio.sleep(float("inf"))


def test_init(subscriber: SubscriberImpl, async_subscriber, close_callback):
async_subscriber.read.side_effect = sleep_forever
subscriber.add_close_callback(close_callback)
subscriber.__enter__()
async_subscriber.__aenter__.assert_called_once()
subscriber.close()
async_subscriber.__aexit__.assert_called_once()
close_callback.assert_called_once_with(subscriber, None)


def test_failed(subscriber: SubscriberImpl, async_subscriber, close_callback):
error = FailedPrecondition("bad read")
async_subscriber.read.side_effect = error

close_called = concurrent.futures.Future()
close_callback.side_effect = lambda manager, err: close_called.set_result(None)

subscriber.add_close_callback(close_callback)
subscriber.__enter__()
async_subscriber.__aenter__.assert_called_once()
close_called.result()
async_subscriber.__aexit__.assert_called_once()
close_callback.assert_called_once_with(subscriber, error)


def test_messages_received(subscriber: SubscriberImpl, async_subscriber, message_callback, close_callback):
message1 = Message(PubsubMessage(message_id="1")._pb, "", 0, None)
message2 = Message(PubsubMessage(message_id="2")._pb, "", 0, None)

counter = Box[int]()
counter.val = 0

async def on_read() -> Message:
counter.val += 1
if counter.val == 1:
return message1
if counter.val == 2:
return message2
await sleep_forever()

async_subscriber.read.side_effect = on_read

results = Queue()
message_callback.side_effect = lambda m: results.put(m.message_id)

subscriber.add_close_callback(close_callback)
subscriber.__enter__()
assert results.get() == "1"
assert results.get() == "2"
subscriber.close()