diff --git a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py index 4a0c410d..bfc33ff7 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py @@ -73,8 +73,9 @@ async def _assign_action(self): for partition in added_partitions: await self._start_subscriber(partition) for partition in removed_partitions: - await self._stop_subscriber(self._subscribers[partition]) + subscriber = self._subscribers[partition] del self._subscribers[partition] + await self._stop_subscriber(subscriber) async def __aenter__(self): self._messages = Queue() @@ -89,3 +90,4 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self._assigner.__aexit__(exc_type, exc_value, traceback) for running in self._subscribers.values(): await self._stop_subscriber(running) + pass diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py index 8b441507..6150e818 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py @@ -6,6 +6,7 @@ from google.cloud.pubsub_v1.subscriber.message import Message from google.pubsub_v1 import PubsubMessage +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled from google.cloud.pubsublite.types import FlowControlSettings from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer @@ -54,10 +55,10 @@ def __init__( self._messages_by_offset = {} async def read(self) -> Message: - message: SequencedMessage = await self.await_unless_failed( - self._underlying.read() - ) try: + message: SequencedMessage = await self.await_unless_failed( + self._underlying.read() + ) cps_message = self._transformer.transform(message) offset = message.cursor.offset self._ack_set_tracker.track(offset) @@ -156,9 +157,6 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): self._looper_future.cancel() - try: - await self._looper_future - except asyncio.CancelledError: - pass + await wait_ignore_cancelled(self._looper_future) await self._underlying.__aexit__(exc_type, exc_value, traceback) await self._ack_set_tracker.__aexit__(exc_type, exc_value, traceback) diff --git a/google/cloud/pubsublite/internal/wait_ignore_cancelled.py b/google/cloud/pubsublite/internal/wait_ignore_cancelled.py index f2ead03a..2f3d2946 100644 --- a/google/cloud/pubsublite/internal/wait_ignore_cancelled.py +++ b/google/cloud/pubsublite/internal/wait_ignore_cancelled.py @@ -7,3 +7,10 @@ async def wait_ignore_cancelled(awaitable: Awaitable): await awaitable except CancelledError: pass + + +async def wait_ignore_errors(awaitable: Awaitable): + try: + await awaitable + except: # noqa: E722 + pass diff --git a/google/cloud/pubsublite/internal/wire/assigner_impl.py b/google/cloud/pubsublite/internal/wire/assigner_impl.py index fe6962b7..acff74b4 100644 --- a/google/cloud/pubsublite/internal/wire/assigner_impl.py +++ b/google/cloud/pubsublite/internal/wire/assigner_impl.py @@ -2,6 +2,8 @@ from typing import Optional, Set from absl import logging + +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.assigner import Assigner from google.cloud.pubsublite.internal.wire.retrying_connection import ( RetryingConnection, @@ -62,27 +64,24 @@ def _start_receiver(self): async def _stop_receiver(self): if self._receiver: self._receiver.cancel() - await self._receiver + await wait_ignore_errors(self._receiver) self._receiver = None async def _receive_loop(self): - try: - while True: - response = await self._connection.read() - if self._outstanding_assignment or not self._new_assignment.empty(): - self._connection.fail( - FailedPrecondition( - "Received a duplicate assignment on the stream while one was outstanding." - ) + while True: + response = await self._connection.read() + if self._outstanding_assignment or not self._new_assignment.empty(): + self._connection.fail( + FailedPrecondition( + "Received a duplicate assignment on the stream while one was outstanding." ) - return - self._outstanding_assignment = True - partitions = set() - for partition in response.partitions: - partitions.add(Partition(partition)) - self._new_assignment.put_nowait(partitions) - except (asyncio.CancelledError, GoogleAPICallError): - return + ) + return + self._outstanding_assignment = True + partitions = set() + for partition in response.partitions: + partitions.add(Partition(partition)) + self._new_assignment.put_nowait(partitions) async def __aexit__(self, exc_type, exc_val, exc_tb): await self._stop_receiver() diff --git a/google/cloud/pubsublite/internal/wire/committer_impl.py b/google/cloud/pubsublite/internal/wire/committer_impl.py index d4f8d641..6647c752 100644 --- a/google/cloud/pubsublite/internal/wire/committer_impl.py +++ b/google/cloud/pubsublite/internal/wire/committer_impl.py @@ -3,6 +3,7 @@ from absl import logging +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.committer import Committer from google.cloud.pubsublite.internal.wire.retrying_connection import ( RetryingConnection, @@ -75,11 +76,11 @@ def _start_loopers(self): async def _stop_loopers(self): if self._receiver: self._receiver.cancel() - await self._receiver + await wait_ignore_errors(self._receiver) self._receiver = None if self._flusher: self._flusher.cancel() - await self._flusher + await wait_ignore_errors(self._flusher) self._flusher = None def _handle_response(self, response: StreamingCommitCursorResponse): @@ -101,20 +102,14 @@ def _handle_response(self, response: StreamingCommitCursorResponse): item.response_future.set_result(None) async def _receive_loop(self): - try: - while True: - response = await self._connection.read() - self._handle_response(response) - except (asyncio.CancelledError, GoogleAPICallError): - return + while True: + response = await self._connection.read() + self._handle_response(response) async def _flush_loop(self): - try: - while True: - await asyncio.sleep(self._flush_seconds) - await self._flush() - except asyncio.CancelledError: - return + while True: + await asyncio.sleep(self._flush_seconds) + await self._flush() async def __aexit__(self, exc_type, exc_val, exc_tb): await self._stop_loopers() diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index dcaa9467..6bcb1125 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -3,9 +3,24 @@ from google.api_core.exceptions import GoogleAPICallError +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors + T = TypeVar("T") +class _TaskWithCleanup: + def __init__(self, a: Awaitable): + self._task = asyncio.ensure_future(a) + + async def __aenter__(self): + return self._task + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if not self._task.done(): + self._task.cancel() + await wait_ignore_errors(self._task) + + class PermanentFailable: """A class that can experience permanent failures, with helpers for forwarding these to client actions.""" @@ -21,14 +36,6 @@ def _failure_task(self) -> asyncio.Future: self._maybe_failure_task = asyncio.Future() return self._maybe_failure_task - @staticmethod - async def _fail_client_task(task: asyncio.Future): - task.cancel() - try: - await task - except: # noqa: E722 intentionally broad except clause - pass - async def await_unless_failed(self, awaitable: Awaitable[T]) -> T: """ Await the awaitable, unless fail() is called first. @@ -38,18 +45,15 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T: Returns: The result of the awaitable Raises: The permanent error if fail() is called or the awaitable raises one. """ - - task = asyncio.ensure_future(awaitable) - if self._failure_task.done(): - await self._fail_client_task(task) + async with _TaskWithCleanup(awaitable) as task: + if self._failure_task.done(): + raise self._failure_task.exception() + done, _ = await asyncio.wait( + [task, self._failure_task], return_when=asyncio.FIRST_COMPLETED + ) + if task in done: + return await task raise self._failure_task.exception() - done, _ = await asyncio.wait( - [task, self._failure_task], return_when=asyncio.FIRST_COMPLETED - ) - if task in done: - return await task - await self._fail_client_task(task) - raise self._failure_task.exception() async def run_poller(self, poll_action: Callable[[], Awaitable[None]]): """ diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index 2d543253..c766ea8f 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -1,8 +1,10 @@ import asyncio +from asyncio import Future -from typing import Awaitable, Optional +from typing import Optional from google.api_core.exceptions import GoogleAPICallError, Cancelled from google.cloud.pubsublite.internal.status_codes import is_retryable +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.connection_reinitializer import ( ConnectionReinitializer, ) @@ -101,20 +103,26 @@ async def _run_loop(self): print(e) async def _loop_connection(self, connection: Connection[Request, Response]): - read_task: Awaitable[Response] = asyncio.ensure_future(connection.read()) - write_task: Awaitable[WorkItem[Request]] = asyncio.ensure_future( + read_task: "Future[Response]" = asyncio.ensure_future(connection.read()) + write_task: "Future[WorkItem[Request]]" = asyncio.ensure_future( self._write_queue.get() ) - while True: - done, _ = await asyncio.wait( - [write_task, read_task], return_when=asyncio.FIRST_COMPLETED - ) - if write_task in done: - await self._handle_write(connection, await write_task) - write_task = asyncio.ensure_future(self._write_queue.get()) - if read_task in done: - await self._read_queue.put(await read_task) - read_task = asyncio.ensure_future(connection.read()) + try: + while True: + done, _ = await asyncio.wait( + [write_task, read_task], return_when=asyncio.FIRST_COMPLETED + ) + if write_task in done: + await self._handle_write(connection, await write_task) + write_task = asyncio.ensure_future(self._write_queue.get()) + if read_task in done: + await self._read_queue.put(await read_task) + read_task = asyncio.ensure_future(connection.read()) + finally: + read_task.cancel() + write_task.cancel() + await wait_ignore_errors(read_task) + await wait_ignore_errors(write_task) @staticmethod async def _handle_write( diff --git a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py index 80596502..bc48fd6a 100644 --- a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py +++ b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py @@ -4,6 +4,7 @@ from absl import logging from google.cloud.pubsub_v1.types import BatchSettings +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.publisher import Publisher from google.cloud.pubsublite.internal.wire.retrying_connection import ( RetryingConnection, @@ -81,11 +82,11 @@ def _start_loopers(self): async def _stop_loopers(self): if self._receiver: self._receiver.cancel() - await self._receiver + await wait_ignore_errors(self._receiver) self._receiver = None if self._flusher: self._flusher.cancel() - await self._flusher + await wait_ignore_errors(self._flusher) self._flusher = None def _handle_response(self, response: PublishResponse): @@ -108,20 +109,14 @@ def _handle_response(self, response: PublishResponse): next_offset += 1 async def _receive_loop(self): - try: - while True: - response = await self._connection.read() - self._handle_response(response) - except (asyncio.CancelledError, GoogleAPICallError): - return + while True: + response = await self._connection.read() + self._handle_response(response) async def _flush_loop(self): - try: - while True: - await asyncio.sleep(self._batching_settings.max_latency) - await self._flush() - except asyncio.CancelledError: - return + while True: + await asyncio.sleep(self._batching_settings.max_latency) + await self._flush() async def __aexit__(self, exc_type, exc_val, exc_tb): if self._connection.error(): diff --git a/google/cloud/pubsublite/internal/wire/subscriber_impl.py b/google/cloud/pubsublite/internal/wire/subscriber_impl.py index f41ec628..8c2df143 100644 --- a/google/cloud/pubsublite/internal/wire/subscriber_impl.py +++ b/google/cloud/pubsublite/internal/wire/subscriber_impl.py @@ -3,6 +3,7 @@ from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.connection import ( Connection, ConnectionFactory, @@ -72,11 +73,11 @@ def _start_loopers(self): async def _stop_loopers(self): if self._receiver: self._receiver.cancel() - await self._receiver + await wait_ignore_errors(self._receiver) self._receiver = None if self._flusher: self._flusher.cancel() - await self._flusher + await wait_ignore_errors(self._flusher) self._flusher = None def _handle_response(self, response: SubscribeResponse): @@ -107,12 +108,9 @@ def _handle_response(self, response: SubscribeResponse): self._message_queue.put_nowait(message) async def _receive_loop(self): - try: - while True: - response = await self._connection.read() - self._handle_response(response) - except (asyncio.CancelledError, GoogleAPICallError): - return + while True: + response = await self._connection.read() + self._handle_response(response) async def _try_send_tokens(self): req = self._outstanding_flow_control.release_pending_request() @@ -125,12 +123,9 @@ async def _try_send_tokens(self): pass async def _flush_loop(self): - try: - while True: - await asyncio.sleep(self._token_flush_seconds) - await self._try_send_tokens() - except asyncio.CancelledError: - return + while True: + await asyncio.sleep(self._token_flush_seconds) + await self._try_send_tokens() async def __aexit__(self, exc_type, exc_val, exc_tb): await self._stop_loopers()