Skip to content

Commit

Permalink
fix: Ensure tasks are always awaited to remove shutdown errors (#57)
Browse files Browse the repository at this point in the history
* fix: Ensure tasks are always awaited to remove shutdown errors

* fix: Fix lint errors and missing asynccontextmanager in 3.6
  • Loading branch information
dpcollins-google committed Nov 5, 2020
1 parent 7cf02ae commit 7735d2f
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 99 deletions.
Expand Up @@ -73,8 +73,9 @@ async def _assign_action(self):
for partition in added_partitions:
await self._start_subscriber(partition)
for partition in removed_partitions:
await self._stop_subscriber(self._subscribers[partition])
subscriber = self._subscribers[partition]
del self._subscribers[partition]
await self._stop_subscriber(subscriber)

async def __aenter__(self):
self._messages = Queue()
Expand All @@ -89,3 +90,4 @@ async def __aexit__(self, exc_type, exc_value, traceback):
await self._assigner.__aexit__(exc_type, exc_value, traceback)
for running in self._subscribers.values():
await self._stop_subscriber(running)
pass
Expand Up @@ -6,6 +6,7 @@
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.types import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer
Expand Down Expand Up @@ -54,10 +55,10 @@ def __init__(
self._messages_by_offset = {}

async def read(self) -> Message:
message: SequencedMessage = await self.await_unless_failed(
self._underlying.read()
)
try:
message: SequencedMessage = await self.await_unless_failed(
self._underlying.read()
)
cps_message = self._transformer.transform(message)
offset = message.cursor.offset
self._ack_set_tracker.track(offset)
Expand Down Expand Up @@ -156,9 +157,6 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_value, traceback):
self._looper_future.cancel()
try:
await self._looper_future
except asyncio.CancelledError:
pass
await wait_ignore_cancelled(self._looper_future)
await self._underlying.__aexit__(exc_type, exc_value, traceback)
await self._ack_set_tracker.__aexit__(exc_type, exc_value, traceback)
7 changes: 7 additions & 0 deletions google/cloud/pubsublite/internal/wait_ignore_cancelled.py
Expand Up @@ -7,3 +7,10 @@ async def wait_ignore_cancelled(awaitable: Awaitable):
await awaitable
except CancelledError:
pass


async def wait_ignore_errors(awaitable: Awaitable):
try:
await awaitable
except: # noqa: E722
pass
33 changes: 16 additions & 17 deletions google/cloud/pubsublite/internal/wire/assigner_impl.py
Expand Up @@ -2,6 +2,8 @@
from typing import Optional, Set

from absl import logging

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.retrying_connection import (
RetryingConnection,
Expand Down Expand Up @@ -62,27 +64,24 @@ def _start_receiver(self):
async def _stop_receiver(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
if self._outstanding_assignment or not self._new_assignment.empty():
self._connection.fail(
FailedPrecondition(
"Received a duplicate assignment on the stream while one was outstanding."
)
while True:
response = await self._connection.read()
if self._outstanding_assignment or not self._new_assignment.empty():
self._connection.fail(
FailedPrecondition(
"Received a duplicate assignment on the stream while one was outstanding."
)
return
self._outstanding_assignment = True
partitions = set()
for partition in response.partitions:
partitions.add(Partition(partition))
self._new_assignment.put_nowait(partitions)
except (asyncio.CancelledError, GoogleAPICallError):
return
)
return
self._outstanding_assignment = True
partitions = set()
for partition in response.partitions:
partitions.add(Partition(partition))
self._new_assignment.put_nowait(partitions)

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_receiver()
Expand Down
23 changes: 9 additions & 14 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -3,6 +3,7 @@

from absl import logging

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.committer import Committer
from google.cloud.pubsublite.internal.wire.retrying_connection import (
RetryingConnection,
Expand Down Expand Up @@ -75,11 +76,11 @@ def _start_loopers(self):
async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
await wait_ignore_errors(self._flusher)
self._flusher = None

def _handle_response(self, response: StreamingCommitCursorResponse):
Expand All @@ -101,20 +102,14 @@ def _handle_response(self, response: StreamingCommitCursorResponse):
item.response_future.set_result(None)

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except (asyncio.CancelledError, GoogleAPICallError):
return
while True:
response = await self._connection.read()
self._handle_response(response)

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._flush_seconds)
await self._flush()
except asyncio.CancelledError:
return
while True:
await asyncio.sleep(self._flush_seconds)
await self._flush()

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
Expand Down
42 changes: 23 additions & 19 deletions google/cloud/pubsublite/internal/wire/permanent_failable.py
Expand Up @@ -3,9 +3,24 @@

from google.api_core.exceptions import GoogleAPICallError

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors

T = TypeVar("T")


class _TaskWithCleanup:
def __init__(self, a: Awaitable):
self._task = asyncio.ensure_future(a)

async def __aenter__(self):
return self._task

async def __aexit__(self, exc_type, exc_val, exc_tb):
if not self._task.done():
self._task.cancel()
await wait_ignore_errors(self._task)


class PermanentFailable:
"""A class that can experience permanent failures, with helpers for forwarding these to client actions."""

Expand All @@ -21,14 +36,6 @@ def _failure_task(self) -> asyncio.Future:
self._maybe_failure_task = asyncio.Future()
return self._maybe_failure_task

@staticmethod
async def _fail_client_task(task: asyncio.Future):
task.cancel()
try:
await task
except: # noqa: E722 intentionally broad except clause
pass

async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
"""
Await the awaitable, unless fail() is called first.
Expand All @@ -38,18 +45,15 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
Returns: The result of the awaitable
Raises: The permanent error if fail() is called or the awaitable raises one.
"""

task = asyncio.ensure_future(awaitable)
if self._failure_task.done():
await self._fail_client_task(task)
async with _TaskWithCleanup(awaitable) as task:
if self._failure_task.done():
raise self._failure_task.exception()
done, _ = await asyncio.wait(
[task, self._failure_task], return_when=asyncio.FIRST_COMPLETED
)
if task in done:
return await task
raise self._failure_task.exception()
done, _ = await asyncio.wait(
[task, self._failure_task], return_when=asyncio.FIRST_COMPLETED
)
if task in done:
return await task
await self._fail_client_task(task)
raise self._failure_task.exception()

async def run_poller(self, poll_action: Callable[[], Awaitable[None]]):
"""
Expand Down
34 changes: 21 additions & 13 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
@@ -1,8 +1,10 @@
import asyncio
from asyncio import Future

from typing import Awaitable, Optional
from typing import Optional
from google.api_core.exceptions import GoogleAPICallError, Cancelled
from google.cloud.pubsublite.internal.status_codes import is_retryable
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.connection_reinitializer import (
ConnectionReinitializer,
)
Expand Down Expand Up @@ -101,20 +103,26 @@ async def _run_loop(self):
print(e)

async def _loop_connection(self, connection: Connection[Request, Response]):
read_task: Awaitable[Response] = asyncio.ensure_future(connection.read())
write_task: Awaitable[WorkItem[Request]] = asyncio.ensure_future(
read_task: "Future[Response]" = asyncio.ensure_future(connection.read())
write_task: "Future[WorkItem[Request]]" = asyncio.ensure_future(
self._write_queue.get()
)
while True:
done, _ = await asyncio.wait(
[write_task, read_task], return_when=asyncio.FIRST_COMPLETED
)
if write_task in done:
await self._handle_write(connection, await write_task)
write_task = asyncio.ensure_future(self._write_queue.get())
if read_task in done:
await self._read_queue.put(await read_task)
read_task = asyncio.ensure_future(connection.read())
try:
while True:
done, _ = await asyncio.wait(
[write_task, read_task], return_when=asyncio.FIRST_COMPLETED
)
if write_task in done:
await self._handle_write(connection, await write_task)
write_task = asyncio.ensure_future(self._write_queue.get())
if read_task in done:
await self._read_queue.put(await read_task)
read_task = asyncio.ensure_future(connection.read())
finally:
read_task.cancel()
write_task.cancel()
await wait_ignore_errors(read_task)
await wait_ignore_errors(write_task)

@staticmethod
async def _handle_write(
Expand Down
Expand Up @@ -4,6 +4,7 @@
from absl import logging
from google.cloud.pubsub_v1.types import BatchSettings

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.retrying_connection import (
RetryingConnection,
Expand Down Expand Up @@ -81,11 +82,11 @@ def _start_loopers(self):
async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
await wait_ignore_errors(self._flusher)
self._flusher = None

def _handle_response(self, response: PublishResponse):
Expand All @@ -108,20 +109,14 @@ def _handle_response(self, response: PublishResponse):
next_offset += 1

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except (asyncio.CancelledError, GoogleAPICallError):
return
while True:
response = await self._connection.read()
self._handle_response(response)

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._batching_settings.max_latency)
await self._flush()
except asyncio.CancelledError:
return
while True:
await asyncio.sleep(self._batching_settings.max_latency)
await self._flush()

async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._connection.error():
Expand Down
23 changes: 9 additions & 14 deletions google/cloud/pubsublite/internal/wire/subscriber_impl.py
Expand Up @@ -3,6 +3,7 @@

from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.connection import (
Connection,
ConnectionFactory,
Expand Down Expand Up @@ -72,11 +73,11 @@ def _start_loopers(self):
async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
await wait_ignore_errors(self._flusher)
self._flusher = None

def _handle_response(self, response: SubscribeResponse):
Expand Down Expand Up @@ -107,12 +108,9 @@ def _handle_response(self, response: SubscribeResponse):
self._message_queue.put_nowait(message)

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except (asyncio.CancelledError, GoogleAPICallError):
return
while True:
response = await self._connection.read()
self._handle_response(response)

async def _try_send_tokens(self):
req = self._outstanding_flow_control.release_pending_request()
Expand All @@ -125,12 +123,9 @@ async def _try_send_tokens(self):
pass

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._token_flush_seconds)
await self._try_send_tokens()
except asyncio.CancelledError:
return
while True:
await asyncio.sleep(self._token_flush_seconds)
await self._try_send_tokens()

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
Expand Down

0 comments on commit 7735d2f

Please sign in to comment.