Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement assigning subscriber #23

Merged
merged 4 commits into from Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,73 @@
from asyncio import Future, Queue, ensure_future
from typing import Callable, NamedTuple, Dict, Set

from google.cloud.pubsub_v1.subscriber.message import Message

from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.partition import Partition

_PartitionSubscriberFactory = Callable[[Partition], AsyncSubscriber]


class _RunningSubscriber(NamedTuple):
subscriber: AsyncSubscriber
poller: Future


class AssigningSubscriber(AsyncSubscriber, PermanentFailable):
_assigner: Assigner
_subscriber_factory: _PartitionSubscriberFactory

_subscribers: Dict[Partition, _RunningSubscriber]
_messages: "Queue[Message]"
_assign_poller: Future

def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory):
super().__init__()
self._assigner = assigner
self._subscriber_factory = subscriber_factory
self._subscribers = {}
self._messages = Queue()

async def read(self) -> Message:
return await self.await_unless_failed(self._messages.get())

async def _subscribe_action(self, subscriber: AsyncSubscriber):
message = await subscriber.read()
await self._messages.put(message)

async def _start_subscriber(self, partition: Partition):
new_subscriber = self._subscriber_factory(partition)
await new_subscriber.__aenter__()
poller = ensure_future(self.run_poller(lambda: self._subscribe_action(new_subscriber)))
self._subscribers[partition] = _RunningSubscriber(new_subscriber, poller)

async def _stop_subscriber(self, running: _RunningSubscriber):
running.poller.cancel()
await wait_ignore_cancelled(running.poller)
await running.subscriber.__aexit__(None, None, None)

async def _assign_action(self):
assignment: Set[Partition] = await self._assigner.get_assignment()
added_partitions = assignment - self._subscribers.keys()
removed_partitions = self._subscribers.keys() - assignment
for partition in added_partitions:
await self._start_subscriber(partition)
for partition in removed_partitions:
await self._stop_subscriber(self._subscribers[partition])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit; I would make stop_subscriber take a partition, and have it remove it from the map. Then we can never forget to keep the map in sync after calling this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nack. Then this function cannot be reused on teardown when looping over the active subscribers without making an explicit copy of the key set, otherwise you'll get a RuntimeError: dictionary changed size during iteration.

del self._subscribers[partition]

async def __aenter__(self):
await self._assigner.__aenter__()
self._assign_poller = ensure_future(self.run_poller(self._assign_action))
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self._assign_poller.cancel()
await wait_ignore_cancelled(self._assign_poller)
await self._assigner.__aexit__(exc_type, exc_value, traceback)
for running in self._subscribers.values():
await self._stop_subscriber(running)
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
def __enter__(self):
self._thread.start()

def __exit__(self, __exc_type, __exc_value, __traceback):
def __exit__(self, exc_type, exc_value, traceback):
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()

Expand Down
9 changes: 9 additions & 0 deletions google/cloud/pubsublite/internal/wait_ignore_cancelled.py
@@ -0,0 +1,9 @@
from asyncio import CancelledError
from typing import Awaitable


async def wait_ignore_cancelled(awaitable: Awaitable):
try:
await awaitable
except CancelledError:
pass
15 changes: 14 additions & 1 deletion google/cloud/pubsublite/internal/wire/permanent_failable.py
@@ -1,5 +1,5 @@
import asyncio
from typing import Awaitable, TypeVar, Optional
from typing import Awaitable, TypeVar, Optional, Callable

from google.api_core.exceptions import GoogleAPICallError

Expand Down Expand Up @@ -31,6 +31,19 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
task.cancel()
raise self._failure_task.exception()

async def run_poller(self, poll_action: Callable[[], Awaitable[None]]):
"""
Run a polling loop, which runs poll_action forever unless this is failed.
Args:
poll_action: A callable returning an awaitable to run in a loop. Note that async functions which return once
satisfy this.
"""
try:
while True:
await self.await_unless_failed(poll_action())
except GoogleAPICallError as e:
self.fail(e)

