Skip to content

Commit

Permalink
feat: Implement assigner, which handles partition-subscriber assignme…
Browse files Browse the repository at this point in the history
…nt. (#14)

* feat: Implement assigner which generates subscription-partition assignments.

Also slightly change the semantics of PermanentFailable to not fail a RetryingConnection on retryable errors from a watched awaitable.
  • Loading branch information
dpcollins-google committed Sep 15, 2020
1 parent db78799 commit b2d0d36
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 12 deletions.
15 changes: 15 additions & 0 deletions google/cloud/pubsublite/internal/wire/assigner.py
@@ -0,0 +1,15 @@
from abc import abstractmethod
from typing import AsyncContextManager, Set

from google.cloud.pubsublite.partition import Partition


class Assigner(AsyncContextManager):
"""
An assigner will deliver a continuous stream of assignments when called into. Perform all necessary work with the
assignment before attempting to get the next one.
"""

@abstractmethod
async def get_assignment(self) -> Set[Partition]:
raise NotImplementedError()
88 changes: 88 additions & 0 deletions google/cloud/pubsublite/internal/wire/assigner_impl.py
@@ -0,0 +1,88 @@
import asyncio
from typing import Optional, Set

from absl import logging
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection, ConnectionFactory
from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError
from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer
from google.cloud.pubsublite.internal.wire.connection import Connection
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite_v1.types import PartitionAssignmentRequest, PartitionAssignment, \
InitialPartitionAssignmentRequest, PartitionAssignmentAck

# Maximum bytes per batch at 3.5 MiB to avoid GRPC limit of 4 MiB
_MAX_BYTES = int(3.5 * 1024 * 1024)

# Maximum messages per batch at 1000
_MAX_MESSAGES = 1000


class AssignerImpl(Assigner, ConnectionReinitializer[PartitionAssignmentRequest, PartitionAssignment]):
_initial: InitialPartitionAssignmentRequest
_connection: RetryingConnection[PartitionAssignmentRequest, PartitionAssignment]

_outstanding_assignment: bool

_receiver: Optional[asyncio.Future]

# A queue that may only hold one element with the next assignment.
_new_assignment: 'asyncio.Queue[Set[Partition]]'

def __init__(self, initial: InitialPartitionAssignmentRequest,
factory: ConnectionFactory[PartitionAssignmentRequest, PartitionAssignment]):
self._initial = initial
self._connection = RetryingConnection(factory, self)
self._outstanding_assignment = False
self._receiver = None
self._new_assignment = asyncio.Queue(maxsize=1)

async def __aenter__(self):
await self._connection.__aenter__()

def _start_receiver(self):
assert self._receiver is None
self._receiver = asyncio.ensure_future(self._receive_loop())

async def _stop_receiver(self):
if self._receiver:
self._receiver.cancel()
await self._receiver
self._receiver = None

async def _receive_loop(self):
try:
while True:
response = await self._connection.read()
if self._outstanding_assignment or not self._new_assignment.empty():
self._connection.fail(FailedPrecondition(
"Received a duplicate assignment on the stream while one was outstanding."))
return
self._outstanding_assignment = True
partitions = set()
for partition in response.partitions:
partitions.add(Partition(partition))
self._new_assignment.put_nowait(partitions)
except asyncio.CancelledError:
return

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

async def reinitialize(self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment]):
self._outstanding_assignment = False
while not self._new_assignment.empty():
self._new_assignment.get_nowait()
await self._stop_receiver()
await connection.write(PartitionAssignmentRequest(initial=self._initial))
self._start_receiver()

async def get_assignment(self) -> Set[Partition]:
if self._outstanding_assignment:
try:
await self._connection.write(PartitionAssignmentRequest(ack=PartitionAssignmentAck()))
self._outstanding_assignment = False
except GoogleAPICallError as e:
# If there is a failure to ack, keep going. The stream likely restarted.
logging.debug(f"Assignment ack attempt failed due to stream failure: {e}")
return await self._connection.await_unless_failed(self._new_assignment.get())
16 changes: 12 additions & 4 deletions google/cloud/pubsublite/internal/wire/gapic_connection.py
@@ -1,6 +1,8 @@
from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable
import asyncio

from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition

