Skip to content

Commit

Permalink
fix: ensure ack() doesn't wait on stream messages (#234)
Browse files Browse the repository at this point in the history
* fix: ensure ack() doesn't wait on stream messages

also fix error propagation to streaming pull future

* fix: ensure ack() doesn't wait on stream messages

also fix error propagation to streaming pull future

* fix: ensure ack() doesn't wait on stream messages

also fix error propagation to streaming pull future

* fix: remove debug log
  • Loading branch information
dpcollins-google committed Sep 13, 2021
1 parent 435ad27 commit 03db702
Show file tree
Hide file tree
Showing 15 changed files with 146 additions and 178 deletions.
Expand Up @@ -35,7 +35,7 @@ def track(self, offset: int):
"""

@abstractmethod
async def ack(self, offset: int):
def ack(self, offset: int):
"""
Acknowledge the message with the provided offset. The offset must have previously been tracked.
Expand Down
Expand Up @@ -17,6 +17,7 @@
from typing import Optional

from google.api_core.exceptions import FailedPrecondition

from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.internal.wire.committer import Committer
from google.cloud.pubsublite_v1 import Cursor
Expand All @@ -43,9 +44,7 @@ def track(self, offset: int):
)
self._receipts.append(offset)

async def ack(self, offset: int):
# Note: put_nowait is used here and below to ensure that the below logic is executed without yielding
# to another coroutine in the event loop. The queue is unbounded so it will never throw.
def ack(self, offset: int):
self._acks.put_nowait(offset)
prefix_acked_offset: Optional[int] = None
while len(self._receipts) != 0 and not self._acks.empty():
Expand All @@ -60,7 +59,7 @@ async def ack(self, offset: int):
if prefix_acked_offset is None:
return
# Convert from last acked to first unacked.
await self._committer.commit(Cursor(offset=prefix_acked_offset + 1))
self._committer.commit(Cursor(offset=prefix_acked_offset + 1))

async def clear_and_commit(self):
self._receipts.clear()
Expand Down
Expand Up @@ -20,7 +20,10 @@
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
AsyncSingleSubscriber,
)
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wait_ignore_cancelled import (
wait_ignore_cancelled,
wait_ignore_errors,
)
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.types import Partition
Expand Down Expand Up @@ -100,8 +103,10 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_value, traceback):
self._assign_poller.cancel()
await wait_ignore_cancelled(self._assign_poller)
await self._assigner.__aexit__(exc_type, exc_value, traceback)
await wait_ignore_errors(self._assign_poller)
await wait_ignore_errors(
self._assigner.__aexit__(exc_type, exc_value, traceback)
)
for running in self._subscribers.values():
await self._stop_subscriber(running)
await wait_ignore_errors(self._stop_subscriber(running))
pass
Expand Up @@ -126,8 +126,8 @@ async def read(self) -> List[Message]:
self.fail(e)
raise e

async def _handle_ack(self, message: requests.AckRequest):
await self._underlying.allow_flow(
def _handle_ack(self, message: requests.AckRequest):
self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=1,
allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes,
Expand All @@ -138,7 +138,7 @@ async def _handle_ack(self, message: requests.AckRequest):
ack_id = _AckId.parse(message.ack_id)
if ack_id.generation == self._ack_generation_id:
try:
await self._ack_set_tracker.ack(ack_id.offset)
self._ack_set_tracker.ack(ack_id.offset)
except GoogleAPICallError as e:
self.fail(e)

Expand Down Expand Up @@ -179,7 +179,7 @@ async def _handle_queue_message(
)
)
elif isinstance(message, requests.AckRequest):
await self._handle_ack(message)
self._handle_ack(message)
else:
self._handle_nack(message)