def fail(self, err: GoogleAPICallError):
if not self._failure_task.done():
self._failure_task.set_exception(err)
Expand Down
17 changes: 17 additions & 0 deletions google/cloud/pubsublite/testing/test_utils.py
@@ -1,6 +1,8 @@
import asyncio
from typing import List, Union, Any, TypeVar, Generic, Optional

from asynctest import CoroutineMock

T = TypeVar("T")


Expand All @@ -27,5 +29,20 @@ async def waiter(*args, **kwargs):
return waiter


class QueuePair:
called: asyncio.Queue
results: asyncio.Queue

def __init__(self):
self.called = asyncio.Queue()
self.results = asyncio.Queue()


def wire_queues(mock: CoroutineMock) -> QueuePair:
queues = QueuePair()
mock.side_effect = make_queue_waiter(queues.called, queues.results)
return queues


class Box(Generic[T]):
val: Optional[T]
@@ -0,0 +1,127 @@
import asyncio
from typing import Callable, Set

from asynctest.mock import MagicMock, call
import pytest
from google.api_core.exceptions import FailedPrecondition
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import AssigningSubscriber
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite.testing.test_utils import make_queue_waiter, wire_queues

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio


def mock_async_context_manager(cm):
cm.__aenter__.return_value = cm
return cm


@pytest.fixture()
def assigner():
return mock_async_context_manager(MagicMock(spec=Assigner))


@pytest.fixture()
def subscriber_factory():
return MagicMock(spec=Callable[[Partition], AsyncSubscriber])


@pytest.fixture()
def subscriber(assigner, subscriber_factory):
return AssigningSubscriber(assigner, subscriber_factory)


async def test_init(subscriber, assigner):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
assigner.__aenter__.assert_called_once()
await assign_queues.called.get()
assigner.get_assignment.assert_called_once()
assigner.__aexit__.assert_called_once()


async def test_initial_assignment(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
await assign_queues.results.put({Partition(1), Partition(2)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
sub1.__aenter__.assert_called_once()
sub2.__aenter__.assert_called_once()
sub1.__aexit__.assert_called_once()
sub2.__aexit__.assert_called_once()


async def test_assigner_failure(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
await assign_queues.results.put(FailedPrecondition("bad assign"))
with pytest.raises(FailedPrecondition):
await subscriber.read()


async def test_assignment_change(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub3 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(
1) else sub2 if partition == Partition(2) else sub3
await assign_queues.results.put({Partition(1), Partition(2)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
sub1.__aenter__.assert_called_once()
sub2.__aenter__.assert_called_once()
await assign_queues.results.put({Partition(1), Partition(3)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2)), call(Partition(3))], any_order=True)
sub3.__aenter__.assert_called_once()
sub2.__aexit__.assert_called_once()
sub1.__aexit__.assert_called_once()
sub2.__aexit__.assert_called_once()
sub3.__aexit__.assert_called_once()


async def test_subscriber_failure(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub1_queues = wire_queues(sub1.read)
subscriber_factory.return_value = sub1
await assign_queues.results.put({Partition(1)})
await sub1_queues.called.get()
await sub1_queues.results.put(FailedPrecondition("sub failed"))
with pytest.raises(FailedPrecondition):
await subscriber.read()


async def test_delivery_from_multiple(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub1_queues = wire_queues(sub1.read)
sub2_queues = wire_queues(sub2.read)
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
await assign_queues.results.put({Partition(1), Partition(2)})
await sub1_queues.results.put(Message(PubsubMessage(message_id="1")._pb, "", 0, None))
await sub2_queues.results.put(Message(PubsubMessage(message_id="2")._pb, "", 0, None))
message_ids: Set[str] = set()
message_ids.add((await subscriber.read()).message_id)
message_ids.add((await subscriber.read()).message_id)
assert message_ids == {"1", "2"}