Skip to content

Commit

Permalink
feat: implement assigning subscriber (#23)
Browse files Browse the repository at this point in the history
* feat: Implement SinglePartitionSubscriber.

This handles mapping a single partition to a Cloud Pub/Sub Like asynchronous subscriber.

* feat: Add DefaultNackHandler.

* feat: Add AssigningSubscriber.

This handles changing partition assignments and creates AsyncSubscribers per-partition.
  • Loading branch information
dpcollins-google committed Sep 24, 2020
1 parent bb76d90 commit 6afd477
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 2 deletions.
@@ -0,0 +1,73 @@
from asyncio import Future, Queue, ensure_future
from typing import Callable, NamedTuple, Dict, Set

from google.cloud.pubsub_v1.subscriber.message import Message

from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.partition import Partition

_PartitionSubscriberFactory = Callable[[Partition], AsyncSubscriber]


class _RunningSubscriber(NamedTuple):
subscriber: AsyncSubscriber
poller: Future


class AssigningSubscriber(AsyncSubscriber, PermanentFailable):
_assigner: Assigner
_subscriber_factory: _PartitionSubscriberFactory

_subscribers: Dict[Partition, _RunningSubscriber]
_messages: "Queue[Message]"
_assign_poller: Future

def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory):
super().__init__()
self._assigner = assigner
self._subscriber_factory = subscriber_factory
self._subscribers = {}
self._messages = Queue()

async def read(self) -> Message:
return await self.await_unless_failed(self._messages.get())

async def _subscribe_action(self, subscriber: AsyncSubscriber):
message = await subscriber.read()
await self._messages.put(message)

async def _start_subscriber(self, partition: Partition):
new_subscriber = self._subscriber_factory(partition)
await new_subscriber.__aenter__()
poller = ensure_future(self.run_poller(lambda: self._subscribe_action(new_subscriber)))
self._subscribers[partition] = _RunningSubscriber(new_subscriber, poller)

async def _stop_subscriber(self, running: _RunningSubscriber):
running.poller.cancel()
await wait_ignore_cancelled(running.poller)
await running.subscriber.__aexit__(None, None, None)

async def _assign_action(self):
assignment: Set[Partition] = await self._assigner.get_assignment()
added_partitions = assignment - self._subscribers.keys()
removed_partitions = self._subscribers.keys() - assignment
for partition in added_partitions:
await self._start_subscriber(partition)
for partition in removed_partitions:
await self._stop_subscriber(self._subscribers[partition])
del self._subscribers[partition]

async def __aenter__(self):
await self._assigner.__aenter__()
self._assign_poller = ensure_future(self.run_poller(self._assign_action))
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self._assign_poller.cancel()
await wait_ignore_cancelled(self._assign_poller)
await self._assigner.__aexit__(exc_type, exc_value, traceback)
for running in self._subscribers.values():
await self._stop_subscriber(running)
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
def __enter__(self):
self._thread.start()

def __exit__(self, __exc_type, __exc_value, __traceback):
def __exit__(self, exc_type, exc_value, traceback):
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()

Expand Down
9 changes: 9 additions & 0 deletions google/cloud/pubsublite/internal/wait_ignore_cancelled.py
@@ -0,0 +1,9 @@
from asyncio import CancelledError
from typing import Awaitable


async def wait_ignore_cancelled(awaitable: Awaitable):
try:
await awaitable
except CancelledError:
pass
15 changes: 14 additions & 1 deletion google/cloud/pubsublite/internal/wire/permanent_failable.py
@@ -1,5 +1,5 @@
import asyncio
from typing import Awaitable, TypeVar, Optional
from typing import Awaitable, TypeVar, Optional, Callable

from google.api_core.exceptions import GoogleAPICallError

Expand Down Expand Up @@ -31,6 +31,19 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
task.cancel()
raise self._failure_task.exception()

async def run_poller(self, poll_action: Callable[[], Awaitable[None]]):
"""
Run a polling loop, which runs poll_action forever unless this is failed.
Args:
poll_action: A callable returning an awaitable to run in a loop. Note that async functions which return once
satisfy this.
"""
try:
while True:
await self.await_unless_failed(poll_action())
except GoogleAPICallError as e:
self.fail(e)

