Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure tasks are always awaited to remove shutdown errors #57

Merged
merged 2 commits into from Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -73,8 +73,9 @@ async def _assign_action(self):
for partition in added_partitions:
await self._start_subscriber(partition)
for partition in removed_partitions:
await self._stop_subscriber(self._subscribers[partition])
subscriber = self._subscribers[partition]
del self._subscribers[partition]
await self._stop_subscriber(subscriber)

async def __aenter__(self):
self._messages = Queue()
Expand All @@ -89,3 +90,4 @@ async def __aexit__(self, exc_type, exc_value, traceback):
await self._assigner.__aexit__(exc_type, exc_value, traceback)
for running in self._subscribers.values():
await self._stop_subscriber(running)
pass
Expand Up @@ -6,6 +6,7 @@
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.types import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer
Expand Down Expand Up @@ -54,10 +55,10 @@ def __init__(
self._messages_by_offset = {}

async def read(self) -> Message:
message: SequencedMessage = await self.await_unless_failed(
self._underlying.read()
)
try:
message: SequencedMessage = await self.await_unless_failed(
self._underlying.read()
)
cps_message = self._transformer.transform(message)
offset = message.cursor.offset
self._ack_set_tracker.track(offset)
Expand Down Expand Up @@ -156,9 +157,6 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_value, traceback):
self._looper_future.cancel()
try:
await self._looper_future
except asyncio.CancelledError:
pass
await wait_ignore_cancelled(self._looper_future)
await self._underlying.__aexit__(exc_type, exc_value, traceback)
await self._ack_set_tracker.__aexit__(exc_type, exc_value, traceback)
7 changes: 7 additions & 0 deletions google/cloud/pubsublite/internal/wait_ignore_cancelled.py
Expand Up @@ -7,3 +7,10 @@ async def wait_ignore_cancelled(awaitable: Awaitable):
await awaitable
except CancelledError:
pass


async def wait_ignore_errors(awaitable: Awaitable):
try:
await awaitable
except: # noqa: E722
pass
33 changes: 16 additions & 17 deletions google/cloud/pubsublite/internal/wire/assigner_impl.py
Expand Up @@ -2,6 +2,8 @@
from typing import Optional, Set

