Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Subscriber, which handles flow control and batch message processing. #16

Merged
merged 7 commits into from Sep 15, 2020
4 changes: 3 additions & 1 deletion google/cloud/pubsublite/internal/wire/assigner_impl.py
Expand Up @@ -39,6 +39,7 @@ def __init__(self, initial: InitialPartitionAssignmentRequest,

async def __aenter__(self):
await self._connection.__aenter__()
return self

def _start_receiver(self):
assert self._receiver is None
Expand All @@ -63,10 +64,11 @@ async def _receive_loop(self):
for partition in response.partitions:
partitions.add(Partition(partition))
self._new_assignment.put_nowait(partitions)
except asyncio.CancelledError:
except (asyncio.CancelledError, GoogleAPICallError):
return

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_receiver()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

async def reinitialize(self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment]):
Expand Down
10 changes: 7 additions & 3 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -10,11 +10,13 @@
from google.cloud.pubsublite.internal.wire.connection import Connection
from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher, BatchTester
from google.cloud.pubsublite_v1 import Cursor
from google.cloud.pubsublite_v1.types import StreamingCommitCursorRequest, StreamingCommitCursorResponse, InitialCommitCursorRequest
from google.cloud.pubsublite_v1.types import StreamingCommitCursorRequest, StreamingCommitCursorResponse, \
InitialCommitCursorRequest
from google.cloud.pubsublite.internal.wire.work_item import WorkItem


