Skip to content

Commit

Permalink
feat: Implement a single partition publisher (#8)
Browse files Browse the repository at this point in the history
* feat: Implement SerialBatcher which helps with transforming single writes into batch writes.

* feat: Implement SinglePartitionPublisher which publishes to a single partition and handles retries.
  • Loading branch information
dpcollins-google committed Aug 11, 2020
1 parent a6dc15f commit fd1d76f
Show file tree
Hide file tree
Showing 9 changed files with 489 additions and 14 deletions.
7 changes: 6 additions & 1 deletion google/cloud/pubsublite/internal/wire/permanent_failable.py
@@ -1,5 +1,5 @@
import asyncio
from typing import Awaitable, TypeVar
from typing import Awaitable, TypeVar, Optional

from google.api_core.exceptions import GoogleAPICallError

Expand Down Expand Up @@ -29,3 +29,8 @@ async def await_or_fail(self, awaitable: Awaitable[T]) -> T:
def fail(self, err: GoogleAPICallError):
if not self._failure_task.done():
self._failure_task.set_exception(err)

def error(self) -> Optional[GoogleAPICallError]:
if not self._failure_task.done():
return None
return self._failure_task.exception()
29 changes: 29 additions & 0 deletions google/cloud/pubsublite/internal/wire/publisher.py
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from google.cloud.pubsublite_v1.types import PubSubMessage
from google.cloud.pubsublite.publish_metadata import PublishMetadata


class Publisher(ABC):
@abstractmethod
async def __aenter__(self):
raise NotImplementedError()

@abstractmethod
async def __aexit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError()

@abstractmethod
async def publish(self, message: PubSubMessage) -> PublishMetadata:
"""
Publish the provided message.
Args:
message: The message to be published.
Returns:
Metadata about the published message.
Raises:
GoogleAPICallError: On a permanent error.
"""
raise NotImplementedError()
6 changes: 3 additions & 3 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
Expand Up @@ -19,7 +19,7 @@ class RetryingConnection(Connection[Request, Response], PermanentFailable):

_loop_task: asyncio.Future

_write_queue: 'asyncio.Queue[WorkItem[Request]]'
_write_queue: 'asyncio.Queue[WorkItem[Request, None]]'
_read_queue: 'asyncio.Queue[Response]'

def __init__(self, connection_factory: ConnectionFactory[Request, Response], reinitializer: ConnectionReinitializer[Request, Response]):
Expand Down Expand Up @@ -56,7 +56,7 @@ async def _run_loop(self):
await self._reinitializer.reinitialize(connection)
bad_retries = 0
await self._loop_connection(connection)
except (Exception, GoogleAPICallError) as e:
except GoogleAPICallError as e:
if not is_retryable(e):
self.fail(e)
return
Expand All @@ -79,7 +79,7 @@ async def _loop_connection(self, connection: Connection[Request, Response]):
read_task = asyncio.ensure_future(connection.read())

@staticmethod
async def _handle_write(connection: Connection[Request, Response], to_write: WorkItem[Request]):
async def _handle_write(connection: Connection[Request, Response], to_write: WorkItem[Request, Response]):
try:
await connection.write(to_write.request)
to_write.response_future.set_result(None)
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/pubsublite/internal/wire/serial_batcher.py
Expand Up @@ -21,7 +21,7 @@ def test(self, requests: Iterable[Request]) -> bool:

class SerialBatcher(Generic[Request, Response]):
_tester: BatchTester[Request]
_requests: List[WorkItem[Request]] # A list of outstanding requests
_requests: List[WorkItem[Request, Response]] # A list of outstanding requests

def __init__(self, tester: BatchTester[Request]):
self._tester = tester
Expand All @@ -37,14 +37,14 @@ def add(self, request: Request) -> 'asyncio.Future[Response]':
Returns:
A future that will resolve to the response or a GoogleAPICallError.
"""
item = WorkItem[Request](request)
item = WorkItem[Request, Response](request)
self._requests.append(item)
return item.response_future

def should_flush(self) -> bool:
return self._tester.test(item.request for item in self._requests)

def flush(self) -> Iterable[WorkItem[Request]]:
def flush(self) -> List[WorkItem[Request, Response]]:
requests = self._requests
self._requests = []
return requests
146 changes: 146 additions & 0 deletions google/cloud/pubsublite/internal/wire/single_partition_publisher.py
@@ -0,0 +1,146 @@
import asyncio
from typing import Optional, List, Iterable

from absl import logging
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection, ConnectionFactory
from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError
from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer
from google.cloud.pubsublite.internal.wire.connection import Connection
from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher, BatchTester
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite.publish_metadata import PublishMetadata
from google.cloud.pubsublite_v1.types import PubSubMessage, Cursor, PublishRequest, PublishResponse, \
InitialPublishRequest
from google.cloud.pubsublite.internal.wire.work_item import WorkItem

# Maximum bytes per batch at 3.5 MiB to avoid GRPC limit of 4 MiB
_MAX_BYTES = int(3.5 * 1024 * 1024)

# Maximum messages per batch at 1000
_MAX_MESSAGES = 1000


class SinglePartitionPublisher(Publisher, ConnectionReinitializer[PublishRequest, PublishResponse], BatchTester[PubSubMessage]):
_initial: InitialPublishRequest
_flush_seconds: float
_connection: RetryingConnection[PublishRequest, PublishResponse]

_batcher: SerialBatcher[PubSubMessage, Cursor]
_outstanding_writes: List[List[WorkItem[PubSubMessage, Cursor]]]

_receiver: Optional[asyncio.Future]
_flusher: Optional[asyncio.Future]

def __init__(self, initial: InitialPublishRequest, flush_seconds: float,
factory: ConnectionFactory[PublishRequest, PublishResponse]):
self._initial = initial
self._flush_seconds = flush_seconds
self._connection = RetryingConnection(factory, self)
self._batcher = SerialBatcher(self)
self._outstanding_writes = []
self._receiver = None
self._flusher = None

@property
def _partition(self) -> Partition:
return Partition(self._initial.partition)

async def __aenter__(self):
await self._connection.__aenter__()

def _start_loopers(self):
assert self._receiver is None
assert self._flusher is None
self._receiver = asyncio.ensure_future(self._receive_loop())
self._flusher = asyncio.ensure_future(self._flush_loop())

async def _stop_loopers(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
self._receiver = None
if self._flusher:
self._flusher.cancel()
await self._flusher
self._flusher = None

def _handle_response(self, response: PublishResponse):
if "message_response" not in response:
self._connection.fail(FailedPrecondition("Received an invalid subsequent response on the publish stream."))
if not self._outstanding_writes:
self._connection.fail(
FailedPrecondition("Received an publish response on the stream with no outstanding publishes."))
next_offset: Cursor = response.message_response.start_cursor.offset
batch: List[WorkItem[PubSubMessage]] = self._outstanding_writes.pop(0)
for item in batch:
item.response_future.set_result(Cursor(offset=next_offset))
next_offset += 1

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
self._handle_response(response)
except asyncio.CancelledError:
return

async def _flush_loop(self):
try:
while True:
await asyncio.sleep(self._flush_seconds)
await self._flush()
except asyncio.CancelledError:
return

async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._connection.error():
self._fail_if_retrying_failed()
else:
await self._flush()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

def _fail_if_retrying_failed(self):
if self._connection.error():
for batch in self._outstanding_writes:
for item in batch:
item.response_future.set_exception(self._connection.error())

async def _flush(self):
batch = self._batcher.flush()
if not batch:
return
self._outstanding_writes.append(batch)
aggregate = PublishRequest()
aggregate.message_publish_request.messages = [item.request for item in batch]
try:
await self._connection.write(aggregate)
except GoogleAPICallError as e:
logging.debug(f"Failed publish on stream: {e}")
self._fail_if_retrying_failed()

async def publish(self, message: PubSubMessage) -> PublishMetadata:
cursor_future = self._batcher.add(message)
if self._batcher.should_flush():
await self._flush()
return PublishMetadata(self._partition, await cursor_future)

async def reinitialize(self, connection: Connection[PublishRequest, PublishResponse]):
await self._stop_loopers()
await connection.write(PublishRequest(initial_request=self._initial))
response = await connection.read()
if "initial_response" not in response:
self._connection.fail(FailedPrecondition("Received an invalid initial response on the publish stream."))
for batch in self._outstanding_writes:
aggregate = PublishRequest()
aggregate.message_publish_request.messages = [item.request for item in batch]
await connection.write(aggregate)
self._start_loopers()

def test(self, requests: Iterable[PubSubMessage]) -> bool:
request_count = 0
byte_count = 0
for req in requests:
request_count += 1
byte_count += PubSubMessage.pb(req).ByteSize()
return (request_count >= _MAX_MESSAGES) or (byte_count >= _MAX_BYTES)
12 changes: 6 additions & 6 deletions google/cloud/pubsublite/internal/wire/work_item.py
@@ -1,14 +1,14 @@
import asyncio
from typing import Generic, TypeVar
from typing import Generic

T = TypeVar('T')
from google.cloud.pubsublite.internal.wire.connection import Request, Response


class WorkItem(Generic[T]):
class WorkItem(Generic[Request, Response]):
"""An item of work and a future to complete when it is finished."""
request: T
response_future: "asyncio.Future[None]"
request: Request
response_future: "asyncio.Future[Response]"

def __init__(self, request: T):
def __init__(self, request: Request):
self.request = request
self.response_future = asyncio.Future()
8 changes: 8 additions & 0 deletions google/cloud/pubsublite/publish_metadata.py
@@ -0,0 +1,8 @@
from typing import NamedTuple
from google.cloud.pubsublite_v1.types.common import Cursor
from google.cloud.pubsublite.partition import Partition


class PublishMetadata(NamedTuple):
partition: Partition
cursor: Cursor
25 changes: 24 additions & 1 deletion google/cloud/pubsublite/testing/test_utils.py
@@ -1,8 +1,31 @@
from typing import List, Union, Any
import asyncio
from typing import List, Union, Any, TypeVar, Generic, Optional

T = TypeVar("T")


async def async_iterable(elts: List[Union[Any, Exception]]):
for elt in elts:
if isinstance(elt, Exception):
raise elt
yield elt


def make_queue_waiter(started_q: "asyncio.Queue[None]", result_q: "asyncio.Queue[Union[T, Exception]]"):
"""
Given a queue to notify when started and a queue to get results from, return a waiter which
notifies started_q when started and returns from result_q when done.
"""

async def waiter(*args, **kwargs):
await started_q.put(None)
result = await result_q.get()
if isinstance(result, Exception):
raise result
return result

return waiter


class Box(Generic[T]):
val: Optional[T]

0 comments on commit fd1d76f

Please sign in to comment.