from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response, ConnectionFactory
from google.cloud.pubsublite.internal.wire.work_item import WorkItem
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
Expand All @@ -22,11 +24,17 @@ def set_response_it(self, response_it: AsyncIterator[Response]):

async def write(self, request: Request) -> None:
item = WorkItem(request)
await self.await_or_fail(self._write_queue.put(item))
await self.await_or_fail(item.response_future)
await self.await_unless_failed(self._write_queue.put(item))
await self.await_unless_failed(item.response_future)

async def read(self) -> Response:
return await self.await_or_fail(self._response_it.__anext__())
try:
return await self.await_unless_failed(self._response_it.__anext__())
except StopAsyncIteration:
self.fail(FailedPrecondition("Server sent unprompted half close."))
except GoogleAPICallError as e:
self.fail(e)
raise self.error()

def __aenter__(self):
return self
Expand All @@ -35,7 +43,7 @@ def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass

async def __anext__(self) -> Request:
item: WorkItem[Request] = await self.await_or_fail(self._write_queue.get())
item: WorkItem[Request] = await self.await_unless_failed(self._write_queue.get())
item.response_future.set_result(None)
return item.request

Expand Down
15 changes: 10 additions & 5 deletions google/cloud/pubsublite/internal/wire/permanent_failable.py
Expand Up @@ -13,16 +13,21 @@ class PermanentFailable:
def __init__(self):
self._failure_task = asyncio.Future()

async def await_or_fail(self, awaitable: Awaitable[T]) -> T:
async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
"""
Await the awaitable, unless fail() is called first.
Args:
awaitable: An awaitable
Returns: The result of the awaitable
Raises: The permanent error if fail() is called or the awaitable raises one.
"""
if self._failure_task.done():
raise self._failure_task.exception()
task = asyncio.ensure_future(awaitable)
done, _ = await asyncio.wait([task, self._failure_task], return_when=asyncio.FIRST_COMPLETED)
if task in done:
try:
return await task
except GoogleAPICallError as e:
self.fail(e)
return await task
task.cancel()
raise self._failure_task.exception()

Expand Down
6 changes: 3 additions & 3 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
Expand Up @@ -38,11 +38,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

async def write(self, request: Request) -> None:
item = WorkItem(request)
await self.await_or_fail(self._write_queue.put(item))
return await self.await_or_fail(item.response_future)
await self.await_unless_failed(self._write_queue.put(item))
return await self.await_unless_failed(item.response_future)

async def read(self) -> Response:
return await self.await_or_fail(self._read_queue.get())
return await self.await_unless_failed(self._read_queue.get())

