Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Handle out of band seeks #158

Merged
merged 5 commits into from Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -45,3 +45,12 @@ async def ack(self, offset: int):
Returns:
GoogleAPICallError: On a commit failure.
"""

@abstractmethod
async def clear_and_commit(self):
"""
Discard all outstanding acks and wait for the commit offset to be acknowledged by the server.

Raises:
GoogleAPICallError: If the committer has shut down due to a permanent error.
"""
Expand Up @@ -62,6 +62,11 @@ async def ack(self, offset: int):
# Convert from last acked to first unacked.
await self._committer.commit(Cursor(offset=prefix_acked_offset + 1))

async def clear_and_commit(self):
self._receipts.clear()
self._acks = queue.PriorityQueue()
await self._committer.wait_until_empty()

async def __aenter__(self):
await self._committer.__aenter__()

Expand Down
22 changes: 14 additions & 8 deletions google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py
Expand Up @@ -52,6 +52,9 @@
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
from google.cloud.pubsublite.internal.wire.pubsub_context import pubsub_context
import google.cloud.pubsublite.internal.wire.subscriber_impl as wire_subscriber
from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import (
SubscriberResetHandler,
)
from google.cloud.pubsublite.types import Partition, SubscriptionPath
from google.cloud.pubsublite.internal.routing_metadata import (
subscription_routing_metadata,
Expand Down Expand Up @@ -131,13 +134,16 @@ def cursor_connection_factory(
requests, metadata=list(final_metadata.items())
)

subscriber = wire_subscriber.SubscriberImpl(
InitialSubscribeRequest(
subscription=str(subscription), partition=partition.value
),
_DEFAULT_FLUSH_SECONDS,
GapicConnectionFactory(subscribe_connection_factory),
)
def subscriber_factory(reset_handler: SubscriberResetHandler):
return wire_subscriber.SubscriberImpl(
InitialSubscribeRequest(
subscription=str(subscription), partition=partition.value
),
_DEFAULT_FLUSH_SECONDS,
GapicConnectionFactory(subscribe_connection_factory),
reset_handler,
)

committer = CommitterImpl(
InitialCommitCursorRequest(
subscription=str(subscription), partition=partition.value
Expand All @@ -147,7 +153,7 @@ def cursor_connection_factory(
)
ack_set_tracker = AckSetTrackerImpl(committer)
return SinglePartitionSingleSubscriber(
subscriber,
subscriber_factory,
flow_control_settings,
ack_set_tracker,
nack_handler,
Expand Down
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import asyncio
from typing import Union, Dict, NamedTuple
import json
from typing import Callable, Union, Dict, NamedTuple
import queue

from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError
Expand All @@ -30,6 +31,9 @@
)
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.internal.wire.subscriber import Subscriber
from google.cloud.pubsublite.internal.wire.subscriber_reset_handler import (
SubscriberResetHandler,
)
from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage
from google.cloud.pubsub_v1.subscriber._protocol import requests

Expand All @@ -39,34 +43,61 @@ class _SizedMessage(NamedTuple):
size_bytes: int


class SinglePartitionSingleSubscriber(PermanentFailable, AsyncSingleSubscriber):
class _AckId(NamedTuple):
generation: int
offset: int

def str(self) -> str:
return json.dumps({"generation": self.generation, "offset": self.offset})

@staticmethod
def parse(payload: str) -> "_AckId":
loaded = json.loads(payload)
return _AckId(
generation=int(loaded["generation"]), offset=int(loaded["offset"]),
)


ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber]


class SinglePartitionSingleSubscriber(
PermanentFailable, AsyncSingleSubscriber, SubscriberResetHandler
):
_underlying: Subscriber
_flow_control_settings: FlowControlSettings
_ack_set_tracker: AckSetTracker
_nack_handler: NackHandler
_transformer: MessageTransformer

_queue: queue.Queue
_messages_by_offset: Dict[int, _SizedMessage]
_ack_generation_id: int
_messages_by_ack_id: Dict[str, _SizedMessage]
_looper_future: asyncio.Future

def __init__(
self,
underlying: Subscriber,
subscriber_factory: ResettableSubscriberFactory,
flow_control_settings: FlowControlSettings,
ack_set_tracker: AckSetTracker,
nack_handler: NackHandler,
transformer: MessageTransformer,
):
super().__init__()
self._underlying = underlying
self._underlying = subscriber_factory(self)
self._flow_control_settings = flow_control_settings
self._ack_set_tracker = ack_set_tracker
self._nack_handler = nack_handler
self._transformer = transformer

self._queue = queue.Queue()
self._messages_by_offset = {}
self._ack_generation_id = 0
self._messages_by_ack_id = {}

async def handle_reset(self):
# Increment ack generation id to ignore unacked messages.
++self._ack_generation_id
await self._ack_set_tracker.clear_and_commit()

async def read(self) -> Message:
try:
Expand All @@ -75,13 +106,14 @@ async def read(self) -> Message:
)
cps_message = self._transformer.transform(message)
offset = message.cursor.offset
ack_id = _AckId(self._ack_generation_id, offset)
self._ack_set_tracker.track(offset)
self._messages_by_offset[offset] = _SizedMessage(
self._messages_by_ack_id[ack_id.str()] = _SizedMessage(
cps_message, message.size_bytes
)
wrapped_message = Message(
cps_message._pb,
ack_id=str(offset),
ack_id=ack_id.str(),
delivery_attempt=0,
request_queue=self._queue,
)
Expand All @@ -91,22 +123,23 @@ async def read(self) -> Message:
raise e

async def _handle_ack(self, message: requests.AckRequest):
offset = int(message.ack_id)
await self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=1,
allowed_bytes=self._messages_by_offset[offset].size_bytes,
allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes,
)
)
del self._messages_by_offset[offset]
try:
await self._ack_set_tracker.ack(offset)
except GoogleAPICallError as e:
self.fail(e)
del self._messages_by_ack_id[message.ack_id]
# Always refill flow control tokens, but do not commit offsets from outdated generations.
ack_id = _AckId.parse(message.ack_id)
if ack_id.generation == self._ack_generation_id:
try:
await self._ack_set_tracker.ack(ack_id.offset)
except GoogleAPICallError as e:
self.fail(e)

def _handle_nack(self, message: requests.NackRequest):
offset = int(message.ack_id)
sized_message = self._messages_by_offset[offset]
sized_message = self._messages_by_ack_id[message.ack_id]
try:
# Put the ack request back into the queue since the callback may be called from another thread.
self._nack_handler.on_nack(
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/pubsublite/internal/wire/committer.py
Expand Up @@ -26,3 +26,13 @@ class Committer(AsyncContextManager):
@abstractmethod
async def commit(self, cursor: Cursor) -> None:
pass

@abstractmethod
async def wait_until_empty(self):
"""
Flushes pending commits and waits for all outstanding commit responses from the server.