from absl import logging

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.retrying_connection import (
RetryingConnection,
Expand Down Expand Up @@ -62,27 +64,24 @@ def _start_receiver(self):
async def _stop_receiver(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
if self._outstanding_assignment or not self._new_assignment.empty():
self._connection.fail(
FailedPrecondition(
"Received a duplicate assignment on the stream while one was outstanding."
)
while True:
response = await self._connection.read()
if self._outstanding_assignment or not self._new_assignment.empty():
self._connection.fail(
FailedPrecondition(
"Received a duplicate assignment on the stream while one was outstanding."
)
return
self._outstanding_assignment = True
partitions = set()
for partition in response.partitions:
partitions.add(Partition(partition))
self._new_assignment.put_nowait(partitions)
except (asyncio.CancelledError, GoogleAPICallError):
return
)
return
self._outstanding_assignment = True
partitions = set()
for partition in response.partitions:
partitions.add(Partition(partition))
self._new_assignment.put_nowait(partitions)

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_receiver()
Expand Down
23 changes: 9 additions & 14 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -3,6 +3,7 @@

from absl import logging

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.committer import Committer
from google.cloud.pubsublite.internal.wire.retrying_connection import (
RetryingConnection,
Expand Down Expand Up @@ -75,11 +76,11 @@ def _start_loopers(self):
async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
await wait_ignore_errors(self._flusher)
self._flusher = None

def _handle_response(self, response: StreamingCommitCursorResponse):
Expand All @@ -101,20 +102,14 @@ def _handle_response(self, response: StreamingCommitCursorResponse):
item.response_future.set_result(None)

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except (asyncio.CancelledError, GoogleAPICallError):
return
while True:
response = await self._connection.read()
self._handle_response(response)

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._flush_seconds)
await self._flush()
except asyncio.CancelledError:
return
while True:
await asyncio.sleep(self._flush_seconds)
await self._flush()

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
Expand Down
42 changes: 23 additions & 19 deletions google/cloud/pubsublite/internal/wire/permanent_failable.py
Expand Up @@ -3,9 +3,24 @@

from google.api_core.exceptions import GoogleAPICallError

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors

T = TypeVar("T")


class _TaskWithCleanup:
def __init__(self, a: Awaitable):
self._task = asyncio.ensure_future(a)

async def __aenter__(self):
return self._task

async def __aexit__(self, exc_type, exc_val, exc_tb):
if not self._task.done():
self._task.cancel()
await wait_ignore_errors(self._task)


class PermanentFailable:
"""A class that can experience permanent failures, with helpers for forwarding these to client actions."""

Expand All @@ -21,14 +36,6 @@ def _failure_task(self) -> asyncio.Future:
self._maybe_failure_task = asyncio.Future()
return self._maybe_failure_task

@staticmethod
async def _fail_client_task(task: asyncio.Future):
task.cancel()
try:
await task
except: # noqa: E722 intentionally broad except clause
pass

async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
"""
Await the awaitable, unless fail() is called first.
Expand All @@ -38,18 +45,15 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
Returns: The result of the awaitable
Raises: The permanent error if fail() is called or the awaitable raises one.
"""

task = asyncio.ensure_future(awaitable)
if self._failure_task.done():
await self._fail_client_task(task)
async with _TaskWithCleanup(awaitable) as task:
if self._failure_task.done():
raise self._failure_task.exception()
done, _ = await asyncio.wait(
[task, self._failure_task], return_when=asyncio.FIRST_COMPLETED
)
if task in done:
return await task
raise self._failure_task.exception()
done, _ = await asyncio.wait(
[task, self._failure_task], return_when=asyncio.FIRST_COMPLETED
)
if task in done:
return await task
await self._fail_client_task(task)
raise self._failure_task.exception()

async def run_poller(self, poll_action: Callable[[], Awaitable[None]]):
"""
Expand Down
34 changes: 21 additions & 13 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
@@ -1,8 +1,10 @@
import asyncio
from asyncio import Future

from typing import Awaitable, Optional
from typing import Optional
from google.api_core.exceptions import GoogleAPICallError, Cancelled
from google.cloud.pubsublite.internal.status_codes import is_retryable
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.connection_reinitializer import (
ConnectionReinitializer,
)
Expand Down Expand Up @@ -101,20 +103,26 @@ async def _run_loop(self):
print(e)

async def _loop_connection(self, connection: Connection[Request, Response]):
read_task: Awaitable[Response] = asyncio.ensure_future(connection.read())
write_task: Awaitable[WorkItem[Request]] = asyncio.ensure_future(
read_task: "Future[Response]" = asyncio.ensure_future(connection.read())
write_task: "Future[WorkItem[Request]]" = asyncio.ensure_future(
self._write_queue.get()
)
while True:
done, _ = await asyncio.wait(
[write_task, read_task], return_when=asyncio.FIRST_COMPLETED
)
if write_task in done:
await self._handle_write(connection, await write_task)
write_task = asyncio.ensure_future(self._write_queue.get())
if read_task in done:
await self._read_queue.put(await read_task)
read_task = asyncio.ensure_future(connection.read())
try:
while True:
done, _ = await asyncio.wait(
[write_task, read_task], return_when=asyncio.FIRST_COMPLETED
)
if write_task in done:
await self._handle_write(connection, await write_task)
write_task = asyncio.ensure_future(self._write_queue.get())
if read_task in done:
await self._read_queue.put(await read_task)
read_task = asyncio.ensure_future(connection.read())
finally:
read_task.cancel()
write_task.cancel()
await wait_ignore_errors(read_task)
await wait_ignore_errors(write_task)

@staticmethod
async def _handle_write(
Expand Down
Expand Up @@ -4,6 +4,7 @@
from absl import logging
from google.cloud.pubsub_v1.types import BatchSettings

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.retrying_connection import (
RetryingConnection,
Expand Down Expand Up @@ -81,11 +82,11 @@ def _start_loopers(self):
async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
await wait_ignore_errors(self._flusher)
self._flusher = None

def _handle_response(self, response: PublishResponse):
Expand All @@ -108,20 +109,14 @@ def _handle_response(self, response: PublishResponse):
next_offset += 1

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except (asyncio.CancelledError, GoogleAPICallError):
return
while True:
response = await self._connection.read()
self._handle_response(response)

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._batching_settings.max_latency)
await self._flush()
except asyncio.CancelledError:
return
while True:
await asyncio.sleep(self._batching_settings.max_latency)
await self._flush()

async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._connection.error():
Expand Down
23 changes: 9 additions & 14 deletions google/cloud/pubsublite/internal/wire/subscriber_impl.py
Expand Up @@ -3,6 +3,7 @@

from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.connection import (
Connection,
ConnectionFactory,
Expand Down Expand Up @@ -72,11 +73,11 @@ def _start_loopers(self):
async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
await wait_ignore_errors(self._receiver)
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
await wait_ignore_errors(self._flusher)
self._flusher = None

def _handle_response(self, response: SubscribeResponse):
Expand Down Expand Up @@ -107,12 +108,9 @@ def _handle_response(self, response: SubscribeResponse):
self._message_queue.put_nowait(message)

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except (asyncio.CancelledError, GoogleAPICallError):
return
while True:
response = await self._connection.read()
self._handle_response(response)

async def _try_send_tokens(self):
req = self._outstanding_flow_control.release_pending_request()
Expand All @@ -125,12 +123,9 @@ async def _try_send_tokens(self):
pass

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._token_flush_seconds)
await self._try_send_tokens()
except asyncio.CancelledError:
return
while True:
await asyncio.sleep(self._token_flush_seconds)
await self._try_send_tokens()

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
Expand Down