diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py index 01f0157f..78611c37 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py @@ -35,7 +35,7 @@ def track(self, offset: int): """ @abstractmethod - async def ack(self, offset: int): + def ack(self, offset: int): """ Acknowledge the message with the provided offset. The offset must have previously been tracked. diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py index 5a648091..b1e6aedd 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py @@ -17,6 +17,7 @@ from typing import Optional from google.api_core.exceptions import FailedPrecondition + from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker from google.cloud.pubsublite.internal.wire.committer import Committer from google.cloud.pubsublite_v1 import Cursor @@ -43,9 +44,7 @@ def track(self, offset: int): ) self._receipts.append(offset) - async def ack(self, offset: int): - # Note: put_nowait is used here and below to ensure that the below logic is executed without yielding - # to another coroutine in the event loop. The queue is unbounded so it will never throw. + def ack(self, offset: int): self._acks.put_nowait(offset) prefix_acked_offset: Optional[int] = None while len(self._receipts) != 0 and not self._acks.empty(): @@ -60,7 +59,7 @@ async def ack(self, offset: int): if prefix_acked_offset is None: return # Convert from last acked to first unacked. - await self._committer.commit(Cursor(offset=prefix_acked_offset + 1)) + self._committer.commit(Cursor(offset=prefix_acked_offset + 1)) async def clear_and_commit(self): self._receipts.clear() diff --git a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py index 8a37ad89..c2706113 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py @@ -20,7 +20,10 @@ from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import ( AsyncSingleSubscriber, ) -from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wait_ignore_cancelled import ( + wait_ignore_cancelled, + wait_ignore_errors, +) from google.cloud.pubsublite.internal.wire.assigner import Assigner from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable from google.cloud.pubsublite.types import Partition @@ -100,8 +103,10 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): self._assign_poller.cancel() - await wait_ignore_cancelled(self._assign_poller) - await self._assigner.__aexit__(exc_type, exc_value, traceback) + await wait_ignore_errors(self._assign_poller) + await wait_ignore_errors( + self._assigner.__aexit__(exc_type, exc_value, traceback) + ) for running in self._subscribers.values(): - await self._stop_subscriber(running) + await wait_ignore_errors(self._stop_subscriber(running)) pass diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py index 00290889..801adfd9 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py @@ -126,8 +126,8 @@ async def read(self) -> List[Message]: self.fail(e) raise e - async def _handle_ack(self, message: requests.AckRequest): - await self._underlying.allow_flow( + def _handle_ack(self, message: requests.AckRequest): + self._underlying.allow_flow( FlowControlRequest( allowed_messages=1, allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes, @@ -138,7 +138,7 @@ async def _handle_ack(self, message: requests.AckRequest): ack_id = _AckId.parse(message.ack_id) if ack_id.generation == self._ack_generation_id: try: - await self._ack_set_tracker.ack(ack_id.offset) + self._ack_set_tracker.ack(ack_id.offset) except GoogleAPICallError as e: self.fail(e) @@ -179,7 +179,7 @@ async def _handle_queue_message( ) ) elif isinstance(message, requests.AckRequest): - await self._handle_ack(message) + self._handle_ack(message) else: self._handle_nack(message) @@ -198,7 +198,7 @@ async def __aenter__(self): await self._ack_set_tracker.__aenter__() await self._underlying.__aenter__() self._looper_future = asyncio.ensure_future(self._looper()) - await self._underlying.allow_flow( + self._underlying.allow_flow( FlowControlRequest( allowed_messages=self._flow_control_settings.messages_outstanding, allowed_bytes=self._flow_control_settings.bytes_outstanding, diff --git a/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py index 05a58c13..776d3ffe 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py @@ -17,6 +17,9 @@ from concurrent.futures.thread import ThreadPoolExecutor from typing import ContextManager, Optional from google.api_core.exceptions import GoogleAPICallError +from functools import partial + +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.cloudpubsub.internal.managed_event_loop import ( ManagedEventLoop, ) @@ -86,8 +89,8 @@ async def _poller(self): while True: batch = await self._underlying.read() self._unowned_executor.map(self._callback, batch) - except GoogleAPICallError as e: # noqa: F841 Flake8 thinks e is unused - self._unowned_executor.submit(lambda: self._fail(e)) # noqa: F821 + except GoogleAPICallError as e: + self._unowned_executor.submit(partial(self._fail, e)) def __enter__(self): assert self._close_callback is not None @@ -97,13 +100,15 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): + self._poller_future.cancel() try: - self._poller_future.cancel() - self._poller_future.result() - except concurrent.futures.CancelledError: + self._poller_future.result() # Ignore error. + except: # noqa: E722 pass self._event_loop.submit( - self._underlying.__aexit__(exc_type, exc_value, traceback) + wait_ignore_errors( + self._underlying.__aexit__(exc_type, exc_value, traceback) + ) ).result() self._event_loop.__exit__(exc_type, exc_value, traceback) assert self._close_callback is not None diff --git a/google/cloud/pubsublite/internal/wire/committer.py b/google/cloud/pubsublite/internal/wire/committer.py index f2485f9e..2e0e4ffa 100644 --- a/google/cloud/pubsublite/internal/wire/committer.py +++ b/google/cloud/pubsublite/internal/wire/committer.py @@ -24,7 +24,13 @@ class Committer(AsyncContextManager, metaclass=ABCMeta): """ @abstractmethod - async def commit(self, cursor: Cursor) -> None: + def commit(self, cursor: Cursor) -> None: + """ + Start the commit for a cursor. + + Raises: + GoogleAPICallError: When the committer terminates in failure. + """ pass @abstractmethod diff --git a/google/cloud/pubsublite/internal/wire/committer_impl.py b/google/cloud/pubsublite/internal/wire/committer_impl.py index d8f06f03..cafaa344 100644 --- a/google/cloud/pubsublite/internal/wire/committer_impl.py +++ b/google/cloud/pubsublite/internal/wire/committer_impl.py @@ -28,14 +28,12 @@ ConnectionReinitializer, ) from google.cloud.pubsublite.internal.wire.connection import Connection -from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher from google.cloud.pubsublite_v1 import Cursor from google.cloud.pubsublite_v1.types import ( StreamingCommitCursorRequest, StreamingCommitCursorResponse, InitialCommitCursorRequest, ) -from google.cloud.pubsublite.internal.wire.work_item import WorkItem _LOGGER = logging.getLogger(__name__) @@ -53,9 +51,8 @@ class CommitterImpl( StreamingCommitCursorRequest, StreamingCommitCursorResponse ] - _batcher: SerialBatcher[Cursor, None] - - _outstanding_commits: List[List[WorkItem[Cursor, None]]] + _next_to_commit: Optional[Cursor] + _outstanding_commits: List[Cursor] _receiver: Optional[asyncio.Future] _flusher: Optional[asyncio.Future] @@ -72,7 +69,7 @@ def __init__( self._initial = initial self._flush_seconds = flush_seconds self._connection = RetryingConnection(factory, self) - self._batcher = SerialBatcher() + self._next_to_commit = None self._outstanding_commits = [] self._receiver = None self._flusher = None @@ -113,9 +110,7 @@ def _handle_response(self, response: StreamingCommitCursorResponse): ) ) for _ in range(response.commit.acknowledged_commits): - batch = self._outstanding_commits.pop(0) - for item in batch: - item.response_future.set_result(None) + self._outstanding_commits.pop(0) if len(self._outstanding_commits) == 0: self._empty.set() @@ -131,39 +126,31 @@ async def _flush_loop(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self._stop_loopers() - if self._connection.error(): - self._fail_if_retrying_failed() - else: + if not self._connection.error(): await self._flush() await self._connection.__aexit__(exc_type, exc_val, exc_tb) - def _fail_if_retrying_failed(self): - if self._connection.error(): - for batch in self._outstanding_commits: - for item in batch: - item.response_future.set_exception(self._connection.error()) - async def _flush(self): - batch = self._batcher.flush() - if not batch: + if self._next_to_commit is None: return - self._outstanding_commits.append(batch) - self._empty.clear() req = StreamingCommitCursorRequest() - req.commit.cursor = batch[-1].request + req.commit.cursor = self._next_to_commit + self._outstanding_commits.append(self._next_to_commit) + self._next_to_commit = None + self._empty.clear() try: await self._connection.write(req) except GoogleAPICallError as e: _LOGGER.debug(f"Failed commit on stream: {e}") - self._fail_if_retrying_failed() async def wait_until_empty(self): await self._flush() await self._connection.await_unless_failed(self._empty.wait()) - async def commit(self, cursor: Cursor) -> None: - future = self._batcher.add(cursor) - await future + def commit(self, cursor: Cursor) -> None: + if self._connection.error(): + raise self._connection.error() + self._next_to_commit = cursor async def reinitialize( self, @@ -181,14 +168,8 @@ async def reinitialize( "Received an invalid initial response on the publish stream." ) ) - if self._outstanding_commits: - # Roll up outstanding commits - rollup: List[WorkItem[Cursor, None]] = [] - for batch in self._outstanding_commits: - for item in batch: - rollup.append(item) - self._outstanding_commits = [rollup] - req = StreamingCommitCursorRequest() - req.commit.cursor = rollup[-1].request - await connection.write(req) + if self._next_to_commit is None: + if self._outstanding_commits: + self._next_to_commit = self._outstanding_commits[-1] + self._outstanding_commits = [] self._start_loopers() diff --git a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py index ee442808..8eceaf01 100644 --- a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py +++ b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py @@ -26,11 +26,14 @@ class _AggregateRequest: def __init__(self): self._request = FlowControlRequest.meta.pb() - def __add__(self, other: FlowControlRequest.meta.pb): - self._request.allowed_bytes = self._request.allowed_bytes + other.allowed_bytes + def __add__(self, other: FlowControlRequest): + other_pb = other._pb + self._request.allowed_bytes = ( + self._request.allowed_bytes + other_pb.allowed_bytes + ) self._request.allowed_bytes = min(self._request.allowed_bytes, _MAX_INT64) self._request.allowed_messages = ( - self._request.allowed_messages + other.allowed_messages + self._request.allowed_messages + other_pb.allowed_messages ) self._request.allowed_messages = min(self._request.allowed_messages, _MAX_INT64) return self @@ -77,16 +80,3 @@ def release_pending_request(self) -> Optional[FlowControlRequest]: request = self._pending_tokens self._pending_tokens = _AggregateRequest() return request.to_optional() - - def should_expedite(self): - pending_request = self._pending_tokens._request - client_request = self._client_tokens._request - if _exceeds_expedite_ratio( - pending_request.allowed_bytes, client_request.allowed_bytes - ): - return True - if _exceeds_expedite_ratio( - pending_request.allowed_messages, client_request.allowed_messages - ): - return True - return False diff --git a/google/cloud/pubsublite/internal/wire/subscriber.py b/google/cloud/pubsublite/internal/wire/subscriber.py index dec650f6..7a091572 100644 --- a/google/cloud/pubsublite/internal/wire/subscriber.py +++ b/google/cloud/pubsublite/internal/wire/subscriber.py @@ -36,7 +36,7 @@ async def read(self) -> List[SequencedMessage.meta.pb]: raise NotImplementedError() @abstractmethod - async def allow_flow(self, request: FlowControlRequest): + def allow_flow(self, request: FlowControlRequest): """ Allow an additional amount of messages and bytes to be sent to this client. """ diff --git a/google/cloud/pubsublite/internal/wire/subscriber_impl.py b/google/cloud/pubsublite/internal/wire/subscriber_impl.py index 02466bee..9cbb9fff 100644 --- a/google/cloud/pubsublite/internal/wire/subscriber_impl.py +++ b/google/cloud/pubsublite/internal/wire/subscriber_impl.py @@ -201,10 +201,5 @@ async def reinitialize( async def read(self) -> List[SequencedMessage.meta.pb]: return await self._connection.await_unless_failed(self._message_queue.get()) - async def allow_flow(self, request: FlowControlRequest): + def allow_flow(self, request: FlowControlRequest): self._outstanding_flow_control.add(request) - if ( - not self._reinitializing - and self._outstanding_flow_control.should_expedite() - ): - await self._try_send_tokens() diff --git a/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py b/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py index 8824d89e..f859592f 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py @@ -49,17 +49,17 @@ async def test_track_and_aggregate_acks(committer, tracker: AckSetTracker): tracker.track(offset=7) committer.commit.assert_has_calls([]) - await tracker.ack(offset=3) + tracker.ack(offset=3) committer.commit.assert_has_calls([]) - await tracker.ack(offset=1) + tracker.ack(offset=1) committer.commit.assert_has_calls([call(Cursor(offset=4))]) - await tracker.ack(offset=5) + tracker.ack(offset=5) committer.commit.assert_has_calls( [call(Cursor(offset=4)), call(Cursor(offset=6))] ) tracker.track(offset=8) - await tracker.ack(offset=7) + tracker.ack(offset=7) committer.commit.assert_has_calls( [call(Cursor(offset=4)), call(Cursor(offset=6)), call(Cursor(offset=8))] ) @@ -74,7 +74,7 @@ async def test_clear_and_commit(committer, tracker: AckSetTracker): with pytest.raises(FailedPrecondition): tracker.track(offset=1) - await tracker.ack(offset=5) + tracker.ack(offset=5) committer.commit.assert_has_calls([]) await tracker.clear_and_commit() @@ -82,6 +82,6 @@ async def test_clear_and_commit(committer, tracker: AckSetTracker): # After clearing, it should be possible to track earlier offsets. tracker.track(offset=1) - await tracker.ack(offset=1) + tracker.ack(offset=1) committer.commit.assert_has_calls([call(Cursor(offset=2))]) committer.__aexit__.assert_called_once() diff --git a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py index 6949fddc..3d355642 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py @@ -36,7 +36,6 @@ from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import ( SubscriberResetHandler, ) -from google.cloud.pubsublite.testing.test_utils import make_queue_waiter from google.cloud.pubsublite_v1 import Cursor, FlowControlRequest, SequencedMessage # All test coroutines will be treated as marked. @@ -71,8 +70,15 @@ def initial_flow_request(flow_control_settings): @pytest.fixture() -def ack_set_tracker(): - return mock_async_context_manager(MagicMock(spec=AckSetTracker)) +def ack_queue(): + return asyncio.Queue() + + +@pytest.fixture() +def ack_set_tracker(ack_queue): + tracker = mock_async_context_manager(MagicMock(spec=AckSetTracker)) + tracker.ack.side_effect = lambda offset: ack_queue.put_nowait(None) + return tracker @pytest.fixture() @@ -125,13 +131,12 @@ async def test_failed_transform(subscriber, underlying, transformer): async def test_ack( - subscriber: AsyncSingleSubscriber, underlying, transformer, ack_set_tracker + subscriber: AsyncSingleSubscriber, + underlying, + transformer, + ack_set_tracker, + ack_queue, ): - ack_called_queue = asyncio.Queue() - ack_result_queue = asyncio.Queue() - ack_set_tracker.ack.side_effect = make_queue_waiter( - ack_called_queue, ack_result_queue - ) async with subscriber: message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10)._pb @@ -144,12 +149,10 @@ async def test_ack( assert read_1.message_id == "1" assert read_2.message_id == "2" read_2.ack() - await ack_called_queue.get() - await ack_result_queue.put(None) + await ack_queue.get() ack_set_tracker.ack.assert_has_calls([call(2)]) read_1.ack() - await ack_called_queue.get() - await ack_result_queue.put(None) + await ack_queue.get() ack_set_tracker.ack.assert_has_calls([call(2), call(1)]) @@ -173,22 +176,23 @@ async def test_ack_failure( underlying, transformer, ack_set_tracker, + ack_queue, ): - ack_called_queue = asyncio.Queue() - ack_result_queue = asyncio.Queue() - ack_set_tracker.ack.side_effect = make_queue_waiter( - ack_called_queue, ack_result_queue - ) async with subscriber: message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb underlying.read.return_value = [message] read: List[Message] = await subscriber.read() assert len(read) == 1 ack_set_tracker.track.assert_has_calls([call(1)]) + + def bad_ack(offset): + ack_queue.put_nowait(None) + raise FailedPrecondition("Bad ack") + + ack_set_tracker.ack.side_effect = bad_ack read[0].ack() - await ack_called_queue.get() + await ack_queue.get() ack_set_tracker.ack.assert_has_calls([call(1)]) - await ack_result_queue.put(FailedPrecondition("Bad ack")) async def sleep_forever(): await asyncio.sleep(float("inf")) @@ -228,12 +232,8 @@ async def test_nack_calls_ack( transformer, ack_set_tracker, nack_handler, + ack_queue, ): - ack_called_queue = asyncio.Queue() - ack_result_queue = asyncio.Queue() - ack_set_tracker.ack.side_effect = make_queue_waiter( - ack_called_queue, ack_result_queue - ) async with subscriber: message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb underlying.read.return_value = [message] @@ -247,8 +247,7 @@ def on_nack(nacked: PubsubMessage, ack: Callable[[], None]): nack_handler.on_nack.side_effect = on_nack read[0].nack() - await ack_called_queue.get() - await ack_result_queue.put(None) + await ack_queue.get() ack_set_tracker.ack.assert_has_calls([call(1)]) @@ -257,12 +256,8 @@ async def test_handle_reset( underlying, transformer, ack_set_tracker, + ack_queue, ): - ack_called_queue = asyncio.Queue() - ack_result_queue = asyncio.Queue() - ack_set_tracker.ack.side_effect = make_queue_waiter( - ack_called_queue, ack_result_queue - ) async with subscriber: message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb underlying.read.return_value = [message_1] @@ -287,8 +282,7 @@ async def test_handle_reset( assert read_2[0].message_id == "2" assert read_2[0].ack_id == ack_id(1, 2) read_2[0].ack() - await ack_called_queue.get() - await ack_result_queue.put(None) + await ack_queue.get() underlying.allow_flow.assert_has_calls( [ call(FlowControlRequest(allowed_messages=1000, allowed_bytes=1000,)), diff --git a/tests/unit/pubsublite/internal/wire/committer_impl_test.py b/tests/unit/pubsublite/internal/wire/committer_impl_test.py index 773b2b85..1c963f6b 100644 --- a/tests/unit/pubsublite/internal/wire/committer_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/committer_impl_test.py @@ -139,11 +139,9 @@ async def test_basic_commit_after_timeout( default_connection.write.assert_has_calls([call(initial_request)]) # Commit cursors - commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) - commit_fut2 = asyncio.ensure_future(committer.commit(cursor2)) + committer.commit(cursor1) + committer.commit(cursor2) empty_fut = asyncio.ensure_future(committer.wait_until_empty()) - assert not commit_fut1.done() - assert not commit_fut2.done() assert not empty_fut.done() # Wait for writes to be waiting @@ -158,14 +156,10 @@ async def test_basic_commit_after_timeout( default_connection.write.assert_has_calls( [call(initial_request), call(as_request(cursor2))] ) - assert not commit_fut1.done() - assert not commit_fut2.done() assert not empty_fut.done() # Send the connection response with 1 ack since only one request was sent. await read_result_queue.put(as_response(count=1)) - await commit_fut1 - await commit_fut2 await empty_fut @@ -199,9 +193,8 @@ async def test_commits_multi_cycle( default_connection.write.assert_has_calls([call(initial_request)]) # Write message 1 - commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) + committer.commit(cursor1) empty_fut = asyncio.ensure_future(committer.wait_until_empty()) - assert not commit_fut1.done() assert not empty_fut.done() # Wait for writes to be waiting @@ -215,7 +208,6 @@ async def test_commits_multi_cycle( default_connection.write.assert_has_calls( [call(initial_request), call(as_request(cursor1))] ) - assert not commit_fut1.done() assert not empty_fut.done() # Wait for writes to be waiting @@ -223,8 +215,7 @@ async def test_commits_multi_cycle( asyncio_sleep.assert_has_calls([call(FLUSH_SECONDS), call(FLUSH_SECONDS)]) # Write message 2 - commit_fut2 = asyncio.ensure_future(committer.commit(cursor2)) - assert not commit_fut2.done() + committer.commit(cursor2) assert not empty_fut.done() # Handle the connection write @@ -238,14 +229,10 @@ async def test_commits_multi_cycle( call(as_request(cursor2)), ] ) - assert not commit_fut1.done() - assert not commit_fut2.done() assert not empty_fut.done() # Send the connection responses await read_result_queue.put(as_response(count=2)) - await commit_fut1 - await commit_fut2 await empty_fut @@ -279,9 +266,8 @@ async def test_publishes_retried_on_restart( default_connection.write.assert_has_calls([call(initial_request)]) # Write message 1 - commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) + committer.commit(cursor1) empty_fut = asyncio.ensure_future(committer.wait_until_empty()) - assert not commit_fut1.done() assert not empty_fut.done() # Wait for writes to be waiting @@ -295,7 +281,6 @@ async def test_publishes_retried_on_restart( default_connection.write.assert_has_calls( [call(initial_request), call(as_request(cursor1))] ) - assert not commit_fut1.done() assert not empty_fut.done() # Wait for writes to be waiting @@ -303,8 +288,7 @@ async def test_publishes_retried_on_restart( asyncio_sleep.assert_has_calls([call(FLUSH_SECONDS), call(FLUSH_SECONDS)]) # Write message 2 - commit_fut2 = asyncio.ensure_future(committer.commit(cursor2)) - assert not commit_fut2.done() + committer.commit(cursor2) assert not empty_fut.done() # Handle the connection write @@ -318,8 +302,6 @@ async def test_publishes_retried_on_restart( call(as_request(cursor2)), ] ) - assert not commit_fut1.done() - assert not commit_fut2.done() assert not empty_fut.done() # Fail the connection with a retryable error @@ -333,6 +315,8 @@ async def test_publishes_retried_on_restart( await read_called_queue.get() await read_result_queue.put(StreamingCommitCursorResponse(initial={})) # Re-sending messages on the new stream + await sleep_queues[FLUSH_SECONDS].called.get() + await sleep_queues[FLUSH_SECONDS].results.put(None) await write_called_queue.get() await write_result_queue.put(None) asyncio_sleep.assert_has_calls( @@ -355,8 +339,6 @@ async def test_publishes_retried_on_restart( # Sending the response for the one commit finishes both await read_called_queue.get() await read_result_queue.put(as_response(count=1)) - await commit_fut1 - await commit_fut2 await empty_fut @@ -392,9 +374,8 @@ async def test_wait_until_empty_completes_on_failure( await committer.wait_until_empty() # Write message 1 - commit_fut1 = asyncio.ensure_future(committer.commit(cursor1)) + committer.commit(cursor1) empty_fut = asyncio.ensure_future(committer.wait_until_empty()) - assert not commit_fut1.done() assert not empty_fut.done() # Wait for writes to be waiting @@ -408,7 +389,6 @@ async def test_wait_until_empty_completes_on_failure( default_connection.write.assert_has_calls( [call(initial_request), call(as_request(cursor1))] ) - assert not commit_fut1.done() assert not empty_fut.done() # Wait for writes to be waiting diff --git a/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py b/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py index 000a32f7..c037135c 100644 --- a/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py +++ b/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py @@ -21,14 +21,12 @@ def test_restart_clears_send(): batcher = FlowControlBatcher() batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3)) - assert batcher.should_expedite() to_send = batcher.release_pending_request() assert to_send.allowed_bytes == 10 assert to_send.allowed_messages == 3 restart_1 = batcher.request_for_restart() assert restart_1.allowed_bytes == 10 assert restart_1.allowed_messages == 3 - assert not batcher.should_expedite() assert batcher.release_pending_request() is None diff --git a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py index 82df4b2c..e2de2481 100644 --- a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py @@ -168,20 +168,23 @@ async def test_basic_flow_control_after_timeout( default_connection.write.assert_has_calls([call(initial_request)]) # Send tokens. - flow_fut1 = asyncio.ensure_future(subscriber.allow_flow(flow_1)) - assert not flow_fut1.done() + subscriber.allow_flow(flow_1) - # Handle the inline write since initial tokens are 100% of outstanding. + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + await sleep_results.put(None) + + # Handle the connection write. await write_called_queue.get() await write_result_queue.put(None) - await flow_fut1 default_connection.write.assert_has_calls( [call(initial_request), call(as_request(flow_1))] ) - # Should complete without writing to the connection - await subscriber.allow_flow(flow_2) - await subscriber.allow_flow(flow_3) + # Multiple requests are batched + subscriber.allow_flow(flow_2) + subscriber.allow_flow(flow_3) # Wait for writes to be waiting await sleep_called.get() @@ -234,20 +237,23 @@ async def test_flow_resent_on_restart( default_connection.write.assert_has_calls([call(initial_request)]) # Send tokens. - flow_fut1 = asyncio.ensure_future(subscriber.allow_flow(flow_1)) - assert not flow_fut1.done() + subscriber.allow_flow(flow_1) - # Handle the inline write since initial tokens are 100% of outstanding. + # Wait for writes to be waiting + await sleep_queues[FLUSH_SECONDS].called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + await sleep_queues[FLUSH_SECONDS].results.put(None) + + # Handle the connection write. await write_called_queue.get() await write_result_queue.put(None) - await flow_fut1 default_connection.write.assert_has_calls( [call(initial_request), call(as_request(flow_1))] ) - # Should complete without writing to the connection - await subscriber.allow_flow(flow_2) - await subscriber.allow_flow(flow_3) + # Send more tokens + subscriber.allow_flow(flow_2) + subscriber.allow_flow(flow_3) # Fail the connection with a retryable error await read_called_queue.get() @@ -306,13 +312,16 @@ async def test_message_receipt( default_connection.write.assert_has_calls([call(initial_request)]) # Send tokens. - flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) - assert not flow_fut.done() + subscriber.allow_flow(flow) - # Handle the inline write since initial tokens are 100% of outstanding. + # Wait for writes to be waiting + await sleep_queues[FLUSH_SECONDS].called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + await sleep_queues[FLUSH_SECONDS].results.put(None) + + # Handle the connection write. await write_called_queue.get() await write_result_queue.put(None) - await flow_fut default_connection.write.assert_has_calls( [call(initial_request), call(as_request(flow))] ) @@ -395,13 +404,16 @@ async def test_out_of_order_receipt_failure( default_connection.write.assert_has_calls([call(initial_request)]) # Send tokens. - flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) - assert not flow_fut.done() + subscriber.allow_flow(flow) + + # Wait for writes to be waiting + await sleep_queues[FLUSH_SECONDS].called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + await sleep_queues[FLUSH_SECONDS].results.put(None) - # Handle the inline write since initial tokens are 100% of outstanding. + # Handle the connection write. await write_called_queue.get() await write_result_queue.put(None) - await flow_fut default_connection.write.assert_has_calls( [call(initial_request), call(as_request(flow))] ) @@ -453,13 +465,16 @@ async def test_handle_reset_signal( default_connection.write.assert_has_calls([call(initial_request)]) # Send tokens. - flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) - assert not flow_fut.done() + subscriber.allow_flow(flow) + + # Wait for writes to be waiting + await sleep_queues[FLUSH_SECONDS].called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + await sleep_queues[FLUSH_SECONDS].results.put(None) - # Handle the inline write since initial tokens are 100% of outstanding. + # Handle the connection write. await write_called_queue.get() await write_result_queue.put(None) - await flow_fut default_connection.write.assert_has_calls( [call(initial_request), call(as_request(flow))] )