Skip to content

Commit

Permalink
feat: Handle out of band seeks (#158)
Browse files Browse the repository at this point in the history
- Handles RESET signal (out of band seek notification) from the server.
- Sets the InitialSubscribeRequest.initial_location field when reconnecting subscribe streams.
  • Loading branch information
tmdiep committed Jun 11, 2021
1 parent edbd104 commit 77db700
Show file tree
Hide file tree
Showing 16 changed files with 596 additions and 67 deletions.
Expand Up @@ -45,3 +45,12 @@ async def ack(self, offset: int):
Returns:
GoogleAPICallError: On a commit failure.
"""

@abstractmethod
async def clear_and_commit(self):
"""
Discard all outstanding acks and wait for the commit offset to be acknowledged by the server.
Raises:
GoogleAPICallError: If the committer has shut down due to a permanent error.
"""
Expand Up @@ -62,6 +62,11 @@ async def ack(self, offset: int):
# Convert from last acked to first unacked.
await self._committer.commit(Cursor(offset=prefix_acked_offset + 1))

async def clear_and_commit(self):
self._receipts.clear()
self._acks = queue.PriorityQueue()
await self._committer.wait_until_empty()

async def __aenter__(self):
await self._committer.__aenter__()

Expand Down
22 changes: 14 additions & 8 deletions google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py
Expand Up @@ -52,6 +52,9 @@
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
from google.cloud.pubsublite.internal.wire.pubsub_context import pubsub_context
import google.cloud.pubsublite.internal.wire.subscriber_impl as wire_subscriber
from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import (
SubscriberResetHandler,
)
from google.cloud.pubsublite.types import Partition, SubscriptionPath
from google.cloud.pubsublite.internal.routing_metadata import (
subscription_routing_metadata,
Expand Down Expand Up @@ -131,13 +134,16 @@ def cursor_connection_factory(
requests, metadata=list(final_metadata.items())
)

subscriber = wire_subscriber.SubscriberImpl(
InitialSubscribeRequest(
subscription=str(subscription), partition=partition.value
),
_DEFAULT_FLUSH_SECONDS,
GapicConnectionFactory(subscribe_connection_factory),
)
def subscriber_factory(reset_handler: SubscriberResetHandler):
return wire_subscriber.SubscriberImpl(
InitialSubscribeRequest(
subscription=str(subscription), partition=partition.value
),
_DEFAULT_FLUSH_SECONDS,
GapicConnectionFactory(subscribe_connection_factory),
reset_handler,
)

committer = CommitterImpl(
InitialCommitCursorRequest(
subscription=str(subscription), partition=partition.value
Expand All @@ -147,7 +153,7 @@ def cursor_connection_factory(
)
ack_set_tracker = AckSetTrackerImpl(committer)
return SinglePartitionSingleSubscriber(
subscriber,
subscriber_factory,
flow_control_settings,
ack_set_tracker,
nack_handler,
Expand Down
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import asyncio
from typing import Union, Dict, NamedTuple
import json
from typing import Callable, Union, Dict, NamedTuple
import queue

from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError
Expand All @@ -30,6 +31,9 @@
)
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.internal.wire.subscriber import Subscriber
from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import (
SubscriberResetHandler,
)
from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage
from google.cloud.pubsub_v1.subscriber._protocol import requests

Expand All @@ -39,34 +43,61 @@ class _SizedMessage(NamedTuple):
size_bytes: int


class SinglePartitionSingleSubscriber(PermanentFailable, AsyncSingleSubscriber):
class _AckId(NamedTuple):
generation: int
offset: int

def str(self) -> str:
return json.dumps({"generation": self.generation, "offset": self.offset})

@staticmethod
def parse(payload: str) -> "_AckId":
loaded = json.loads(payload)
return _AckId(
generation=int(loaded["generation"]), offset=int(loaded["offset"]),
)


ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber]


class SinglePartitionSingleSubscriber(
PermanentFailable, AsyncSingleSubscriber, SubscriberResetHandler
):
_underlying: Subscriber
_flow_control_settings: FlowControlSettings
_ack_set_tracker: AckSetTracker
_nack_handler: NackHandler
_transformer: MessageTransformer

_queue: queue.Queue
_messages_by_offset: Dict[int, _SizedMessage]
_ack_generation_id: int
_messages_by_ack_id: Dict[str, _SizedMessage]
_looper_future: asyncio.Future

def __init__(
self,
underlying: Subscriber,
subscriber_factory: ResettableSubscriberFactory,
flow_control_settings: FlowControlSettings,
ack_set_tracker: AckSetTracker,
nack_handler: NackHandler,
transformer: MessageTransformer,
):
super().__init__()
self._underlying = underlying
self._underlying = subscriber_factory(self)
self._flow_control_settings = flow_control_settings
self._ack_set_tracker = ack_set_tracker
self._nack_handler = nack_handler
self._transformer = transformer

self._queue = queue.Queue()
self._messages_by_offset = {}
self._ack_generation_id = 0
self._messages_by_ack_id = {}

async def handle_reset(self):
# Increment ack generation id to ignore unacked messages.
++self._ack_generation_id
await self._ack_set_tracker.clear_and_commit()

async def read(self) -> Message:
try:
Expand All @@ -75,13 +106,14 @@ async def read(self) -> Message:
)
cps_message = self._transformer.transform(message)
offset = message.cursor.offset
ack_id = _AckId(self._ack_generation_id, offset)
self._ack_set_tracker.track(offset)
self._messages_by_offset[offset] = _SizedMessage(
self._messages_by_ack_id[ack_id.str()] = _SizedMessage(
cps_message, message.size_bytes
)
wrapped_message = Message(
cps_message._pb,
ack_id=str(offset),
ack_id=ack_id.str(),
delivery_attempt=0,
request_queue=self._queue,
)
Expand All @@ -91,22 +123,23 @@ async def read(self) -> Message:
raise e

async def _handle_ack(self, message: requests.AckRequest):
offset = int(message.ack_id)
await self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=1,
allowed_bytes=self._messages_by_offset[offset].size_bytes,
allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes,
)
)
del self._messages_by_offset[offset]
try:
await self._ack_set_tracker.ack(offset)
except GoogleAPICallError as e:
self.fail(e)
del self._messages_by_ack_id[message.ack_id]
# Always refill flow control tokens, but do not commit offsets from outdated generations.
ack_id = _AckId.parse(message.ack_id)
if ack_id.generation == self._ack_generation_id:
try:
await self._ack_set_tracker.ack(ack_id.offset)
except GoogleAPICallError as e:
self.fail(e)

def _handle_nack(self, message: requests.NackRequest):
offset = int(message.ack_id)
sized_message = self._messages_by_offset[offset]
sized_message = self._messages_by_ack_id[message.ack_id]
try:
# Put the ack request back into the queue since the callback may be called from another thread.
self._nack_handler.on_nack(
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/pubsublite/internal/wire/committer.py
Expand Up @@ -26,3 +26,13 @@ class Committer(AsyncContextManager):
@abstractmethod
async def commit(self, cursor: Cursor) -> None:
pass

@abstractmethod
async def wait_until_empty(self):
"""
Flushes pending commits and waits for all outstanding commit responses from the server.
Raises:
GoogleAPICallError: When the committer terminates in failure.
"""
pass
10 changes: 10 additions & 0 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -63,6 +63,7 @@ class CommitterImpl(

_receiver: Optional[asyncio.Future]
_flusher: Optional[asyncio.Future]
_empty: asyncio.Event

def __init__(
self,
Expand All @@ -79,6 +80,8 @@ def __init__(
self._outstanding_commits = []
self._receiver = None
self._flusher = None
self._empty = asyncio.Event()
self._empty.set()

async def __aenter__(self):
await self._connection.__aenter__()
Expand Down Expand Up @@ -117,6 +120,8 @@ def _handle_response(self, response: StreamingCommitCursorResponse):
batch = self._outstanding_commits.pop(0)
for item in batch:
item.response_future.set_result(None)
if len(self._outstanding_commits) == 0:
self._empty.set()

async def _receive_loop(self):
while True:
Expand Down Expand Up @@ -147,6 +152,7 @@ async def _flush(self):
if not batch:
return
self._outstanding_commits.append(batch)
self._empty.clear()
req = StreamingCommitCursorRequest()
req.commit.cursor = batch[-1].request
try:
Expand All @@ -155,6 +161,10 @@ async def _flush(self):
_LOGGER.debug(f"Failed commit on stream: {e}")
self._fail_if_retrying_failed()

async def wait_until_empty(self):
await self._flush()
await self._connection.await_unless_failed(self._empty.wait())

async def commit(self, cursor: Cursor) -> None:
future = self._batcher.add(cursor)
if self._batcher.should_flush():
Expand Down
42 changes: 42 additions & 0 deletions google/cloud/pubsublite/internal/wire/reset_signal.py
@@ -0,0 +1,42 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.api_core.exceptions import GoogleAPICallError
from google.cloud.pubsublite.internal.status_codes import is_retryable
from google.rpc.error_details_pb2 import ErrorInfo
from grpc_status import rpc_status


def is_reset_signal(error: GoogleAPICallError) -> bool:
"""
Determines whether the given error contains the stream RESET signal, sent by
the server to instruct clients to reset stream state.
Returns: True if the error contains the RESET signal.
"""
if not is_retryable(error) or not error.response:
return False
try:
status = rpc_status.from_call(error.response)
for detail in status.details:
info = ErrorInfo()
if (
detail.Unpack(info)
and info.reason == "RESET"
and info.domain == "pubsublite.googleapis.com"
):
return True
except ValueError:
pass
return False

0 comments on commit 77db700

Please sign in to comment.