def fail(self, err: GoogleAPICallError):
if not self._failure_task.done():
self._failure_task.set_exception(err)
Expand Down
17 changes: 17 additions & 0 deletions google/cloud/pubsublite/testing/test_utils.py
@@ -1,6 +1,8 @@
import asyncio
from typing import List, Union, Any, TypeVar, Generic, Optional

from asynctest import CoroutineMock

T = TypeVar("T")


Expand All @@ -27,5 +29,20 @@ async def waiter(*args, **kwargs):
return waiter


class QueuePair:
called: asyncio.Queue
results: asyncio.Queue

def __init__(self):
self.called = asyncio.Queue()
self.results = asyncio.Queue()


def wire_queues(mock: CoroutineMock) -> QueuePair:
queues = QueuePair()
mock.side_effect = make_queue_waiter(queues.called, queues.results)
return queues


class Box(Generic[T]):
val: Optional[T]
@@ -0,0 +1,127 @@
import asyncio
from typing import Callable, Set

from asynctest.mock import MagicMock, call
import pytest
from google.api_core.exceptions import FailedPrecondition
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import AssigningSubscriber
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite.testing.test_utils import make_queue_waiter, wire_queues

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio


def mock_async_context_manager(cm):
cm.__aenter__.return_value = cm
return cm


@pytest.fixture()
def assigner():
return mock_async_context_manager(MagicMock(spec=Assigner))


@pytest.fixture()
def subscriber_factory():
return MagicMock(spec=Callable[[Partition], AsyncSubscriber])


@pytest.fixture()
def subscriber(assigner, subscriber_factory):
return AssigningSubscriber(assigner, subscriber_factory)


async def test_init(subscriber, assigner):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
assigner.__aenter__.assert_called_once()
await assign_queues.called.get()
assigner.get_assignment.assert_called_once()
assigner.__aexit__.assert_called_once()


async def test_initial_assignment(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
await assign_queues.results.put({Partition(1), Partition(2)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
sub1.__aenter__.assert_called_once()
sub2.__aenter__.assert_called_once()
sub1.__aexit__.assert_called_once()
sub2.__aexit__.assert_called_once()


async def test_assigner_failure(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
await assign_queues.results.put(FailedPrecondition("bad assign"))
with pytest.raises(FailedPrecondition):
await subscriber.read()


async def test_assignment_change(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub3 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(
1) else sub2 if partition == Partition(2) else sub3
await assign_queues.results.put({Partition(1), Partition(2)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
sub1.__aenter__.assert_called_once()
sub2.__aenter__.assert_called_once()
await assign_queues.results.put({Partition(1), Partition(3)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2)), call(Partition(3))], any_order=True)
sub3.__aenter__.assert_called_once()
sub2.__aexit__.assert_called_once()
sub1.__aexit__.assert_called_once()
sub2.__aexit__.assert_called_once()
sub3.__aexit__.assert_called_once()


async def test_subscriber_failure(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub1_queues = wire_queues(sub1.read)
subscriber_factory.return_value = sub1
await assign_queues.results.put({Partition(1)})
await sub1_queues.called.get()
await sub1_queues.results.put(FailedPrecondition("sub failed"))
with pytest.raises(FailedPrecondition):
await subscriber.read()


async def test_delivery_from_multiple(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub1_queues = wire_queues(sub1.read)
sub2_queues = wire_queues(sub2.read)
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
await assign_queues.results.put({Partition(1), Partition(2)})
await sub1_queues.results.put(Message(PubsubMessage(message_id="1")._pb, "", 0, None))
await sub2_queues.results.put(Message(PubsubMessage(message_id="2")._pb, "", 0, None))
message_ids: Set[str] = set()
message_ids.add((await subscriber.read()).message_id)
message_ids.add((await subscriber.read()).message_id)
assert message_ids == {"1", "2"}

0 comments on commit 6afd477

Please sign in to comment.