async def _run_loop(self):
"""
Expand Down
183 changes: 183 additions & 0 deletions tests/unit/pubsublite/internal/wire/assigner_impl_test.py
@@ -0,0 +1,183 @@
import asyncio
from unittest.mock import call
from collections import defaultdict
from typing import Dict, Set

from asynctest.mock import MagicMock, CoroutineMock
import pytest

from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.assigner_impl import AssignerImpl
from google.cloud.pubsublite.internal.wire.connection import Connection, ConnectionFactory
from google.api_core.exceptions import InternalServerError

from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite_v1.types.subscriber import PartitionAssignmentRequest, InitialPartitionAssignmentRequest, \
PartitionAssignment, PartitionAssignmentAck
from google.cloud.pubsublite.testing.test_utils import make_queue_waiter
from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio


@pytest.fixture()
def default_connection():
conn = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment])
conn.__aenter__.return_value = conn
return conn


@pytest.fixture()
def connection_factory(default_connection):
factory = MagicMock(spec=ConnectionFactory[PartitionAssignmentRequest, PartitionAssignment])
factory.new.return_value = default_connection
return factory


@pytest.fixture()
def initial_request():
return PartitionAssignmentRequest(initial=InitialPartitionAssignmentRequest(subscription="mysub"))


class QueuePair:
called: asyncio.Queue
results: asyncio.Queue

def __init__(self):
self.called = asyncio.Queue()
self.results = asyncio.Queue()


@pytest.fixture
def sleep_queues() -> Dict[float, QueuePair]:
return defaultdict(QueuePair)


@pytest.fixture
def asyncio_sleep(monkeypatch, sleep_queues):
"""Requests.get() mocked to return {'mock_key':'mock_response'}."""
mock = CoroutineMock()
monkeypatch.setattr(asyncio, "sleep", mock)

async def sleeper(delay: float):
await make_queue_waiter(sleep_queues[delay].called, sleep_queues[delay].results)(delay)

mock.side_effect = sleeper
return mock


@pytest.fixture()
def assigner(connection_factory, initial_request):
return AssignerImpl(initial_request.initial, connection_factory)


def as_response(partitions: Set[Partition]):
req = PartitionAssignment()
req.partitions = [partition.value for partition in partitions]
return req


def ack_request():
return PartitionAssignmentRequest(ack=PartitionAssignmentAck())


async def test_basic_assign(
assigner: Assigner, default_connection, initial_request):
write_called_queue = asyncio.Queue()
write_result_queue = asyncio.Queue()
default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue)
read_called_queue = asyncio.Queue()
read_result_queue = asyncio.Queue()
default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue)
write_result_queue.put_nowait(None)
async with assigner:
# Set up connection
await write_called_queue.get()
await read_called_queue.get()
default_connection.write.assert_has_calls([call(initial_request)])

# Wait for the first assignment
assign_fut1 = asyncio.ensure_future(assigner.get_assignment())
assert not assign_fut1.done()

partitions = {Partition(2), Partition(7)}

# Send the first assignment.
await read_result_queue.put(as_response(partitions=partitions))
assert (await assign_fut1) == partitions

# Get the next assignment: should send an ack on the stream
assign_fut2 = asyncio.ensure_future(assigner.get_assignment())
await write_called_queue.get()
await write_result_queue.put(None)
default_connection.write.assert_has_calls([call(initial_request), call(ack_request())])

partitions = {Partition(5)}

# Send the second assignment.
await read_called_queue.get()
await read_result_queue.put(as_response(partitions=partitions))
assert (await assign_fut2) == partitions


async def test_restart(
assigner: Assigner, default_connection, connection_factory, initial_request, asyncio_sleep, sleep_queues):
write_called_queue = asyncio.Queue()
write_result_queue = asyncio.Queue()
default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue)
read_called_queue = asyncio.Queue()
read_result_queue = asyncio.Queue()
default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue)
write_result_queue.put_nowait(None)
async with assigner:
# Set up connection
await write_called_queue.get()
await read_called_queue.get()
default_connection.write.assert_has_calls([call(initial_request)])

# Wait for the first assignment
assign_fut1 = asyncio.ensure_future(assigner.get_assignment())
assert not assign_fut1.done()

partitions = {Partition(2), Partition(7)}

# Send the first assignment.
await read_result_queue.put(as_response(partitions=partitions))
await read_called_queue.get()
assert (await assign_fut1) == partitions

# Get the next assignment: should attempt to send an ack on the stream
assign_fut2 = asyncio.ensure_future(assigner.get_assignment())
await write_called_queue.get()
default_connection.write.assert_has_calls([call(initial_request), call(ack_request())])

# Set up the next connection
conn2 = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment])
conn2.__aenter__.return_value = conn2
connection_factory.new.return_value = conn2
write_called_queue_2 = asyncio.Queue()
write_result_queue_2 = asyncio.Queue()
conn2.write.side_effect = make_queue_waiter(write_called_queue_2, write_result_queue_2)
read_called_queue_2 = asyncio.Queue()
read_result_queue_2 = asyncio.Queue()
conn2.read.side_effect = make_queue_waiter(read_called_queue_2, read_result_queue_2)

# Fail the connection by failing the write call.
await write_result_queue.put(InternalServerError("failed"))
await sleep_queues[_MIN_BACKOFF_SECS].called.get()
await sleep_queues[_MIN_BACKOFF_SECS].results.put(None)

# Reinitialize
await write_called_queue_2.get()
write_result_queue_2.put_nowait(None)
conn2.write.assert_has_calls([call(initial_request)])

partitions = {Partition(5)}

# Send the second assignment on the new connection.
await read_called_queue_2.get()
await read_result_queue_2.put(as_response(partitions=partitions))
assert (await assign_fut2) == partitions
# No ack call ever made.
conn2.write.assert_has_calls([call(initial_request)])

0 comments on commit b2d0d36

Please sign in to comment.