From 3fb08965debc3e133dc031a3e5835ed843968bcb Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Mon, 10 Aug 2020 10:40:27 -0400 Subject: [PATCH 1/2] feat: Implement SerialBatcher which helps with transforming single writes into batch writes. --- .../internal/wire/serial_batcher.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 google/cloud/pubsublite/internal/wire/serial_batcher.py diff --git a/google/cloud/pubsublite/internal/wire/serial_batcher.py b/google/cloud/pubsublite/internal/wire/serial_batcher.py new file mode 100644 index 00000000..cea809cc --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/serial_batcher.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Generic, List, Iterable +import asyncio + +from google.cloud.pubsublite.internal.wire.connection import Request, Response +from google.cloud.pubsublite.internal.wire.work_item import WorkItem + + +class BatchTester(Generic[Request], ABC): + """A BatchTester determines whether a given batch of messages must be sent.""" + @abstractmethod + def test(self, requests: Iterable[Request]) -> bool: + """ + Args: + requests: The current outstanding batch. + + Returns: Whether that batch must be sent. + """ + raise NotImplementedError() + + +class SerialBatcher(Generic[Request, Response]): + _tester: BatchTester[Request] + _requests: List[WorkItem[Request]] # A list of outstanding requests + + def __init__(self, tester: BatchTester[Request]): + self._tester = tester + self._requests = [] + + def add(self, request: Request) -> 'asyncio.Future[Response]': + """Add a new request to this batcher. Callers must always call should_flush() after add, and flush() if that returns + true. + + Args: + request: The request to send. + + Returns: + A future that will resolve to the response or a GoogleAPICallError. + """ + item = WorkItem[Request](request) + self._requests.append(item) + return item.response_future + + def should_flush(self) -> bool: + return self._tester.test(item.request for item in self._requests) + + def flush(self) -> Iterable[WorkItem[Request]]: + requests = self._requests + self._requests = [] + return requests From 5e950bca32e97a77f7270ba49225cf5df7db97e6 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Mon, 10 Aug 2020 16:46:09 -0400 Subject: [PATCH 2/2] feat: Implement SinglePartitionPublisher which publishes to a single partition and handles retries. --- .../internal/wire/permanent_failable.py | 7 +- .../pubsublite/internal/wire/publisher.py | 29 ++ .../internal/wire/retrying_connection.py | 6 +- .../internal/wire/serial_batcher.py | 6 +- .../wire/single_partition_publisher.py | 146 ++++++++++ .../pubsublite/internal/wire/work_item.py | 12 +- google/cloud/pubsublite/publish_metadata.py | 8 + google/cloud/pubsublite/testing/test_utils.py | 25 +- .../wire/single_partition_publisher_test.py | 264 ++++++++++++++++++ 9 files changed, 489 insertions(+), 14 deletions(-) create mode 100644 google/cloud/pubsublite/internal/wire/publisher.py create mode 100644 google/cloud/pubsublite/internal/wire/single_partition_publisher.py create mode 100644 google/cloud/pubsublite/publish_metadata.py create mode 100644 tests/unit/pubsublite/internal/wire/single_partition_publisher_test.py diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index 1151de78..3efa1c99 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -1,5 +1,5 @@ import asyncio -from typing import Awaitable, TypeVar +from typing import Awaitable, TypeVar, Optional from google.api_core.exceptions import GoogleAPICallError @@ -29,3 +29,8 @@ async def await_or_fail(self, awaitable: Awaitable[T]) -> T: def fail(self, err: GoogleAPICallError): if not self._failure_task.done(): self._failure_task.set_exception(err) + + def error(self) -> Optional[GoogleAPICallError]: + if not self._failure_task.done(): + return None + return self._failure_task.exception() diff --git a/google/cloud/pubsublite/internal/wire/publisher.py b/google/cloud/pubsublite/internal/wire/publisher.py new file mode 100644 index 00000000..4963c91a --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/publisher.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from google.cloud.pubsublite_v1.types import PubSubMessage +from google.cloud.pubsublite.publish_metadata import PublishMetadata + + +class Publisher(ABC): + @abstractmethod + async def __aenter__(self): + raise NotImplementedError() + + @abstractmethod + async def __aexit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError() + + @abstractmethod + async def publish(self, message: PubSubMessage) -> PublishMetadata: + """ + Publish the provided message. + + Args: + message: The message to be published. + + Returns: + Metadata about the published message. + + Raises: + GoogleAPICallError: On a permanent error. + """ + raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index 894022d4..a787deda 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -19,7 +19,7 @@ class RetryingConnection(Connection[Request, Response], PermanentFailable): _loop_task: asyncio.Future - _write_queue: 'asyncio.Queue[WorkItem[Request]]' + _write_queue: 'asyncio.Queue[WorkItem[Request, None]]' _read_queue: 'asyncio.Queue[Response]' def __init__(self, connection_factory: ConnectionFactory[Request, Response], reinitializer: ConnectionReinitializer[Request, Response]): @@ -56,7 +56,7 @@ async def _run_loop(self): await self._reinitializer.reinitialize(connection) bad_retries = 0 await self._loop_connection(connection) - except (Exception, GoogleAPICallError) as e: + except GoogleAPICallError as e: if not is_retryable(e): self.fail(e) return @@ -79,7 +79,7 @@ async def _loop_connection(self, connection: Connection[Request, Response]): read_task = asyncio.ensure_future(connection.read()) @staticmethod - async def _handle_write(connection: Connection[Request, Response], to_write: WorkItem[Request]): + async def _handle_write(connection: Connection[Request, Response], to_write: WorkItem[Request, Response]): try: await connection.write(to_write.request) to_write.response_future.set_result(None) diff --git a/google/cloud/pubsublite/internal/wire/serial_batcher.py b/google/cloud/pubsublite/internal/wire/serial_batcher.py index cea809cc..b04cc664 100644 --- a/google/cloud/pubsublite/internal/wire/serial_batcher.py +++ b/google/cloud/pubsublite/internal/wire/serial_batcher.py @@ -21,7 +21,7 @@ def test(self, requests: Iterable[Request]) -> bool: class SerialBatcher(Generic[Request, Response]): _tester: BatchTester[Request] - _requests: List[WorkItem[Request]] # A list of outstanding requests + _requests: List[WorkItem[Request, Response]] # A list of outstanding requests def __init__(self, tester: BatchTester[Request]): self._tester = tester @@ -37,14 +37,14 @@ def add(self, request: Request) -> 'asyncio.Future[Response]': Returns: A future that will resolve to the response or a GoogleAPICallError. """ - item = WorkItem[Request](request) + item = WorkItem[Request, Response](request) self._requests.append(item) return item.response_future def should_flush(self) -> bool: return self._tester.test(item.request for item in self._requests) - def flush(self) -> Iterable[WorkItem[Request]]: + def flush(self) -> List[WorkItem[Request, Response]]: requests = self._requests self._requests = [] return requests diff --git a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py new file mode 100644 index 00000000..081afa70 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py @@ -0,0 +1,146 @@ +import asyncio +from typing import Optional, List, Iterable + +from absl import logging +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection, ConnectionFactory +from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError +from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer +from google.cloud.pubsublite.internal.wire.connection import Connection +from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher, BatchTester +from google.cloud.pubsublite.partition import Partition +from google.cloud.pubsublite.publish_metadata import PublishMetadata +from google.cloud.pubsublite_v1.types import PubSubMessage, Cursor, PublishRequest, PublishResponse, \ + InitialPublishRequest +from google.cloud.pubsublite.internal.wire.work_item import WorkItem + +# Maximum bytes per batch at 3.5 MiB to avoid GRPC limit of 4 MiB +_MAX_BYTES = int(3.5 * 1024 * 1024) + +# Maximum messages per batch at 1000 +_MAX_MESSAGES = 1000 + + +class SinglePartitionPublisher(Publisher, ConnectionReinitializer[PublishRequest, PublishResponse], BatchTester[PubSubMessage]): + _initial: InitialPublishRequest + _flush_seconds: float + _connection: RetryingConnection[PublishRequest, PublishResponse] + + _batcher: SerialBatcher[PubSubMessage, Cursor] + _outstanding_writes: List[List[WorkItem[PubSubMessage, Cursor]]] + + _receiver: Optional[asyncio.Future] + _flusher: Optional[asyncio.Future] + + def __init__(self, initial: InitialPublishRequest, flush_seconds: float, + factory: ConnectionFactory[PublishRequest, PublishResponse]): + self._initial = initial + self._flush_seconds = flush_seconds + self._connection = RetryingConnection(factory, self) + self._batcher = SerialBatcher(self) + self._outstanding_writes = [] + self._receiver = None + self._flusher = None + + @property + def _partition(self) -> Partition: + return Partition(self._initial.partition) + + async def __aenter__(self): + await self._connection.__aenter__() + + def _start_loopers(self): + assert self._receiver is None + assert self._flusher is None + self._receiver = asyncio.ensure_future(self._receive_loop()) + self._flusher = asyncio.ensure_future(self._flush_loop()) + + async def _stop_loopers(self): + if self._receiver: + self._receiver.cancel() + await self._receiver + self._receiver = None + if self._flusher: + self._flusher.cancel() + await self._flusher + self._flusher = None + + def _handle_response(self, response: PublishResponse): + if "message_response" not in response: + self._connection.fail(FailedPrecondition("Received an invalid subsequent response on the publish stream.")) + if not self._outstanding_writes: + self._connection.fail( + FailedPrecondition("Received an publish response on the stream with no outstanding publishes.")) + next_offset: Cursor = response.message_response.start_cursor.offset + batch: List[WorkItem[PubSubMessage]] = self._outstanding_writes.pop(0) + for item in batch: + item.response_future.set_result(Cursor(offset=next_offset)) + next_offset += 1 + + async def _receive_loop(self): + try: + while True: + response = await self._connection.read() + self._handle_response(response) + except asyncio.CancelledError: + return + + async def _flush_loop(self): + try: + while True: + await asyncio.sleep(self._flush_seconds) + await self._flush() + except asyncio.CancelledError: + return + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._connection.error(): + self._fail_if_retrying_failed() + else: + await self._flush() + await self._connection.__aexit__(exc_type, exc_val, exc_tb) + + def _fail_if_retrying_failed(self): + if self._connection.error(): + for batch in self._outstanding_writes: + for item in batch: + item.response_future.set_exception(self._connection.error()) + + async def _flush(self): + batch = self._batcher.flush() + if not batch: + return + self._outstanding_writes.append(batch) + aggregate = PublishRequest() + aggregate.message_publish_request.messages = [item.request for item in batch] + try: + await self._connection.write(aggregate) + except GoogleAPICallError as e: + logging.debug(f"Failed publish on stream: {e}") + self._fail_if_retrying_failed() + + async def publish(self, message: PubSubMessage) -> PublishMetadata: + cursor_future = self._batcher.add(message) + if self._batcher.should_flush(): + await self._flush() + return PublishMetadata(self._partition, await cursor_future) + + async def reinitialize(self, connection: Connection[PublishRequest, PublishResponse]): + await self._stop_loopers() + await connection.write(PublishRequest(initial_request=self._initial)) + response = await connection.read() + if "initial_response" not in response: + self._connection.fail(FailedPrecondition("Received an invalid initial response on the publish stream.")) + for batch in self._outstanding_writes: + aggregate = PublishRequest() + aggregate.message_publish_request.messages = [item.request for item in batch] + await connection.write(aggregate) + self._start_loopers() + + def test(self, requests: Iterable[PubSubMessage]) -> bool: + request_count = 0 + byte_count = 0 + for req in requests: + request_count += 1 + byte_count += PubSubMessage.pb(req).ByteSize() + return (request_count >= _MAX_MESSAGES) or (byte_count >= _MAX_BYTES) diff --git a/google/cloud/pubsublite/internal/wire/work_item.py b/google/cloud/pubsublite/internal/wire/work_item.py index 3685fb84..29cf125d 100644 --- a/google/cloud/pubsublite/internal/wire/work_item.py +++ b/google/cloud/pubsublite/internal/wire/work_item.py @@ -1,14 +1,14 @@ import asyncio -from typing import Generic, TypeVar +from typing import Generic -T = TypeVar('T') +from google.cloud.pubsublite.internal.wire.connection import Request, Response -class WorkItem(Generic[T]): +class WorkItem(Generic[Request, Response]): """An item of work and a future to complete when it is finished.""" - request: T - response_future: "asyncio.Future[None]" + request: Request + response_future: "asyncio.Future[Response]" - def __init__(self, request: T): + def __init__(self, request: Request): self.request = request self.response_future = asyncio.Future() diff --git a/google/cloud/pubsublite/publish_metadata.py b/google/cloud/pubsublite/publish_metadata.py new file mode 100644 index 00000000..6b37211f --- /dev/null +++ b/google/cloud/pubsublite/publish_metadata.py @@ -0,0 +1,8 @@ +from typing import NamedTuple +from google.cloud.pubsublite_v1.types.common import Cursor +from google.cloud.pubsublite.partition import Partition + + +class PublishMetadata(NamedTuple): + partition: Partition + cursor: Cursor diff --git a/google/cloud/pubsublite/testing/test_utils.py b/google/cloud/pubsublite/testing/test_utils.py index b9531acd..e93c16ba 100644 --- a/google/cloud/pubsublite/testing/test_utils.py +++ b/google/cloud/pubsublite/testing/test_utils.py @@ -1,4 +1,7 @@ -from typing import List, Union, Any +import asyncio +from typing import List, Union, Any, TypeVar, Generic, Optional + +T = TypeVar("T") async def async_iterable(elts: List[Union[Any, Exception]]): @@ -6,3 +9,23 @@ async def async_iterable(elts: List[Union[Any, Exception]]): if isinstance(elt, Exception): raise elt yield elt + + +def make_queue_waiter(started_q: "asyncio.Queue[None]", result_q: "asyncio.Queue[Union[T, Exception]]"): + """ + Given a queue to notify when started and a queue to get results from, return a waiter which + notifies started_q when started and returns from result_q when done. + """ + + async def waiter(*args, **kwargs): + await started_q.put(None) + result = await result_q.get() + if isinstance(result, Exception): + raise result + return result + + return waiter + + +class Box(Generic[T]): + val: Optional[T] diff --git a/tests/unit/pubsublite/internal/wire/single_partition_publisher_test.py b/tests/unit/pubsublite/internal/wire/single_partition_publisher_test.py new file mode 100644 index 00000000..58465e4c --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/single_partition_publisher_test.py @@ -0,0 +1,264 @@ +import asyncio +from unittest.mock import call +from collections import defaultdict +from typing import Dict, List + +from asynctest.mock import MagicMock, CoroutineMock +import pytest +from google.cloud.pubsublite.internal.wire.connection import Connection, ConnectionFactory +from google.api_core.exceptions import InternalServerError +from google.cloud.pubsublite_v1.types.publisher import InitialPublishRequest, PublishRequest, PublishResponse, \ + MessagePublishResponse +from google.cloud.pubsublite_v1.types.common import PubSubMessage, Cursor +from google.cloud.pubsublite.internal.wire.single_partition_publisher import SinglePartitionPublisher +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter +from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS + +FLUSH_SECONDS = 100000 + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def default_connection(): + conn = MagicMock(spec=Connection[int, int]) + conn.__aenter__.return_value = conn + return conn + + +@pytest.fixture() +def connection_factory(default_connection): + factory = MagicMock(spec=ConnectionFactory[int, int]) + factory.new.return_value = default_connection + return factory + + +@pytest.fixture() +def initial_request(): + return PublishRequest(initial_request=InitialPublishRequest(topic="mytopic")) + + +class QueuePair: + called: asyncio.Queue + results: asyncio.Queue + + def __init__(self): + self.called = asyncio.Queue() + self.results = asyncio.Queue() + + +@pytest.fixture +def sleep_queues() -> Dict[float, QueuePair]: + return defaultdict(QueuePair) + + +@pytest.fixture +def asyncio_sleep(monkeypatch, sleep_queues): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + mock = CoroutineMock() + monkeypatch.setattr(asyncio, "sleep", mock) + + async def sleeper(delay: float): + await make_queue_waiter(sleep_queues[delay].called, sleep_queues[delay].results)(delay) + + mock.side_effect = sleeper + return mock + + +@pytest.fixture() +def publisher(connection_factory, initial_request): + return SinglePartitionPublisher(initial_request.initial_request, FLUSH_SECONDS, connection_factory) + + +def as_publish_request(messages: List[PubSubMessage]): + req = PublishRequest() + req.message_publish_request.messages = messages + return req + + +def as_publish_response(start_cursor: int): + return PublishResponse(message_response=MessagePublishResponse(start_cursor=Cursor(offset=start_cursor))) + + +async def test_basic_publish_after_timeout(publisher: Publisher, default_connection, initial_request, asyncio_sleep, + sleep_queues): + sleep_called = sleep_queues[FLUSH_SECONDS].called + sleep_results = sleep_queues[FLUSH_SECONDS].results + message1 = PubSubMessage(data=b"abc") + message2 = PubSubMessage(data=b"def") + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(PublishResponse(initial_response={})) + async with publisher: + # Set up connection + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Write messages + publish_fut1 = asyncio.ensure_future(publisher.publish(message1)) + publish_fut2 = asyncio.ensure_future(publisher.publish(message2)) + assert not publish_fut1.done() + assert not publish_fut2.done() + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + + # Handle the connection write + write_future = asyncio.Future() + + async def write(val: PublishRequest): + write_future.set_result(None) + + default_connection.write.side_effect = write + await sleep_results.put(None) + await write_future + default_connection.write.assert_has_calls([call(initial_request), call(as_publish_request([message1, message2]))]) + assert not publish_fut1.done() + assert not publish_fut2.done() + + # Send the connection response + await read_result_queue.put(as_publish_response(100)) + cursor1 = (await publish_fut1).cursor + cursor2 = (await publish_fut2).cursor + assert cursor1.offset == 100 + assert cursor2.offset == 101 + + +async def test_publishes_multi_cycle(publisher: Publisher, default_connection, initial_request, asyncio_sleep, + sleep_queues): + sleep_called = sleep_queues[FLUSH_SECONDS].called + sleep_results = sleep_queues[FLUSH_SECONDS].results + message1 = PubSubMessage(data=b"abc") + message2 = PubSubMessage(data=b"def") + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + write_result_queue.put_nowait(None) + read_result_queue.put_nowait(PublishResponse(initial_response={})) + async with publisher: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Write message 1 + publish_fut1 = asyncio.ensure_future(publisher.publish(message1)) + assert not publish_fut1.done() + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + + # Handle the connection write + await sleep_results.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls([call(initial_request), call(as_publish_request([message1]))]) + assert not publish_fut1.done() + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_has_calls([call(FLUSH_SECONDS), call(FLUSH_SECONDS)]) + + # Write message 2 + publish_fut2 = asyncio.ensure_future(publisher.publish(message2)) + assert not publish_fut2.done() + + # Handle the connection write + await sleep_results.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls( + [call(initial_request), call(as_publish_request([message1])), call(as_publish_request([message2]))]) + assert not publish_fut1.done() + assert not publish_fut2.done() + + # Send the connection responses + await read_result_queue.put(as_publish_response(100)) + cursor1 = (await publish_fut1).cursor + assert cursor1.offset == 100 + assert not publish_fut2.done() + await read_result_queue.put(as_publish_response(105)) + cursor2 = (await publish_fut2).cursor + assert cursor2.offset == 105 + + +async def test_publishes_retried_on_restart(publisher: Publisher, default_connection, initial_request, asyncio_sleep, + sleep_queues): + sleep_called = sleep_queues[FLUSH_SECONDS].called + sleep_results = sleep_queues[FLUSH_SECONDS].results + message1 = PubSubMessage(data=b"abc") + message2 = PubSubMessage(data=b"def") + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + write_result_queue.put_nowait(None) + read_result_queue.put_nowait(PublishResponse(initial_response={})) + async with publisher: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Write message 1 + publish_fut1 = asyncio.ensure_future(publisher.publish(message1)) + assert not publish_fut1.done() + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + + # Handle the connection write + await sleep_results.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls([call(initial_request), call(as_publish_request([message1]))]) + assert not publish_fut1.done() + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_has_calls([call(FLUSH_SECONDS), call(FLUSH_SECONDS)]) + + # Write message 2 + publish_fut2 = asyncio.ensure_future(publisher.publish(message2)) + assert not publish_fut2.done() + + # Handle the connection write + await sleep_results.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls( + [call(initial_request), call(as_publish_request([message1])), call(as_publish_request([message2]))]) + assert not publish_fut1.done() + assert not publish_fut2.done() + + # Fail the connection with a retryable error + await read_called_queue.get() + await read_result_queue.put(InternalServerError("retryable")) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + # Reinitialization + await write_called_queue.get() + write_result_queue.put_nowait(None) + await read_called_queue.get() + read_result_queue.put_nowait(PublishResponse(initial_response={})) + # Re-sending messages on the new stream + await write_called_queue.get() + await write_result_queue.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + asyncio_sleep.assert_has_calls( + [call(FLUSH_SECONDS), call(FLUSH_SECONDS), call(FLUSH_SECONDS), call(_MIN_BACKOFF_SECS)]) + default_connection.write.assert_has_calls([ + call(initial_request), call(as_publish_request([message1])), call(as_publish_request([message2])), + call(initial_request), call(as_publish_request([message1])), call(as_publish_request([message2]))])