Raises:
GoogleAPICallError: When the committer terminates in failure.
"""
pass
10 changes: 10 additions & 0 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -63,6 +63,7 @@ class CommitterImpl(

_receiver: Optional[asyncio.Future]
_flusher: Optional[asyncio.Future]
_empty: asyncio.Event

def __init__(
self,
Expand All @@ -79,6 +80,8 @@ def __init__(
self._outstanding_commits = []
self._receiver = None
self._flusher = None
self._empty = asyncio.Event()
self._empty.set()

async def __aenter__(self):
await self._connection.__aenter__()
Expand Down Expand Up @@ -117,6 +120,8 @@ def _handle_response(self, response: StreamingCommitCursorResponse):
batch = self._outstanding_commits.pop(0)
for item in batch:
item.response_future.set_result(None)
if len(self._outstanding_commits) == 0:
self._empty.set()

async def _receive_loop(self):
while True:
Expand Down Expand Up @@ -147,6 +152,7 @@ async def _flush(self):
if not batch:
return
self._outstanding_commits.append(batch)
self._empty.clear()
req = StreamingCommitCursorRequest()
req.commit.cursor = batch[-1].request
try:
Expand All @@ -155,6 +161,10 @@ async def _flush(self):
_LOGGER.debug(f"Failed commit on stream: {e}")
self._fail_if_retrying_failed()

async def wait_until_empty(self):
await self._flush()
await self._connection.await_unless_failed(self._empty.wait())

async def commit(self, cursor: Cursor) -> None:
future = self._batcher.add(cursor)
if self._batcher.should_flush():
Expand Down
42 changes: 42 additions & 0 deletions google/cloud/pubsublite/internal/wire/reset_signal.py
@@ -0,0 +1,42 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.api_core.exceptions import GoogleAPICallError
from google.cloud.pubsublite.internal.status_codes import is_retryable
from google.rpc.error_details_pb2 import ErrorInfo
from grpc_status import rpc_status


def is_reset_signal(error: GoogleAPICallError) -> bool:
"""
Determines whether the given error contains the stream RESET signal, sent by
the server to instruct clients to reset stream state.

Returns: True if the error contains the RESET signal.
"""
if not is_retryable(error) or not error.response:
return False
try:
status = rpc_status.from_call(error.response)
for detail in status.details:
info = ErrorInfo()
if (
detail.Unpack(info)
and info.reason == "RESET"
and info.domain == "pubsublite.googleapis.com"
):
return True
except ValueError:
pass
return False