diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py index 518120a1..c1f73b7c 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py @@ -45,3 +45,12 @@ async def ack(self, offset: int): Returns: GoogleAPICallError: On a commit failure. """ + + @abstractmethod + async def clear_and_commit(self): + """ + Discard all outstanding acks and wait for the commit offset to be acknowledged by the server. + + Raises: + GoogleAPICallError: If the committer has shut down due to a permanent error. + """ diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py index d271bed1..3b10d3c1 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py @@ -62,6 +62,11 @@ async def ack(self, offset: int): # Convert from last acked to first unacked. await self._committer.commit(Cursor(offset=prefix_acked_offset + 1)) + async def clear_and_commit(self): + self._receipts.clear() + self._acks = queue.PriorityQueue() + await self._committer.wait_until_empty() + async def __aenter__(self): await self._committer.__aenter__() diff --git a/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py index 4d4f52ce..e51a5e64 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py @@ -52,6 +52,9 @@ from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata from google.cloud.pubsublite.internal.wire.pubsub_context import pubsub_context import google.cloud.pubsublite.internal.wire.subscriber_impl as wire_subscriber +from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import ( + SubscriberResetHandler, +) from google.cloud.pubsublite.types import Partition, SubscriptionPath from google.cloud.pubsublite.internal.routing_metadata import ( subscription_routing_metadata, @@ -131,13 +134,16 @@ def cursor_connection_factory( requests, metadata=list(final_metadata.items()) ) - subscriber = wire_subscriber.SubscriberImpl( - InitialSubscribeRequest( - subscription=str(subscription), partition=partition.value - ), - _DEFAULT_FLUSH_SECONDS, - GapicConnectionFactory(subscribe_connection_factory), - ) + def subscriber_factory(reset_handler: SubscriberResetHandler): + return wire_subscriber.SubscriberImpl( + InitialSubscribeRequest( + subscription=str(subscription), partition=partition.value + ), + _DEFAULT_FLUSH_SECONDS, + GapicConnectionFactory(subscribe_connection_factory), + reset_handler, + ) + committer = CommitterImpl( InitialCommitCursorRequest( subscription=str(subscription), partition=partition.value @@ -147,7 +153,7 @@ def cursor_connection_factory( ) ack_set_tracker = AckSetTrackerImpl(committer) return SinglePartitionSingleSubscriber( - subscriber, + subscriber_factory, flow_control_settings, ack_set_tracker, nack_handler, diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py index 5cdcfcd5..1e244aa4 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py @@ -13,7 +13,8 @@ # limitations under the License. import asyncio -from typing import Union, Dict, NamedTuple +import json +from typing import Callable, Union, Dict, NamedTuple import queue from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError @@ -30,6 +31,9 @@ ) from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import ( + SubscriberResetHandler, +) from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage from google.cloud.pubsub_v1.subscriber._protocol import requests @@ -39,7 +43,27 @@ class _SizedMessage(NamedTuple): size_bytes: int -class SinglePartitionSingleSubscriber(PermanentFailable, AsyncSingleSubscriber): +class _AckId(NamedTuple): + generation: int + offset: int + + def str(self) -> str: + return json.dumps({"generation": self.generation, "offset": self.offset}) + + @staticmethod + def parse(payload: str) -> "_AckId": + loaded = json.loads(payload) + return _AckId( + generation=int(loaded["generation"]), offset=int(loaded["offset"]), + ) + + +ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber] + + +class SinglePartitionSingleSubscriber( + PermanentFailable, AsyncSingleSubscriber, SubscriberResetHandler +): _underlying: Subscriber _flow_control_settings: FlowControlSettings _ack_set_tracker: AckSetTracker @@ -47,26 +71,33 @@ class SinglePartitionSingleSubscriber(PermanentFailable, AsyncSingleSubscriber): _transformer: MessageTransformer _queue: queue.Queue - _messages_by_offset: Dict[int, _SizedMessage] + _ack_generation_id: int + _messages_by_ack_id: Dict[str, _SizedMessage] _looper_future: asyncio.Future def __init__( self, - underlying: Subscriber, + subscriber_factory: ResettableSubscriberFactory, flow_control_settings: FlowControlSettings, ack_set_tracker: AckSetTracker, nack_handler: NackHandler, transformer: MessageTransformer, ): super().__init__() - self._underlying = underlying + self._underlying = subscriber_factory(self) self._flow_control_settings = flow_control_settings self._ack_set_tracker = ack_set_tracker self._nack_handler = nack_handler self._transformer = transformer self._queue = queue.Queue() - self._messages_by_offset = {} + self._ack_generation_id = 0 + self._messages_by_ack_id = {} + + async def handle_reset(self): + # Increment ack generation id to ignore unacked messages. + ++self._ack_generation_id + await self._ack_set_tracker.clear_and_commit() async def read(self) -> Message: try: @@ -75,13 +106,14 @@ async def read(self) -> Message: ) cps_message = self._transformer.transform(message) offset = message.cursor.offset + ack_id = _AckId(self._ack_generation_id, offset) self._ack_set_tracker.track(offset) - self._messages_by_offset[offset] = _SizedMessage( + self._messages_by_ack_id[ack_id.str()] = _SizedMessage( cps_message, message.size_bytes ) wrapped_message = Message( cps_message._pb, - ack_id=str(offset), + ack_id=ack_id.str(), delivery_attempt=0, request_queue=self._queue, ) @@ -91,22 +123,23 @@ async def read(self) -> Message: raise e async def _handle_ack(self, message: requests.AckRequest): - offset = int(message.ack_id) await self._underlying.allow_flow( FlowControlRequest( allowed_messages=1, - allowed_bytes=self._messages_by_offset[offset].size_bytes, + allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes, ) ) - del self._messages_by_offset[offset] - try: - await self._ack_set_tracker.ack(offset) - except GoogleAPICallError as e: - self.fail(e) + del self._messages_by_ack_id[message.ack_id] + # Always refill flow control tokens, but do not commit offsets from outdated generations. + ack_id = _AckId.parse(message.ack_id) + if ack_id.generation == self._ack_generation_id: + try: + await self._ack_set_tracker.ack(ack_id.offset) + except GoogleAPICallError as e: + self.fail(e) def _handle_nack(self, message: requests.NackRequest): - offset = int(message.ack_id) - sized_message = self._messages_by_offset[offset] + sized_message = self._messages_by_ack_id[message.ack_id] try: # Put the ack request back into the queue since the callback may be called from another thread. self._nack_handler.on_nack( diff --git a/google/cloud/pubsublite/internal/wire/committer.py b/google/cloud/pubsublite/internal/wire/committer.py index 23e8e96a..2f52c8d1 100644 --- a/google/cloud/pubsublite/internal/wire/committer.py +++ b/google/cloud/pubsublite/internal/wire/committer.py @@ -26,3 +26,13 @@ class Committer(AsyncContextManager): @abstractmethod async def commit(self, cursor: Cursor) -> None: pass + + @abstractmethod + async def wait_until_empty(self): + """ + Flushes pending commits and waits for all outstanding commit responses from the server. + + Raises: + GoogleAPICallError: When the committer terminates in failure. + """ + pass diff --git a/google/cloud/pubsublite/internal/wire/committer_impl.py b/google/cloud/pubsublite/internal/wire/committer_impl.py index 8d2451b2..4bf0c021 100644 --- a/google/cloud/pubsublite/internal/wire/committer_impl.py +++ b/google/cloud/pubsublite/internal/wire/committer_impl.py @@ -63,6 +63,7 @@ class CommitterImpl( _receiver: Optional[asyncio.Future] _flusher: Optional[asyncio.Future] + _empty: asyncio.Event def __init__( self, @@ -79,6 +80,8 @@ def __init__( self._outstanding_commits = [] self._receiver = None self._flusher = None + self._empty = asyncio.Event() + self._empty.set() async def __aenter__(self): await self._connection.__aenter__() @@ -117,6 +120,8 @@ def _handle_response(self, response: StreamingCommitCursorResponse): batch = self._outstanding_commits.pop(0) for item in batch: item.response_future.set_result(None) + if len(self._outstanding_commits) == 0: + self._empty.set() async def _receive_loop(self): while True: @@ -147,6 +152,7 @@ async def _flush(self): if not batch: return self._outstanding_commits.append(batch) + self._empty.clear() req = StreamingCommitCursorRequest() req.commit.cursor = batch[-1].request try: @@ -155,6 +161,10 @@ async def _flush(self): _LOGGER.debug(f"Failed commit on stream: {e}") self._fail_if_retrying_failed() + async def wait_until_empty(self): + await self._flush() + await self._connection.await_unless_failed(self._empty.wait()) + async def commit(self, cursor: Cursor) -> None: future = self._batcher.add(cursor) if self._batcher.should_flush(): diff --git a/google/cloud/pubsublite/internal/wire/reset_signal.py b/google/cloud/pubsublite/internal/wire/reset_signal.py new file mode 100644 index 00000000..ce4fa3d5 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/reset_signal.py @@ -0,0 +1,42 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.api_core.exceptions import GoogleAPICallError +from google.cloud.pubsublite.internal.status_codes import is_retryable +from google.rpc.error_details_pb2 import ErrorInfo +from grpc_status import rpc_status + + +def is_reset_signal(error: GoogleAPICallError) -> bool: + """ + Determines whether the given error contains the stream RESET signal, sent by + the server to instruct clients to reset stream state. + + Returns: True if the error contains the RESET signal. + """ + if not is_retryable(error) or not error.response: + return False + try: + status = rpc_status.from_call(error.response) + for detail in status.details: + info = ErrorInfo() + if ( + detail.Unpack(info) + and info.reason == "RESET" + and info.domain == "pubsublite.googleapis.com" + ): + return True + except ValueError: + pass + return False diff --git a/google/cloud/pubsublite/internal/wire/subscriber_impl.py b/google/cloud/pubsublite/internal/wire/subscriber_impl.py index cb8f268a..89b9a2a2 100644 --- a/google/cloud/pubsublite/internal/wire/subscriber_impl.py +++ b/google/cloud/pubsublite/internal/wire/subscriber_impl.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from copy import deepcopy from typing import Optional from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition @@ -28,6 +29,7 @@ from google.cloud.pubsublite.internal.wire.flow_control_batcher import ( FlowControlBatcher, ) +from google.cloud.pubsublite.internal.wire.reset_signal import is_reset_signal from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection from google.cloud.pubsublite.internal.wire.subscriber import Subscriber from google.cloud.pubsublite_v1 import ( @@ -39,14 +41,18 @@ SeekRequest, Cursor, ) +from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import ( + SubscriberResetHandler, +) class SubscriberImpl( Subscriber, ConnectionReinitializer[SubscribeRequest, SubscribeResponse] ): - _initial: InitialSubscribeRequest + _base_initial: InitialSubscribeRequest _token_flush_seconds: float _connection: RetryingConnection[SubscribeRequest, SubscribeResponse] + _reset_handler: SubscriberResetHandler _outstanding_flow_control: FlowControlBatcher @@ -60,13 +66,15 @@ class SubscriberImpl( def __init__( self, - initial: InitialSubscribeRequest, + base_initial: InitialSubscribeRequest, token_flush_seconds: float, factory: ConnectionFactory[SubscribeRequest, SubscribeResponse], + reset_handler: SubscriberResetHandler, ): - self._initial = initial + self._base_initial = base_initial self._token_flush_seconds = token_flush_seconds self._connection = RetryingConnection(factory, self) + self._reset_handler = reset_handler self._outstanding_flow_control = FlowControlBatcher() self._reinitializing = False self._last_received_offset = None @@ -152,7 +160,27 @@ async def reinitialize( ): self._reinitializing = True await self._stop_loopers() - await connection.write(SubscribeRequest(initial=self._initial)) + if last_error and is_reset_signal(last_error): + # Discard undelivered messages and refill flow control tokens. + while not self._message_queue.empty(): + msg = self._message_queue.get_nowait() + self._outstanding_flow_control.add( + FlowControlRequest( + allowed_messages=1, allowed_bytes=msg.size_bytes, + ) + ) + await self._reset_handler.handle_reset() + self._last_received_offset = None + initial = deepcopy(self._base_initial) + if self._last_received_offset is not None: + initial.initial_location = SeekRequest( + cursor=Cursor(offset=self._last_received_offset + 1) + ) + else: + initial.initial_location = SeekRequest( + named_target=SeekRequest.NamedTarget.COMMITTED_CURSOR + ) + await connection.write(SubscribeRequest(initial=initial)) response = await connection.read() if "initial" not in response: self._connection.fail( @@ -161,23 +189,6 @@ async def reinitialize( ) ) return - if self._last_received_offset is not None: - # Perform a seek to get the next message after the one we received. - await connection.write( - SubscribeRequest( - seek=SeekRequest( - cursor=Cursor(offset=self._last_received_offset + 1) - ) - ) - ) - seek_response = await connection.read() - if "seek" not in seek_response: - self._connection.fail( - FailedPrecondition( - "Received an invalid seek response on the subscribe stream." - ) - ) - return tokens = self._outstanding_flow_control.request_for_restart() if tokens is not None: await connection.write(SubscribeRequest(flow_control=tokens)) diff --git a/google/cloud/pubsublite/internal/wire/subscriber_reset_handler.py b/google/cloud/pubsublite/internal/wire/subscriber_reset_handler.py new file mode 100644 index 00000000..695255f4 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/subscriber_reset_handler.py @@ -0,0 +1,28 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABCMeta, abstractmethod + + +class SubscriberResetHandler(metaclass=ABCMeta): + """Helps to reset subscriber state when the `RESET` signal is received from the server.""" + + @abstractmethod + async def handle_reset(self): + """Reset subscriber state. + + Raises: + GoogleAPICallError: If reset handling fails. The subscriber will shut down. + """ + raise NotImplementedError() diff --git a/google/cloud/pubsublite/testing/test_reset_signal.py b/google/cloud/pubsublite/testing/test_reset_signal.py new file mode 100644 index 00000000..97a45c42 --- /dev/null +++ b/google/cloud/pubsublite/testing/test_reset_signal.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asynctest.mock import MagicMock +from google.api_core.exceptions import Aborted, GoogleAPICallError +from google.protobuf.any_pb2 import Any +from google.rpc.error_details_pb2 import ErrorInfo +from google.rpc.status_pb2 import Status +import grpc +from grpc_status import rpc_status + + +def make_call(status_pb: Status) -> grpc.Call: + status = rpc_status.to_status(status_pb) + mock_call = MagicMock(spec=grpc.Call) + mock_call.details.return_value = status.details + mock_call.code.return_value = status.code + mock_call.trailing_metadata.return_value = status.trailing_metadata + return mock_call + + +def make_reset_signal() -> GoogleAPICallError: + any = Any() + any.Pack(ErrorInfo(reason="RESET", domain="pubsublite.googleapis.com")) + status_pb = Status(code=10, details=[any]) + return Aborted("", response=make_call(status_pb)) diff --git a/setup.py b/setup.py index 14a34701..7f0fd1da 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,8 @@ dependencies = [ "google-cloud-pubsub >= 2.1.0, <3.0.0dev", + "grpcio >= 1.18.0", + "grpcio-status >= 1.18.0", "overrides>=6.0.1, <7.0.0", ] diff --git a/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py b/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py index 5eb3d053..52a5d317 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py @@ -15,6 +15,8 @@ from asynctest.mock import MagicMock, call import pytest +from google.api_core.exceptions import FailedPrecondition + # All test coroutines will be treated as marked. from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import ( @@ -60,3 +62,24 @@ async def test_track_and_aggregate_acks(committer, tracker: AckSetTracker): [call(Cursor(offset=6)), call(Cursor(offset=8))] ) committer.__aexit__.assert_called_once() + + +async def test_clear_and_commit(committer, tracker: AckSetTracker): + async with tracker: + committer.__aenter__.assert_called_once() + tracker.track(offset=3) + tracker.track(offset=5) + + with pytest.raises(FailedPrecondition): + tracker.track(offset=1) + await tracker.ack(offset=5) + committer.commit.assert_has_calls([]) + + await tracker.clear_and_commit() + committer.wait_until_empty.assert_called_once() + + # After clearing, it should be possible to track earlier offsets. + tracker.track(offset=1) + await tracker.ack(offset=1) + committer.commit.assert_has_calls([call(Cursor(offset=2))]) + committer.__aexit__.assert_called_once() diff --git a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py index 55f6fe30..f2a43928 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py @@ -32,6 +32,9 @@ AsyncSingleSubscriber, ) from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import ( + SubscriberResetHandler, +) from google.cloud.pubsublite.testing.test_utils import make_queue_waiter from google.cloud.pubsublite_v1 import Cursor, FlowControlRequest, SequencedMessage @@ -85,8 +88,15 @@ def transformer(): def subscriber( underlying, flow_control_settings, ack_set_tracker, nack_handler, transformer ): + def subscriber_factory(reset_handler: SubscriberResetHandler): + return underlying + return SinglePartitionSingleSubscriber( - underlying, flow_control_settings, ack_set_tracker, nack_handler, transformer + subscriber_factory, + flow_control_settings, + ack_set_tracker, + nack_handler, + transformer, ) @@ -230,3 +240,55 @@ def on_nack(nacked: PubsubMessage, ack: Callable[[], None]): await ack_called_queue.get() await ack_result_queue.put(None) ack_set_tracker.ack.assert_has_calls([call(1)]) + + +async def test_handle_reset( + subscriber: SinglePartitionSingleSubscriber, + underlying, + transformer, + ack_set_tracker, +): + ack_called_queue = asyncio.Queue() + ack_result_queue = asyncio.Queue() + ack_set_tracker.ack.side_effect = make_queue_waiter( + ack_called_queue, ack_result_queue + ) + async with subscriber: + message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) + underlying.read.return_value = message_1 + read_1: Message = await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1)]) + assert read_1.message_id == "1" + + await subscriber.handle_reset() + ack_set_tracker.clear_and_commit.assert_called_once() + + # After reset, flow control tokens of unacked messages are refilled, + # but offset not committed. + read_1.ack() + await ack_called_queue.get() + await ack_result_queue.put(None) + underlying.allow_flow.assert_has_calls( + [ + call(FlowControlRequest(allowed_messages=1000, allowed_bytes=1000,)), + call(FlowControlRequest(allowed_messages=1, allowed_bytes=5,)), + ] + ) + ack_set_tracker.ack.assert_has_calls([]) + + message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10) + underlying.read.return_value = message_2 + read_2: Message = await subscriber.read() + ack_set_tracker.track.assert_has_calls([call(1), call(2)]) + assert read_2.message_id == "2" + read_2.ack() + await ack_called_queue.get() + await ack_result_queue.put(None) + underlying.allow_flow.assert_has_calls( + [ + call(FlowControlRequest(allowed_messages=1000, allowed_bytes=1000,)), + call(FlowControlRequest(allowed_messages=1, allowed_bytes=5,)), + call(FlowControlRequest(allowed_messages=1, allowed_bytes=10,)), + ] + ) + ack_set_tracker.ack.assert_has_calls([call(2)]) diff --git a/tests/unit/pubsublite/internal/wire/committer_impl_test.py b/tests/unit/pubsublite/internal/wire/committer_impl_test.py index c1105d46..773b2b85 100644 --- a/tests/unit/pubsublite/internal/wire/committer_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/committer_impl_test.py @@ -26,7 +26,7 @@ Connection, ConnectionFactory, ) -from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import InternalServerError, InvalidArgument from google.cloud.pubsublite_v1.types.cursor import ( StreamingCommitCursorRequest, StreamingCommitCursorResponse, @@ -141,8 +141,10 @@ async def test_basic_commit_after_timeout( # Commit cursors commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) commit_fut2 = asyncio.ensure_future(committer.commit(cursor2)) + empty_fut = asyncio.ensure_future(committer.wait_until_empty()) assert not commit_fut1.done() assert not commit_fut2.done() + assert not empty_fut.done() # Wait for writes to be waiting await sleep_called.get() @@ -158,11 +160,13 @@ async def test_basic_commit_after_timeout( ) assert not commit_fut1.done() assert not commit_fut2.done() + assert not empty_fut.done() # Send the connection response with 1 ack since only one request was sent. await read_result_queue.put(as_response(count=1)) await commit_fut1 await commit_fut2 + await empty_fut async def test_commits_multi_cycle( @@ -196,7 +200,9 @@ async def test_commits_multi_cycle( # Write message 1 commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) + empty_fut = asyncio.ensure_future(committer.wait_until_empty()) assert not commit_fut1.done() + assert not empty_fut.done() # Wait for writes to be waiting await sleep_called.get() @@ -210,6 +216,7 @@ async def test_commits_multi_cycle( [call(initial_request), call(as_request(cursor1))] ) assert not commit_fut1.done() + assert not empty_fut.done() # Wait for writes to be waiting await sleep_called.get() @@ -218,6 +225,7 @@ async def test_commits_multi_cycle( # Write message 2 commit_fut2 = asyncio.ensure_future(committer.commit(cursor2)) assert not commit_fut2.done() + assert not empty_fut.done() # Handle the connection write await sleep_results.put(None) @@ -232,11 +240,13 @@ async def test_commits_multi_cycle( ) assert not commit_fut1.done() assert not commit_fut2.done() + assert not empty_fut.done() # Send the connection responses await read_result_queue.put(as_response(count=2)) await commit_fut1 await commit_fut2 + await empty_fut async def test_publishes_retried_on_restart( @@ -270,7 +280,9 @@ async def test_publishes_retried_on_restart( # Write message 1 commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) + empty_fut = asyncio.ensure_future(committer.wait_until_empty()) assert not commit_fut1.done() + assert not empty_fut.done() # Wait for writes to be waiting await sleep_called.get() @@ -284,6 +296,7 @@ async def test_publishes_retried_on_restart( [call(initial_request), call(as_request(cursor1))] ) assert not commit_fut1.done() + assert not empty_fut.done() # Wait for writes to be waiting await sleep_called.get() @@ -292,6 +305,7 @@ async def test_publishes_retried_on_restart( # Write message 2 commit_fut2 = asyncio.ensure_future(committer.commit(cursor2)) assert not commit_fut2.done() + assert not empty_fut.done() # Handle the connection write await sleep_results.put(None) @@ -306,6 +320,7 @@ async def test_publishes_retried_on_restart( ) assert not commit_fut1.done() assert not commit_fut2.done() + assert not empty_fut.done() # Fail the connection with a retryable error await read_called_queue.get() @@ -335,9 +350,74 @@ async def test_publishes_retried_on_restart( call(as_request(cursor2)), ] ) + assert not empty_fut.done() # Sending the response for the one commit finishes both await read_called_queue.get() await read_result_queue.put(as_response(count=1)) await commit_fut1 await commit_fut2 + await empty_fut + + +async def test_wait_until_empty_completes_on_failure( + committer: Committer, + default_connection, + initial_request, + asyncio_sleep, + sleep_queues, +): + sleep_called = sleep_queues[FLUSH_SECONDS].called + sleep_results = sleep_queues[FLUSH_SECONDS].results + cursor1 = Cursor(offset=1) + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + default_connection.write.side_effect = make_queue_waiter( + write_called_queue, write_result_queue + ) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter( + read_called_queue, read_result_queue + ) + read_result_queue.put_nowait(StreamingCommitCursorResponse(initial={})) + write_result_queue.put_nowait(None) + async with committer: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # New committer is empty. + await committer.wait_until_empty() + + # Write message 1 + commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) + empty_fut = asyncio.ensure_future(committer.wait_until_empty()) + assert not commit_fut1.done() + assert not empty_fut.done() + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + + # Handle the connection write + await sleep_results.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls( + [call(initial_request), call(as_request(cursor1))] + ) + assert not commit_fut1.done() + assert not empty_fut.done() + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_has_calls([call(FLUSH_SECONDS), call(FLUSH_SECONDS)]) + + # Fail the connection with a permanent error + await read_called_queue.get() + await read_result_queue.put(InvalidArgument("permanent")) + + with pytest.raises(InvalidArgument): + await empty_fut diff --git a/tests/unit/pubsublite/internal/wire/reset_signal_test.py b/tests/unit/pubsublite/internal/wire/reset_signal_test.py new file mode 100644 index 00000000..8925d525 --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/reset_signal_test.py @@ -0,0 +1,69 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from google.api_core.exceptions import Aborted, NotFound +from google.cloud.pubsublite.internal.wire.reset_signal import is_reset_signal +from google.cloud.pubsublite.testing.test_reset_signal import ( + make_call, + make_reset_signal, +) +from google.protobuf.any_pb2 import Any +from google.rpc.error_details_pb2 import ErrorInfo, RetryInfo +from google.rpc.status_pb2 import Status + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +async def test_is_reset_signal(): + assert is_reset_signal(make_reset_signal()) + + +async def test_non_retryable(): + assert not is_reset_signal(NotFound("")) + + +async def test_missing_call(): + assert not is_reset_signal(Aborted("")) + + +async def test_wrong_reason(): + any = Any() + any.Pack(ErrorInfo(reason="OTHER", domain="pubsublite.googleapis.com")) + status_pb = Status(code=10, details=[any]) + assert not is_reset_signal(Aborted("", response=make_call(status_pb))) + + +async def test_wrong_domain(): + any = Any() + any.Pack(ErrorInfo(reason="RESET", domain="other.googleapis.com")) + status_pb = Status(code=10, details=[any]) + assert not is_reset_signal(Aborted("", response=make_call(status_pb))) + + +async def test_wrong_error_detail(): + any = Any() + any.Pack(RetryInfo()) + status_pb = Status(code=10, details=[any]) + assert not is_reset_signal(Aborted("", response=make_call(status_pb))) + + +async def test_other_error_details_present(): + any1 = Any() + any1.Pack(RetryInfo()) + any2 = Any() + any2.Pack(ErrorInfo(reason="RESET", domain="pubsublite.googleapis.com")) + status_pb = Status(code=10, details=[any1, any2]) + assert is_reset_signal(Aborted("", response=make_call(status_pb))) diff --git a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py index ed773695..8656aa07 100644 --- a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from copy import deepcopy from unittest.mock import call from collections import defaultdict from typing import Dict, List @@ -36,8 +37,12 @@ FlowControlRequest, SeekRequest, ) +from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import ( + SubscriberResetHandler, +) from google.cloud.pubsublite_v1.types.common import Cursor, SequencedMessage from google.cloud.pubsublite.testing.test_utils import make_queue_waiter +from google.cloud.pubsublite.testing.test_reset_signal import make_reset_signal from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS FLUSH_SECONDS = 100000 @@ -61,8 +66,19 @@ def connection_factory(default_connection): @pytest.fixture() -def initial_request(): - return SubscribeRequest(initial=InitialSubscribeRequest(subscription="mysub")) +def reset_handler(): + return MagicMock(spec=SubscriberResetHandler) + + +@pytest.fixture() +def base_initial_subscribe(): + return InitialSubscribeRequest(subscription="mysub") + + +@pytest.fixture() +def initial_request(base_initial_subscribe): + location = SeekRequest(named_target=SeekRequest.NamedTarget.COMMITTED_CURSOR) + return make_initial_subscribe_request(base_initial_subscribe, location) class QueuePair: @@ -95,8 +111,10 @@ async def sleeper(delay: float): @pytest.fixture() -def subscriber(connection_factory, initial_request): - return SubscriberImpl(initial_request.initial, FLUSH_SECONDS, connection_factory) +def subscriber(connection_factory, base_initial_subscribe, reset_handler): + return SubscriberImpl( + base_initial_subscribe, FLUSH_SECONDS, connection_factory, reset_handler + ) def as_request(flow: FlowControlRequest): @@ -111,6 +129,14 @@ def as_response(messages: List[SequencedMessage]): return res +def make_initial_subscribe_request( + base: InitialSubscribeRequest, location: SeekRequest +): + initial_subscribe = deepcopy(base) + initial_subscribe.initial_location = location + return SubscribeRequest(initial=initial_subscribe) + + async def test_basic_flow_control_after_timeout( subscriber: Subscriber, default_connection, @@ -253,6 +279,7 @@ async def test_flow_resent_on_restart( async def test_message_receipt( subscriber: Subscriber, default_connection, + base_initial_subscribe, initial_request, asyncio_sleep, sleep_queues, @@ -307,28 +334,20 @@ async def test_message_receipt( await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) # Reinitialization await write_called_queue.get() - default_connection.write.assert_has_calls( - [call(initial_request), call(as_request(flow)), call(initial_request)] - ) - await write_result_queue.put(None) - await read_called_queue.get() - await read_result_queue.put(SubscribeResponse(initial={})) - # Sends fetch offset seek on the stream, and checks the response. - seek_req = SubscribeRequest( - seek=SeekRequest(cursor=Cursor(offset=message_2.cursor.offset + 1)) + seek_to_cursor_request = make_initial_subscribe_request( + base_initial_subscribe, + SeekRequest(cursor=Cursor(offset=message_2.cursor.offset + 1)), ) - await write_called_queue.get() default_connection.write.assert_has_calls( [ call(initial_request), call(as_request(flow)), - call(initial_request), - call(seek_req), + call(seek_to_cursor_request), ] ) await write_result_queue.put(None) await read_called_queue.get() - await read_result_queue.put(SubscribeResponse(seek={})) + await read_result_queue.put(SubscribeResponse(initial={})) # Re-sending flow tokens on the new stream. await write_called_queue.get() await write_result_queue.put(None) @@ -336,8 +355,7 @@ async def test_message_receipt( [ call(initial_request), call(as_request(flow)), - call(initial_request), - call(seek_req), + call(seek_to_cursor_request), call( as_request( FlowControlRequest(allowed_messages=98, allowed_bytes=85) @@ -400,3 +418,87 @@ async def test_out_of_order_receipt_failure( except GoogleAPICallError as e: assert e.grpc_status_code == StatusCode.FAILED_PRECONDITION pass + + +async def test_handle_reset_signal( + subscriber: Subscriber, + default_connection, + initial_request, + asyncio_sleep, + sleep_queues, + reset_handler, +): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + message_1 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=5) + message_2 = SequencedMessage(cursor=Cursor(offset=4), size_bytes=10) + # Ensure messages with earlier offsets can be handled post-reset. + message_3 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=20) + default_connection.write.side_effect = make_queue_waiter( + write_called_queue, write_result_queue + ) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter( + read_called_queue, read_result_queue + ) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) + assert not flow_fut.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut + default_connection.write.assert_has_calls( + [call(initial_request), call(as_request(flow))] + ) + + # Send messages to the subscriber. + await read_result_queue.put(as_response([message_1, message_2])) + + # Read one message. + await read_called_queue.get() + assert (await subscriber.read()) == message_1 + + # Fail the connection with an error containing the RESET signal. + await read_called_queue.get() + await read_result_queue.put(make_reset_signal()) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + # Reinitialization. + await write_called_queue.get() + await write_result_queue.put(None) + await read_called_queue.get() + await read_result_queue.put(SubscribeResponse(initial={})) + # Re-sending flow tokens on the new stream. + await write_called_queue.get() + await write_result_queue.put(None) + reset_handler.handle_reset.assert_called_once() + default_connection.write.assert_has_calls( + [ + call(initial_request), + call(as_request(flow)), + call(initial_request), + call( + as_request( + # Tokens for undelivered message_2 refilled. + FlowControlRequest(allowed_messages=99, allowed_bytes=95) + ) + ), + ] + ) + + # Ensure the subscriber accepts an earlier message. + await read_result_queue.put(as_response([message_3])) + await read_called_queue.get() + assert (await subscriber.read()) == message_3