From 697df4a604c5b03378fbb9327f4f041c2d1949ce Mon Sep 17 00:00:00 2001 From: dpcollins-google <40498610+dpcollins-google@users.noreply.github.com> Date: Tue, 15 Sep 2020 16:38:04 -0400 Subject: [PATCH] feat: Implement Subscriber, which handles flow control and batch message processing. (#16) * feat: Implement Subscriber, which handles flow control and batch message processing. Also ensure all asynchronous loopers are torn down when their underlying objects are. --- .../pubsublite/internal/wire/assigner_impl.py | 4 +- .../internal/wire/committer_impl.py | 10 +- .../internal/wire/routing_publisher.py | 1 + .../wire/single_partition_publisher.py | 4 +- .../pubsublite/internal/wire/subscriber.py | 28 ++ .../internal/wire/subscriber_impl.py | 135 ++++++++ .../internal/wire/subscriber_impl_test.py | 296 ++++++++++++++++++ 7 files changed, 473 insertions(+), 5 deletions(-) create mode 100644 google/cloud/pubsublite/internal/wire/subscriber.py create mode 100644 google/cloud/pubsublite/internal/wire/subscriber_impl.py create mode 100644 tests/unit/pubsublite/internal/wire/subscriber_impl_test.py diff --git a/google/cloud/pubsublite/internal/wire/assigner_impl.py b/google/cloud/pubsublite/internal/wire/assigner_impl.py index 65a0117c..c754d0c2 100644 --- a/google/cloud/pubsublite/internal/wire/assigner_impl.py +++ b/google/cloud/pubsublite/internal/wire/assigner_impl.py @@ -39,6 +39,7 @@ def __init__(self, initial: InitialPartitionAssignmentRequest, async def __aenter__(self): await self._connection.__aenter__() + return self def _start_receiver(self): assert self._receiver is None @@ -63,10 +64,11 @@ async def _receive_loop(self): for partition in response.partitions: partitions.add(Partition(partition)) self._new_assignment.put_nowait(partitions) - except asyncio.CancelledError: + except (asyncio.CancelledError, GoogleAPICallError): return async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stop_receiver() await self._connection.__aexit__(exc_type, exc_val, exc_tb) async def reinitialize(self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment]): diff --git a/google/cloud/pubsublite/internal/wire/committer_impl.py b/google/cloud/pubsublite/internal/wire/committer_impl.py index 36a84a1c..dd3a907e 100644 --- a/google/cloud/pubsublite/internal/wire/committer_impl.py +++ b/google/cloud/pubsublite/internal/wire/committer_impl.py @@ -10,11 +10,13 @@ from google.cloud.pubsublite.internal.wire.connection import Connection from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher, BatchTester from google.cloud.pubsublite_v1 import Cursor -from google.cloud.pubsublite_v1.types import StreamingCommitCursorRequest, StreamingCommitCursorResponse, InitialCommitCursorRequest +from google.cloud.pubsublite_v1.types import StreamingCommitCursorRequest, StreamingCommitCursorResponse, \ + InitialCommitCursorRequest from google.cloud.pubsublite.internal.wire.work_item import WorkItem -class CommitterImpl(Committer, ConnectionReinitializer[StreamingCommitCursorRequest, StreamingCommitCursorResponse], BatchTester[Cursor]): +class CommitterImpl(Committer, ConnectionReinitializer[StreamingCommitCursorRequest, StreamingCommitCursorResponse], + BatchTester[Cursor]): _initial: InitialCommitCursorRequest _flush_seconds: float _connection: RetryingConnection[StreamingCommitCursorRequest, StreamingCommitCursorResponse] @@ -38,6 +40,7 @@ def __init__(self, initial: InitialCommitCursorRequest, flush_seconds: float, async def __aenter__(self): await self._connection.__aenter__() + return self def _start_loopers(self): assert self._receiver is None @@ -71,7 +74,7 @@ async def _receive_loop(self): while True: response = await self._connection.read() self._handle_response(response) - except asyncio.CancelledError: + except (asyncio.CancelledError, GoogleAPICallError): return async def _flush_loop(self): @@ -83,6 +86,7 @@ async def _flush_loop(self): return async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stop_loopers() if self._connection.error(): self._fail_if_retrying_failed() else: diff --git a/google/cloud/pubsublite/internal/wire/routing_publisher.py b/google/cloud/pubsublite/internal/wire/routing_publisher.py index 3f269550..6b8c0b77 100644 --- a/google/cloud/pubsublite/internal/wire/routing_publisher.py +++ b/google/cloud/pubsublite/internal/wire/routing_publisher.py @@ -18,6 +18,7 @@ def __init__(self, routing_policy: RoutingPolicy, publishers: Mapping[Partition, async def __aenter__(self): for publisher in self._publishers.values(): await publisher.__aenter__() + return self async def __aexit__(self, exc_type, exc_val, exc_tb): for publisher in self._publishers.values(): diff --git a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py index 081afa70..6c6edcc6 100644 --- a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py +++ b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py @@ -48,6 +48,7 @@ def _partition(self) -> Partition: async def __aenter__(self): await self._connection.__aenter__() + return self def _start_loopers(self): assert self._receiver is None @@ -82,7 +83,7 @@ async def _receive_loop(self): while True: response = await self._connection.read() self._handle_response(response) - except asyncio.CancelledError: + except (asyncio.CancelledError, GoogleAPICallError): return async def _flush_loop(self): @@ -98,6 +99,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self._fail_if_retrying_failed() else: await self._flush() + await self._stop_loopers() await self._connection.__aexit__(exc_type, exc_val, exc_tb) def _fail_if_retrying_failed(self): diff --git a/google/cloud/pubsublite/internal/wire/subscriber.py b/google/cloud/pubsublite/internal/wire/subscriber.py new file mode 100644 index 00000000..0e02024d --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/subscriber.py @@ -0,0 +1,28 @@ +from abc import abstractmethod +from typing import AsyncContextManager +from google.cloud.pubsublite_v1.types import SequencedMessage, FlowControlRequest + + +class Subscriber(AsyncContextManager): + """ + A Pub/Sub Lite asynchronous wire protocol subscriber. + """ + @abstractmethod + async def read(self) -> SequencedMessage: + """ + Read the next message off of the stream. + + Returns: + The next message. + + Raises: + GoogleAPICallError: On a permanent error. + """ + raise NotImplementedError() + + @abstractmethod + async def allow_flow(self, request: FlowControlRequest): + """ + Allow an additional amount of messages and bytes to be sent to this client. + """ + raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/subscriber_impl.py b/google/cloud/pubsublite/internal/wire/subscriber_impl.py new file mode 100644 index 00000000..4a28a8be --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/subscriber_impl.py @@ -0,0 +1,135 @@ +import asyncio +from typing import Optional + +from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition + +from google.cloud.pubsublite.internal.wire.connection import Request, Connection, Response, ConnectionFactory +from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer +from google.cloud.pubsublite.internal.wire.flow_control_batcher import FlowControlBatcher +from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection +from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite_v1 import SubscribeRequest, SubscribeResponse, FlowControlRequest, SequencedMessage, \ + InitialSubscribeRequest, SeekRequest, Cursor + + +class SubscriberImpl(Subscriber, ConnectionReinitializer[SubscribeRequest, SubscribeResponse]): + _initial: InitialSubscribeRequest + _token_flush_seconds: float + _connection: RetryingConnection[SubscribeRequest, SubscribeResponse] + + _outstanding_flow_control: FlowControlBatcher + + _reinitializing: bool + _last_received_offset: Optional[int] + + _message_queue: 'asyncio.Queue[SequencedMessage]' + + _receiver: Optional[asyncio.Future] + _flusher: Optional[asyncio.Future] + + def __init__(self, initial: InitialSubscribeRequest, token_flush_seconds: float, + factory: ConnectionFactory[SubscribeRequest, SubscribeResponse]): + self._initial = initial + self._token_flush_seconds = token_flush_seconds + self._connection = RetryingConnection(factory, self) + self._outstanding_flow_control = FlowControlBatcher() + self._reinitializing = False + self._last_received_offset = None + self._message_queue = asyncio.Queue() + self._receiver = None + self._flusher = None + + async def __aenter__(self): + await self._connection.__aenter__() + return self + + def _start_loopers(self): + assert self._receiver is None + assert self._flusher is None + self._receiver = asyncio.ensure_future(self._receive_loop()) + self._flusher = asyncio.ensure_future(self._flush_loop()) + + async def _stop_loopers(self): + if self._receiver: + self._receiver.cancel() + await self._receiver + self._receiver = None + if self._flusher: + self._flusher.cancel() + await self._flusher + self._flusher = None + + def _handle_response(self, response: SubscribeResponse): + if "messages" not in response: + self._connection.fail(FailedPrecondition("Received an invalid subsequent response on the subscribe stream.")) + return + self._outstanding_flow_control.on_messages(response.messages.messages) + for message in response.messages.messages: + if self._last_received_offset is not None and message.cursor.offset <= self._last_received_offset: + self._connection.fail(FailedPrecondition( + "Received an invalid out of order message from the server. Message is {}, previous last received is {}.".format( + message.cursor.offset, self._last_received_offset))) + return + self._last_received_offset = message.cursor.offset + for message in response.messages.messages: + # queue is unbounded. + self._message_queue.put_nowait(message) + + async def _receive_loop(self): + try: + while True: + response = await self._connection.read() + self._handle_response(response) + except (asyncio.CancelledError, GoogleAPICallError): + return + + async def _try_send_tokens(self): + req = self._outstanding_flow_control.release_pending_request() + if req is None: + return + try: + await self._connection.write(SubscribeRequest(flow_control=req)) + except GoogleAPICallError: + # May be transient, in which case these tokens will be resent. + pass + + async def _flush_loop(self): + try: + while True: + await asyncio.sleep(self._token_flush_seconds) + await self._try_send_tokens() + except asyncio.CancelledError: + return + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stop_loopers() + await self._connection.__aexit__(exc_type, exc_val, exc_tb) + + async def reinitialize(self, connection: Connection[SubscribeRequest, SubscribeResponse]): + self._reinitializing = True + await self._stop_loopers() + await connection.write(SubscribeRequest(initial=self._initial)) + response = await connection.read() + if "initial" not in response: + self._connection.fail(FailedPrecondition("Received an invalid initial response on the subscribe stream.")) + return + if self._last_received_offset is not None: + # Perform a seek to get the next message after the one we received. + await connection.write(SubscribeRequest(seek=SeekRequest(cursor=Cursor(offset=self._last_received_offset + 1)))) + seek_response = await connection.read() + if "seek" not in seek_response: + self._connection.fail(FailedPrecondition("Received an invalid seek response on the subscribe stream.")) + return + tokens = self._outstanding_flow_control.request_for_restart() + if tokens is not None: + await connection.write(SubscribeRequest(flow_control=tokens)) + self._reinitializing = False + self._start_loopers() + + async def read(self) -> SequencedMessage: + return await self._connection.await_unless_failed(self._message_queue.get()) + + async def allow_flow(self, request: FlowControlRequest): + self._outstanding_flow_control.add(request) + if not self._reinitializing and self._outstanding_flow_control.should_expedite(): + await self._try_send_tokens() diff --git a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py new file mode 100644 index 00000000..25564f48 --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py @@ -0,0 +1,296 @@ +import asyncio +from unittest.mock import call +from collections import defaultdict +from typing import Dict, List + +from asynctest.mock import MagicMock, CoroutineMock +import pytest +from grpc import StatusCode + +from google.cloud.pubsublite.internal.wire.connection import Connection, ConnectionFactory +from google.api_core.exceptions import InternalServerError, GoogleAPICallError + +from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite.internal.wire.subscriber_impl import SubscriberImpl +from google.cloud.pubsublite_v1 import SubscribeRequest, SubscribeResponse, InitialSubscribeRequest, FlowControlRequest, \ + SeekRequest +from google.cloud.pubsublite_v1.types.common import Cursor, SequencedMessage +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter +from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS + +FLUSH_SECONDS = 100000 + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def default_connection(): + conn = MagicMock(spec=Connection[SubscribeRequest, SubscribeResponse]) + conn.__aenter__.return_value = conn + return conn + + +@pytest.fixture() +def connection_factory(default_connection): + factory = MagicMock(spec=ConnectionFactory[SubscribeRequest, SubscribeResponse]) + factory.new.return_value = default_connection + return factory + + +@pytest.fixture() +def initial_request(): + return SubscribeRequest(initial=InitialSubscribeRequest(subscription="mysub")) + + +class QueuePair: + called: asyncio.Queue + results: asyncio.Queue + + def __init__(self): + self.called = asyncio.Queue() + self.results = asyncio.Queue() + + +@pytest.fixture +def sleep_queues() -> Dict[float, QueuePair]: + return defaultdict(QueuePair) + + +@pytest.fixture +def asyncio_sleep(monkeypatch, sleep_queues): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + mock = CoroutineMock() + monkeypatch.setattr(asyncio, "sleep", mock) + + async def sleeper(delay: float): + await make_queue_waiter(sleep_queues[delay].called, sleep_queues[delay].results)(delay) + + mock.side_effect = sleeper + return mock + + +@pytest.fixture() +def subscriber(connection_factory, initial_request): + return SubscriberImpl(initial_request.initial, FLUSH_SECONDS, connection_factory) + + +def as_request(flow: FlowControlRequest): + req = SubscribeRequest() + req.flow_control = flow + return req + + +def as_response(messages: List[SequencedMessage]): + res = SubscribeResponse() + res.messages.messages = messages + return res + + +async def test_basic_flow_control_after_timeout( + subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, sleep_queues): + sleep_called = sleep_queues[FLUSH_SECONDS].called + sleep_results = sleep_queues[FLUSH_SECONDS].results + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow_1 = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + flow_2 = FlowControlRequest(allowed_messages=5, allowed_bytes=10) + flow_3 = FlowControlRequest(allowed_messages=10, allowed_bytes=5) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut1 = asyncio.ensure_future(subscriber.allow_flow(flow_1)) + assert not flow_fut1.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut1 + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1))]) + + # Should complete without writing to the connection + await subscriber.allow_flow(flow_2) + await subscriber.allow_flow(flow_3) + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + + # Handle the connection write + await sleep_results.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + # Called with aggregate + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1)), call( + as_request(FlowControlRequest(allowed_messages=15, allowed_bytes=15)))]) + + +async def test_flow_resent_on_restart(subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, + sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow_1 = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + flow_2 = FlowControlRequest(allowed_messages=5, allowed_bytes=10) + flow_3 = FlowControlRequest(allowed_messages=10, allowed_bytes=5) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut1 = asyncio.ensure_future(subscriber.allow_flow(flow_1)) + assert not flow_fut1.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut1 + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1))]) + + # Should complete without writing to the connection + await subscriber.allow_flow(flow_2) + await subscriber.allow_flow(flow_3) + + # Fail the connection with a retryable error + await read_called_queue.get() + await read_result_queue.put(InternalServerError("retryable")) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + # Reinitialization + await write_called_queue.get() + await write_result_queue.put(None) + await read_called_queue.get() + await read_result_queue.put(SubscribeResponse(initial={})) + # Re-sending flow tokens on the new stream + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1)), call(initial_request), + call(as_request( + FlowControlRequest(allowed_messages=115, allowed_bytes=115)))]) + + +async def test_message_receipt(subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, + sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5) + message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) + assert not flow_fut.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow))]) + + message1_fut = asyncio.ensure_future(subscriber.read()) + + # Send messages to the subscriber. + await read_result_queue.put(as_response([message_1, message_2])) + # Wait for the next read call + await read_called_queue.get() + + assert (await message1_fut) == message_1 + assert (await subscriber.read()) == message_2 + + # Fail the connection with a retryable error + await read_called_queue.get() + await read_result_queue.put(InternalServerError("retryable")) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + # Reinitialization + await write_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow)), call(initial_request)]) + await write_result_queue.put(None) + await read_called_queue.get() + await read_result_queue.put(SubscribeResponse(initial={})) + # Sends fetch offset seek on the stream, and checks the response. + seek_req = SubscribeRequest(seek=SeekRequest(cursor=Cursor(offset=message_2.cursor.offset + 1))) + await write_called_queue.get() + default_connection.write.assert_has_calls( + [call(initial_request), call(as_request(flow)), call(initial_request), call(seek_req)]) + await write_result_queue.put(None) + await read_called_queue.get() + await read_result_queue.put(SubscribeResponse(seek={})) + # Re-sending flow tokens on the new stream. + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls( + [call(initial_request), call(as_request(flow)), call(initial_request), call(seek_req), + call(as_request(FlowControlRequest(allowed_messages=98, allowed_bytes=85)))]) + + +async def test_out_of_order_receipt_failure(subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, + sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5) + message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) + assert not flow_fut.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow))]) + + read_fut = asyncio.ensure_future(subscriber.read()) + + # Send out of order messages to the subscriber. + await read_result_queue.put(as_response([message_2, message_1])) + # Wait for the next read call + await read_called_queue.get() + + try: + await read_fut + assert False + except GoogleAPICallError as e: + assert e.grpc_status_code == StatusCode.FAILED_PRECONDITION + pass