class CommitterImpl(Committer, ConnectionReinitializer[StreamingCommitCursorRequest, StreamingCommitCursorResponse], BatchTester[Cursor]):
class CommitterImpl(Committer, ConnectionReinitializer[StreamingCommitCursorRequest, StreamingCommitCursorResponse],
BatchTester[Cursor]):
_initial: InitialCommitCursorRequest
_flush_seconds: float
_connection: RetryingConnection[StreamingCommitCursorRequest, StreamingCommitCursorResponse]
Expand All @@ -38,6 +40,7 @@ def __init__(self, initial: InitialCommitCursorRequest, flush_seconds: float,

async def __aenter__(self):
await self._connection.__aenter__()
return self

def _start_loopers(self):
assert self._receiver is None
Expand Down Expand Up @@ -71,7 +74,7 @@ async def _receive_loop(self):
while True:
response = await self._connection.read()
self._handle_response(response)
except asyncio.CancelledError:
except (asyncio.CancelledError, GoogleAPICallError):
return

async def _flush_loop(self):
Expand All @@ -83,6 +86,7 @@ async def _flush_loop(self):
return

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
if self._connection.error():
self._fail_if_retrying_failed()
else:
Expand Down
67 changes: 67 additions & 0 deletions google/cloud/pubsublite/internal/wire/flow_control_batcher.py
@@ -0,0 +1,67 @@
from typing import NamedTuple, List, Optional

from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage

_EXPEDITE_BATCH_REQUEST_RATIO = 0.5
_MAX_INT64 = 0x7FFFFFFFFFFFFFFF


class _AggregateRequest:
request: FlowControlRequest

def __init__(self):
self.request = FlowControlRequest()

def __add__(self, other: FlowControlRequest):
self.request.allowed_bytes += other.allowed_bytes
self.request.allowed_bytes = min(self.request.allowed_bytes, _MAX_INT64)
self.request.allowed_messages += other.allowed_messages
self.request.allowed_messages = min(self.request.allowed_messages, _MAX_INT64)
return self


def _exceeds_expedite_ratio(pending: int, client: int):
if client <= 0:
return False
return (pending/client) >= _EXPEDITE_BATCH_REQUEST_RATIO


def _to_optional(req: FlowControlRequest) -> Optional[FlowControlRequest]:
if req.allowed_messages == 0 and req.allowed_bytes == 0:
return None
return req


class FlowControlBatcher:
_client_tokens: _AggregateRequest
_pending_tokens: _AggregateRequest

def __init__(self):
self._client_tokens = _AggregateRequest()
self._pending_tokens = _AggregateRequest()

def add(self, request: FlowControlRequest):
self._client_tokens += request
self._pending_tokens += request

def on_messages(self, messages: List[SequencedMessage]):
byte_size = sum(message.size_bytes for message in messages)
self._client_tokens += FlowControlRequest(allowed_bytes=-byte_size, allowed_messages=-len(messages))

def request_for_restart(self) -> Optional[FlowControlRequest]:
self._pending_tokens = _AggregateRequest()
return _to_optional(self._client_tokens.request)

def release_pending_request(self) -> Optional[FlowControlRequest]:
request = self._pending_tokens.request
self._pending_tokens = _AggregateRequest()
return _to_optional(request)

def should_expedite(self):
pending_request = self._pending_tokens.request
client_request = self._client_tokens.request
if _exceeds_expedite_ratio(pending_request.allowed_bytes, client_request.allowed_bytes):
return True
if _exceeds_expedite_ratio(pending_request.allowed_messages, client_request.allowed_messages):
return True
return False
1 change: 1 addition & 0 deletions google/cloud/pubsublite/internal/wire/routing_publisher.py
Expand Up @@ -18,6 +18,7 @@ def __init__(self, routing_policy: RoutingPolicy, publishers: Mapping[Partition,
async def __aenter__(self):
for publisher in self._publishers.values():
await publisher.__aenter__()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
for publisher in self._publishers.values():
Expand Down
Expand Up @@ -48,6 +48,7 @@ def _partition(self) -> Partition:

async def __aenter__(self):
await self._connection.__aenter__()
return self

def _start_loopers(self):
assert self._receiver is None
Expand Down Expand Up @@ -82,7 +83,7 @@ async def _receive_loop(self):
while True:
response = await self._connection.read()
self._handle_response(response)
except asyncio.CancelledError:
except (asyncio.CancelledError, GoogleAPICallError):
return

async def _flush_loop(self):
Expand All @@ -98,6 +99,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
self._fail_if_retrying_failed()
else:
await self._flush()
await self._stop_loopers()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

def _fail_if_retrying_failed(self):
Expand Down
28 changes: 28 additions & 0 deletions google/cloud/pubsublite/internal/wire/subscriber.py
@@ -0,0 +1,28 @@
from abc import abstractmethod
from typing import AsyncContextManager
from google.cloud.pubsublite_v1.types import SequencedMessage, FlowControlRequest


class Subscriber(AsyncContextManager):
"""
A Pub/Sub Lite asynchronous wire protocol subscriber.
"""
@abstractmethod
async def read(self) -> SequencedMessage:
"""
Read the next message off of the stream.

Returns:
The next message.

Raises:
GoogleAPICallError: On a permanent error.
"""
raise NotImplementedError()

@abstractmethod
async def allow_flow(self, request: FlowControlRequest):
"""
Allow an additional amount of messages and bytes to be sent to this client.
"""
raise NotImplementedError()
135 changes: 135 additions & 0 deletions google/cloud/pubsublite/internal/wire/subscriber_impl.py
@@ -0,0 +1,135 @@
import asyncio
from typing import Optional

from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition

from google.cloud.pubsublite.internal.wire.connection import Request, Connection, Response, ConnectionFactory
from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer
from google.cloud.pubsublite.internal.wire.flow_control_batcher import FlowControlBatcher
from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection
from google.cloud.pubsublite.internal.wire.subscriber import Subscriber
from google.cloud.pubsublite_v1 import SubscribeRequest, SubscribeResponse, FlowControlRequest, SequencedMessage, \
InitialSubscribeRequest, SeekRequest, Cursor


class SubscriberImpl(Subscriber, ConnectionReinitializer[SubscribeRequest, SubscribeResponse]):
_initial: InitialSubscribeRequest
_token_flush_seconds: float
_connection: RetryingConnection[SubscribeRequest, SubscribeResponse]

_outstanding_flow_control: FlowControlBatcher

_reinitializing: bool
_last_received_offset: Optional[int]

_message_queue: 'asyncio.Queue[SequencedMessage]'

_receiver: Optional[asyncio.Future]
_flusher: Optional[asyncio.Future]

def __init__(self, initial: InitialSubscribeRequest, token_flush_seconds: float,
factory: ConnectionFactory[SubscribeRequest, SubscribeResponse]):
self._initial = initial
self._token_flush_seconds = token_flush_seconds
self._connection = RetryingConnection(factory, self)
self._outstanding_flow_control = FlowControlBatcher()
self._reinitializing = False
self._last_received_offset = None
self._message_queue = asyncio.Queue()
self._receiver = None
self._flusher = None

async def __aenter__(self):
await self._connection.__aenter__()
return self

def _start_loopers(self):
assert self._receiver is None
assert self._flusher is None
self._receiver = asyncio.ensure_future(self._receive_loop())
self._flusher = asyncio.ensure_future(self._flush_loop())

async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
self._flusher = None

def _handle_response(self, response: SubscribeResponse):
if "messages" not in response:
self._connection.fail(FailedPrecondition("Received an invalid subsequent response on the subscribe stream."))
return
self._outstanding_flow_control.on_messages(response.messages.messages)
for message in response.messages.messages:
if self._last_received_offset is not None and message.cursor.offset <= self._last_received_offset:
self._connection.fail(FailedPrecondition(
"Received an invalid out of order message from the server. Message is {}, previous last received is {}.".format(
message.cursor.offset, self._last_received_offset)))
return
self._last_received_offset = message.cursor.offset
for message in response.messages.messages:
# queue is unbounded.
self._message_queue.put_nowait(message)

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except (asyncio.CancelledError, GoogleAPICallError):
return

async def _try_send_tokens(self):
req = self._outstanding_flow_control.release_pending_request()
if req is None:
return
try:
await self._connection.write(SubscribeRequest(flow_control=req))
except GoogleAPICallError:
# May be transient, in which case these tokens will be resent.
pass

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._token_flush_seconds)
await self._try_send_tokens()
except asyncio.CancelledError:
return

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

async def reinitialize(self, connection: Connection[SubscribeRequest, SubscribeResponse]):
self._reinitializing = True
await self._stop_loopers()
await connection.write(SubscribeRequest(initial=self._initial))
response = await connection.read()
if "initial" not in response:
self._connection.fail(FailedPrecondition("Received an invalid initial response on the subscribe stream."))
return
if self._last_received_offset is not None:
# Perform a seek to get the next message after the one we received.
await connection.write(SubscribeRequest(seek=SeekRequest(cursor=Cursor(offset=self._last_received_offset + 1))))
seek_response = await connection.read()
if "seek" not in seek_response:
self._connection.fail(FailedPrecondition("Received an invalid seek response on the subscribe stream."))
return
tokens = self._outstanding_flow_control.request_for_restart()
if tokens is not None:
await connection.write(SubscribeRequest(flow_control=tokens))
self._reinitializing = False
self._start_loopers()

async def read(self) -> SequencedMessage:
return await self._connection.await_unless_failed(self._message_queue.get())

async def allow_flow(self, request: FlowControlRequest):
self._outstanding_flow_control.add(request)
if not self._reinitializing and self._outstanding_flow_control.should_expedite():
await self._try_send_tokens()
28 changes: 28 additions & 0 deletions tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py
@@ -0,0 +1,28 @@
from google.cloud.pubsublite.internal.wire.flow_control_batcher import FlowControlBatcher
from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage


def test_restart_clears_send():
batcher = FlowControlBatcher()
batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3))
assert batcher.should_expedite()
to_send = batcher.release_pending_request()
assert to_send.allowed_bytes == 10
assert to_send.allowed_messages == 3
restart_1 = batcher.request_for_restart()
assert restart_1.allowed_bytes == 10
assert restart_1.allowed_messages == 3
assert not batcher.should_expedite()
assert batcher.release_pending_request() is None


def test_add_remove():
batcher = FlowControlBatcher()
batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3))
restart_1 = batcher.request_for_restart()
assert restart_1.allowed_bytes == 10
assert restart_1.allowed_messages == 3
batcher.on_messages([SequencedMessage(size_bytes=2), SequencedMessage(size_bytes=3)])
restart_2 = batcher.request_for_restart()
assert restart_2.allowed_bytes == 5
assert restart_2.allowed_messages == 1