Expand All @@ -198,7 +198,7 @@ async def __aenter__(self):
await self._ack_set_tracker.__aenter__()
await self._underlying.__aenter__()
self._looper_future = asyncio.ensure_future(self._looper())
await self._underlying.allow_flow(
self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=self._flow_control_settings.messages_outstanding,
allowed_bytes=self._flow_control_settings.bytes_outstanding,
Expand Down
17 changes: 11 additions & 6 deletions google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py
Expand Up @@ -17,6 +17,9 @@
from concurrent.futures.thread import ThreadPoolExecutor
from typing import ContextManager, Optional
from google.api_core.exceptions import GoogleAPICallError
from functools import partial

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.cloudpubsub.internal.managed_event_loop import (
ManagedEventLoop,
)
Expand Down Expand Up @@ -86,8 +89,8 @@ async def _poller(self):
while True:
batch = await self._underlying.read()
self._unowned_executor.map(self._callback, batch)
except GoogleAPICallError as e: # noqa: F841 Flake8 thinks e is unused
self._unowned_executor.submit(lambda: self._fail(e)) # noqa: F821
except GoogleAPICallError as e:
self._unowned_executor.submit(partial(self._fail, e))

def __enter__(self):
assert self._close_callback is not None
Expand All @@ -97,13 +100,15 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self._poller_future.cancel()
try:
self._poller_future.cancel()
self._poller_future.result()
except concurrent.futures.CancelledError:
self._poller_future.result() # Ignore error.
except: # noqa: E722
pass
self._event_loop.submit(
self._underlying.__aexit__(exc_type, exc_value, traceback)
wait_ignore_errors(
self._underlying.__aexit__(exc_type, exc_value, traceback)
)
).result()
self._event_loop.__exit__(exc_type, exc_value, traceback)
assert self._close_callback is not None
Expand Down
8 changes: 7 additions & 1 deletion google/cloud/pubsublite/internal/wire/committer.py
Expand Up @@ -24,7 +24,13 @@ class Committer(AsyncContextManager, metaclass=ABCMeta):
"""

@abstractmethod
async def commit(self, cursor: Cursor) -> None:
def commit(self, cursor: Cursor) -> None:
"""
Start the commit for a cursor.
Raises:
GoogleAPICallError: When the committer terminates in failure.
"""
pass

@abstractmethod
Expand Down
55 changes: 18 additions & 37 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -28,14 +28,12 @@
ConnectionReinitializer,
)
from google.cloud.pubsublite.internal.wire.connection import Connection
from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher
from google.cloud.pubsublite_v1 import Cursor
from google.cloud.pubsublite_v1.types import (
StreamingCommitCursorRequest,
StreamingCommitCursorResponse,
InitialCommitCursorRequest,
)
from google.cloud.pubsublite.internal.wire.work_item import WorkItem


_LOGGER = logging.getLogger(__name__)
Expand All @@ -53,9 +51,8 @@ class CommitterImpl(
StreamingCommitCursorRequest, StreamingCommitCursorResponse
]

_batcher: SerialBatcher[Cursor, None]

_outstanding_commits: List[List[WorkItem[Cursor, None]]]
_next_to_commit: Optional[Cursor]
_outstanding_commits: List[Cursor]

_receiver: Optional[asyncio.Future]
_flusher: Optional[asyncio.Future]
Expand All @@ -72,7 +69,7 @@ def __init__(
self._initial = initial
self._flush_seconds = flush_seconds
self._connection = RetryingConnection(factory, self)
self._batcher = SerialBatcher()
self._next_to_commit = None
self._outstanding_commits = []
self._receiver = None
self._flusher = None
Expand Down Expand Up @@ -113,9 +110,7 @@ def _handle_response(self, response: StreamingCommitCursorResponse):
)
)
for _ in range(response.commit.acknowledged_commits):
batch = self._outstanding_commits.pop(0)
for item in batch:
item.response_future.set_result(None)
self._outstanding_commits.pop(0)
if len(self._outstanding_commits) == 0:
self._empty.set()

Expand All @@ -131,39 +126,31 @@ async def _flush_loop(self):

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
if self._connection.error():
self._fail_if_retrying_failed()
else:
if not self._connection.error():
await self._flush()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

def _fail_if_retrying_failed(self):
if self._connection.error():
for batch in self._outstanding_commits:
for item in batch:
item.response_future.set_exception(self._connection.error())

async def _flush(self):
batch = self._batcher.flush()
if not batch:
if self._next_to_commit is None:
return
self._outstanding_commits.append(batch)
self._empty.clear()
req = StreamingCommitCursorRequest()
req.commit.cursor = batch[-1].request
req.commit.cursor = self._next_to_commit
self._outstanding_commits.append(self._next_to_commit)
self._next_to_commit = None
self._empty.clear()
try:
await self._connection.write(req)
except GoogleAPICallError as e:
_LOGGER.debug(f"Failed commit on stream: {e}")
self._fail_if_retrying_failed()

async def wait_until_empty(self):
await self._flush()
await self._connection.await_unless_failed(self._empty.wait())

async def commit(self, cursor: Cursor) -> None:
future = self._batcher.add(cursor)
await future
def commit(self, cursor: Cursor) -> None:
if self._connection.error():
raise self._connection.error()
self._next_to_commit = cursor

async def reinitialize(
self,
Expand All @@ -181,14 +168,8 @@ async def reinitialize(
"Received an invalid initial response on the publish stream."
)
)
if self._outstanding_commits:
# Roll up outstanding commits
rollup: List[WorkItem[Cursor, None]] = []
for batch in self._outstanding_commits:
for item in batch:
rollup.append(item)
self._outstanding_commits = [rollup]
req = StreamingCommitCursorRequest()
req.commit.cursor = rollup[-1].request
await connection.write(req)
if self._next_to_commit is None:
if self._outstanding_commits:
self._next_to_commit = self._outstanding_commits[-1]
self._outstanding_commits = []
self._start_loopers()
22 changes: 6 additions & 16 deletions google/cloud/pubsublite/internal/wire/flow_control_batcher.py
Expand Up @@ -26,11 +26,14 @@ class _AggregateRequest:
def __init__(self):
self._request = FlowControlRequest.meta.pb()

def __add__(self, other: FlowControlRequest.meta.pb):
self._request.allowed_bytes = self._request.allowed_bytes + other.allowed_bytes
def __add__(self, other: FlowControlRequest):
other_pb = other._pb
self._request.allowed_bytes = (
self._request.allowed_bytes + other_pb.allowed_bytes
)
self._request.allowed_bytes = min(self._request.allowed_bytes, _MAX_INT64)
self._request.allowed_messages = (
self._request.allowed_messages + other.allowed_messages
self._request.allowed_messages + other_pb.allowed_messages
)
self._request.allowed_messages = min(self._request.allowed_messages, _MAX_INT64)
return self
Expand Down Expand Up @@ -77,16 +80,3 @@ def release_pending_request(self) -> Optional[FlowControlRequest]:
request = self._pending_tokens
self._pending_tokens = _AggregateRequest()
return request.to_optional()

def should_expedite(self):
pending_request = self._pending_tokens._request
client_request = self._client_tokens._request
if _exceeds_expedite_ratio(
pending_request.allowed_bytes, client_request.allowed_bytes
):
return True
if _exceeds_expedite_ratio(
pending_request.allowed_messages, client_request.allowed_messages
):
return True
return False
2 changes: 1 addition & 1 deletion google/cloud/pubsublite/internal/wire/subscriber.py
Expand Up @@ -36,7 +36,7 @@ async def read(self) -> List[SequencedMessage.meta.pb]:
raise NotImplementedError()

@abstractmethod
async def allow_flow(self, request: FlowControlRequest):
def allow_flow(self, request: FlowControlRequest):
"""
Allow an additional amount of messages and bytes to be sent to this client.
"""
Expand Down
7 changes: 1 addition & 6 deletions google/cloud/pubsublite/internal/wire/subscriber_impl.py
Expand Up @@ -201,10 +201,5 @@ async def reinitialize(
async def read(self) -> List[SequencedMessage.meta.pb]:
return await self._connection.await_unless_failed(self._message_queue.get())

async def allow_flow(self, request: FlowControlRequest):
def allow_flow(self, request: FlowControlRequest):
self._outstanding_flow_control.add(request)
if (
not self._reinitializing
and self._outstanding_flow_control.should_expedite()
):
await self._try_send_tokens()
Expand Up @@ -49,17 +49,17 @@ async def test_track_and_aggregate_acks(committer, tracker: AckSetTracker):
tracker.track(offset=7)

committer.commit.assert_has_calls([])
await tracker.ack(offset=3)
tracker.ack(offset=3)
committer.commit.assert_has_calls([])
await tracker.ack(offset=1)
tracker.ack(offset=1)
committer.commit.assert_has_calls([call(Cursor(offset=4))])
await tracker.ack(offset=5)
tracker.ack(offset=5)
committer.commit.assert_has_calls(
[call(Cursor(offset=4)), call(Cursor(offset=6))]
)

tracker.track(offset=8)
await tracker.ack(offset=7)
tracker.ack(offset=7)
committer.commit.assert_has_calls(
[call(Cursor(offset=4)), call(Cursor(offset=6)), call(Cursor(offset=8))]
)
Expand All @@ -74,14 +74,14 @@ async def test_clear_and_commit(committer, tracker: AckSetTracker):

with pytest.raises(FailedPrecondition):
tracker.track(offset=1)
await tracker.ack(offset=5)
tracker.ack(offset=5)
committer.commit.assert_has_calls([])

await tracker.clear_and_commit()
committer.wait_until_empty.assert_called_once()

# After clearing, it should be possible to track earlier offsets.
tracker.track(offset=1)
await tracker.ack(offset=1)
tracker.ack(offset=1)
committer.commit.assert_has_calls([call(Cursor(offset=2))])
committer.__aexit__.assert_called_once()

0 comments on commit 03db702

Please sign in to comment.