From ba3140b22439d86bb2fce439fe625d618bf806c5 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Thu, 13 Aug 2020 11:14:21 -0400 Subject: [PATCH 1/7] feat: Implement committer Also small fix to retrying connection so it doesn't leak reads/writes from previous connections. --- google/cloud/pubsublite/internal/wire/retrying_connection.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index e97ebd9b..7fcf001b 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -53,6 +53,8 @@ async def _run_loop(self): bad_retries = 0 while True: try: + self._read_queue = asyncio.Queue(maxsize=1) + self._write_queue = asyncio.Queue(maxsize=1) async with self._connection_factory.new() as connection: # Needs to happen prior to reinitialization to clear outstanding waiters. while not self._write_queue.empty(): From 8d3465312a627bfbcc4ea662a77b1f3347119f18 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Fri, 14 Aug 2020 09:49:44 -0400 Subject: [PATCH 2/7] fix: Patch retrying connection and add comments. --- google/cloud/pubsublite/internal/wire/retrying_connection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index 7fcf001b..e97ebd9b 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -53,8 +53,6 @@ async def _run_loop(self): bad_retries = 0 while True: try: - self._read_queue = asyncio.Queue(maxsize=1) - self._write_queue = asyncio.Queue(maxsize=1) async with self._connection_factory.new() as connection: # Needs to happen prior to reinitialization to clear outstanding waiters. while not self._write_queue.empty(): From e3701188be3680fffd01f905728397600fc596e6 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Tue, 18 Aug 2020 15:11:58 -0400 Subject: [PATCH 3/7] feat: Implement assigner which generates subscription-partition assignments. Also slightly change the semantics of PermanentFailable to not fail a RetryingConnection on retryable errors from a watched awaitable. --- .../pubsublite/internal/wire/assigner.py | 15 ++ .../pubsublite/internal/wire/assigner_impl.py | 88 +++++++++ .../internal/wire/gapic_connection.py | 16 +- .../internal/wire/permanent_failable.py | 15 +- .../internal/wire/retrying_connection.py | 6 +- .../internal/wire/assigner_impl_test.py | 183 ++++++++++++++++++ 6 files changed, 311 insertions(+), 12 deletions(-) create mode 100644 google/cloud/pubsublite/internal/wire/assigner.py create mode 100644 google/cloud/pubsublite/internal/wire/assigner_impl.py create mode 100644 tests/unit/pubsublite/internal/wire/assigner_impl_test.py diff --git a/google/cloud/pubsublite/internal/wire/assigner.py b/google/cloud/pubsublite/internal/wire/assigner.py new file mode 100644 index 00000000..6307850a --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/assigner.py @@ -0,0 +1,15 @@ +from abc import abstractmethod +from typing import AsyncContextManager, Set + +from google.cloud.pubsublite.partition import Partition + + +class Assigner(AsyncContextManager): + """ + An assigner will deliver a continuous stream of assignments when called into. Perform all necessary work with the + assignment before attempting to get the next one. + """ + + @abstractmethod + async def get_assignment(self) -> Set[Partition]: + raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/assigner_impl.py b/google/cloud/pubsublite/internal/wire/assigner_impl.py new file mode 100644 index 00000000..65a0117c --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/assigner_impl.py @@ -0,0 +1,88 @@ +import asyncio +from typing import Optional, Set + +from absl import logging +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection, ConnectionFactory +from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError +from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer +from google.cloud.pubsublite.internal.wire.connection import Connection +from google.cloud.pubsublite.partition import Partition +from google.cloud.pubsublite_v1.types import PartitionAssignmentRequest, PartitionAssignment, \ + InitialPartitionAssignmentRequest, PartitionAssignmentAck + +# Maximum bytes per batch at 3.5 MiB to avoid GRPC limit of 4 MiB +_MAX_BYTES = int(3.5 * 1024 * 1024) + +# Maximum messages per batch at 1000 +_MAX_MESSAGES = 1000 + + +class AssignerImpl(Assigner, ConnectionReinitializer[PartitionAssignmentRequest, PartitionAssignment]): + _initial: InitialPartitionAssignmentRequest + _connection: RetryingConnection[PartitionAssignmentRequest, PartitionAssignment] + + _outstanding_assignment: bool + + _receiver: Optional[asyncio.Future] + + # A queue that may only hold one element with the next assignment. + _new_assignment: 'asyncio.Queue[Set[Partition]]' + + def __init__(self, initial: InitialPartitionAssignmentRequest, + factory: ConnectionFactory[PartitionAssignmentRequest, PartitionAssignment]): + self._initial = initial + self._connection = RetryingConnection(factory, self) + self._outstanding_assignment = False + self._receiver = None + self._new_assignment = asyncio.Queue(maxsize=1) + + async def __aenter__(self): + await self._connection.__aenter__() + + def _start_receiver(self): + assert self._receiver is None + self._receiver = asyncio.ensure_future(self._receive_loop()) + + async def _stop_receiver(self): + if self._receiver: + self._receiver.cancel() + await self._receiver + self._receiver = None + + async def _receive_loop(self): + try: + while True: + response = await self._connection.read() + if self._outstanding_assignment or not self._new_assignment.empty(): + self._connection.fail(FailedPrecondition( + "Received a duplicate assignment on the stream while one was outstanding.")) + return + self._outstanding_assignment = True + partitions = set() + for partition in response.partitions: + partitions.add(Partition(partition)) + self._new_assignment.put_nowait(partitions) + except asyncio.CancelledError: + return + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._connection.__aexit__(exc_type, exc_val, exc_tb) + + async def reinitialize(self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment]): + self._outstanding_assignment = False + while not self._new_assignment.empty(): + self._new_assignment.get_nowait() + await self._stop_receiver() + await connection.write(PartitionAssignmentRequest(initial=self._initial)) + self._start_receiver() + + async def get_assignment(self) -> Set[Partition]: + if self._outstanding_assignment: + try: + await self._connection.write(PartitionAssignmentRequest(ack=PartitionAssignmentAck())) + self._outstanding_assignment = False + except GoogleAPICallError as e: + # If there is a failure to ack, keep going. The stream likely restarted. + logging.debug(f"Assignment ack attempt failed due to stream failure: {e}") + return await self._connection.await_unless_failed(self._new_assignment.get()) diff --git a/google/cloud/pubsublite/internal/wire/gapic_connection.py b/google/cloud/pubsublite/internal/wire/gapic_connection.py index 0ae193a7..5e87d914 100644 --- a/google/cloud/pubsublite/internal/wire/gapic_connection.py +++ b/google/cloud/pubsublite/internal/wire/gapic_connection.py @@ -1,6 +1,8 @@ from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable import asyncio +from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition + from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response, ConnectionFactory from google.cloud.pubsublite.internal.wire.work_item import WorkItem from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable @@ -22,11 +24,17 @@ def set_response_it(self, response_it: AsyncIterator[Response]): async def write(self, request: Request) -> None: item = WorkItem(request) - await self.await_or_fail(self._write_queue.put(item)) - await self.await_or_fail(item.response_future) + await self.await_unless_failed(self._write_queue.put(item)) + await self.await_unless_failed(item.response_future) async def read(self) -> Response: - return await self.await_or_fail(self._response_it.__anext__()) + try: + return await self.await_unless_failed(self._response_it.__anext__()) + except StopAsyncIteration: + self.fail(FailedPrecondition("Server sent unprompted half close.")) + except GoogleAPICallError as e: + self.fail(e) + raise self.error() def __aenter__(self): return self @@ -35,7 +43,7 @@ def __aexit__(self, exc_type, exc_value, traceback) -> None: pass async def __anext__(self) -> Request: - item: WorkItem[Request] = await self.await_or_fail(self._write_queue.get()) + item: WorkItem[Request] = await self.await_unless_failed(self._write_queue.get()) item.response_future.set_result(None) return item.request diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index 3efa1c99..7d8a01f8 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -13,16 +13,21 @@ class PermanentFailable: def __init__(self): self._failure_task = asyncio.Future() - async def await_or_fail(self, awaitable: Awaitable[T]) -> T: + async def await_unless_failed(self, awaitable: Awaitable[T]) -> T: + """ + Await the awaitable, unless fail() is called first. + Args: + awaitable: An awaitable + + Returns: The result of the awaitable + Raises: The permanent error if fail() is called or the awaitable raises one. + """ if self._failure_task.done(): raise self._failure_task.exception() task = asyncio.ensure_future(awaitable) done, _ = await asyncio.wait([task, self._failure_task], return_when=asyncio.FIRST_COMPLETED) if task in done: - try: - return await task - except GoogleAPICallError as e: - self.fail(e) + return await task task.cancel() raise self._failure_task.exception() diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index e97ebd9b..dfc7eb87 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -38,11 +38,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def write(self, request: Request) -> None: item = WorkItem(request) - await self.await_or_fail(self._write_queue.put(item)) - return await self.await_or_fail(item.response_future) + await self.await_unless_failed(self._write_queue.put(item)) + return await self.await_unless_failed(item.response_future) async def read(self) -> Response: - return await self.await_or_fail(self._read_queue.get()) + return await self.await_unless_failed(self._read_queue.get()) async def _run_loop(self): """ diff --git a/tests/unit/pubsublite/internal/wire/assigner_impl_test.py b/tests/unit/pubsublite/internal/wire/assigner_impl_test.py new file mode 100644 index 00000000..14dc8abf --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/assigner_impl_test.py @@ -0,0 +1,183 @@ +import asyncio +from unittest.mock import call +from collections import defaultdict +from typing import Dict, Set + +from asynctest.mock import MagicMock, CoroutineMock +import pytest + +from google.cloud.pubsublite.internal.wire.assigner import Assigner +from google.cloud.pubsublite.internal.wire.assigner_impl import AssignerImpl +from google.cloud.pubsublite.internal.wire.connection import Connection, ConnectionFactory +from google.api_core.exceptions import InternalServerError + +from google.cloud.pubsublite.partition import Partition +from google.cloud.pubsublite_v1.types.subscriber import PartitionAssignmentRequest, InitialPartitionAssignmentRequest, \ + PartitionAssignment, PartitionAssignmentAck +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter +from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def default_connection(): + conn = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment]) + conn.__aenter__.return_value = conn + return conn + + +@pytest.fixture() +def connection_factory(default_connection): + factory = MagicMock(spec=ConnectionFactory[PartitionAssignmentRequest, PartitionAssignment]) + factory.new.return_value = default_connection + return factory + + +@pytest.fixture() +def initial_request(): + return PartitionAssignmentRequest(initial=InitialPartitionAssignmentRequest(subscription="mysub")) + + +class QueuePair: + called: asyncio.Queue + results: asyncio.Queue + + def __init__(self): + self.called = asyncio.Queue() + self.results = asyncio.Queue() + + +@pytest.fixture +def sleep_queues() -> Dict[float, QueuePair]: + return defaultdict(QueuePair) + + +@pytest.fixture +def asyncio_sleep(monkeypatch, sleep_queues): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + mock = CoroutineMock() + monkeypatch.setattr(asyncio, "sleep", mock) + + async def sleeper(delay: float): + await make_queue_waiter(sleep_queues[delay].called, sleep_queues[delay].results)(delay) + + mock.side_effect = sleeper + return mock + + +@pytest.fixture() +def assigner(connection_factory, initial_request): + return AssignerImpl(initial_request.initial, connection_factory) + + +def as_response(partitions: Set[Partition]): + req = PartitionAssignment() + req.partitions = [partition.value for partition in partitions] + return req + + +def ack_request(): + return PartitionAssignmentRequest(ack=PartitionAssignmentAck()) + + +async def test_basic_assign( + assigner: Assigner, default_connection, initial_request): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + write_result_queue.put_nowait(None) + async with assigner: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Wait for the first assignment + assign_fut1 = asyncio.ensure_future(assigner.get_assignment()) + assert not assign_fut1.done() + + partitions = {Partition(2), Partition(7)} + + # Send the first assignment. + await read_result_queue.put(as_response(partitions=partitions)) + assert (await assign_fut1) == partitions + + # Get the next assignment: should send an ack on the stream + assign_fut2 = asyncio.ensure_future(assigner.get_assignment()) + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls([call(initial_request), call(ack_request())]) + + partitions = {Partition(5)} + + # Send the second assignment. + await read_called_queue.get() + await read_result_queue.put(as_response(partitions=partitions)) + assert (await assign_fut2) == partitions + + +async def test_restart( + assigner: Assigner, default_connection, connection_factory, initial_request, asyncio_sleep, sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + write_result_queue.put_nowait(None) + async with assigner: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Wait for the first assignment + assign_fut1 = asyncio.ensure_future(assigner.get_assignment()) + assert not assign_fut1.done() + + partitions = {Partition(2), Partition(7)} + + # Send the first assignment. + await read_result_queue.put(as_response(partitions=partitions)) + await read_called_queue.get() + assert (await assign_fut1) == partitions + + # Get the next assignment: should attempt to send an ack on the stream + assign_fut2 = asyncio.ensure_future(assigner.get_assignment()) + await write_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request), call(ack_request())]) + + # Set up the next connection + conn2 = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment]) + conn2.__aenter__.return_value = conn2 + connection_factory.new.return_value = conn2 + write_called_queue_2 = asyncio.Queue() + write_result_queue_2 = asyncio.Queue() + conn2.write.side_effect = make_queue_waiter(write_called_queue_2, write_result_queue_2) + read_called_queue_2 = asyncio.Queue() + read_result_queue_2 = asyncio.Queue() + conn2.read.side_effect = make_queue_waiter(read_called_queue_2, read_result_queue_2) + + # Fail the connection by failing the write call. + await write_result_queue.put(InternalServerError("failed")) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + + # Reinitialize + await write_called_queue_2.get() + write_result_queue_2.put_nowait(None) + conn2.write.assert_has_calls([call(initial_request)]) + + partitions = {Partition(5)} + + # Send the second assignment on the new connection. + await read_called_queue_2.get() + await read_result_queue_2.put(as_response(partitions=partitions)) + assert (await assign_fut2) == partitions + # No ack call ever made. + conn2.write.assert_has_calls([call(initial_request)]) From 0500806b2188520193c618f4da0ad507d13a6272 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Thu, 10 Sep 2020 16:14:03 -0400 Subject: [PATCH 4/7] feat: Implement FlowControlBatcher This handles aggregating flow control requests without allowing them to get above the max int64 value. --- .../internal/wire/flow_control_batcher.py | 67 +++++++++++++++++++ .../wire/flow_control_batcher_test.py | 28 ++++++++ 2 files changed, 95 insertions(+) create mode 100644 google/cloud/pubsublite/internal/wire/flow_control_batcher.py create mode 100644 tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py diff --git a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py new file mode 100644 index 00000000..821c42e1 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py @@ -0,0 +1,67 @@ +from typing import NamedTuple, List, Optional + +from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage + +_EXPEDITE_BATCH_REQUEST_RATIO = 0.5 +_MAX_INT64 = 0x7FFFFFFFFFFFFFFF + + +class _AggregateRequest: + request: FlowControlRequest + + def __init__(self): + self.request = FlowControlRequest() + + def __add__(self, other: FlowControlRequest): + self.request.allowed_bytes += other.allowed_bytes + self.request.allowed_bytes = min(self.request.allowed_bytes, _MAX_INT64) + self.request.allowed_messages += other.allowed_messages + self.request.allowed_messages = min(self.request.allowed_messages, _MAX_INT64) + return self + + +def _exceeds_expedite_ratio(pending: int, client: int): + if client <= 0: + return False + return (pending/client) >= _EXPEDITE_BATCH_REQUEST_RATIO + + +def _to_optional(req: FlowControlRequest) -> Optional[FlowControlRequest]: + if req.allowed_messages == 0 and req.allowed_bytes == 0: + return None + return req + + +class FlowControlBatcher: + _client_tokens: _AggregateRequest + _pending_tokens: _AggregateRequest + + def __init__(self): + self._client_tokens = _AggregateRequest() + self._pending_tokens = _AggregateRequest() + + def add(self, request: FlowControlRequest): + self._client_tokens += request + self._pending_tokens += request + + def on_messages(self, messages: List[SequencedMessage]): + byte_size = sum(message.size_bytes for message in messages) + self._client_tokens += FlowControlRequest(allowed_bytes=-byte_size, allowed_messages=-len(messages)) + + def request_for_restart(self) -> Optional[FlowControlRequest]: + self._pending_tokens = _AggregateRequest() + return _to_optional(self._client_tokens.request) + + def release_pending_request(self) -> Optional[FlowControlRequest]: + request = self._pending_tokens.request + self._pending_tokens = _AggregateRequest() + return _to_optional(request) + + def should_expedite(self): + pending_request = self._pending_tokens.request + client_request = self._pending_tokens.request + if _exceeds_expedite_ratio(pending_request.allowed_bytes, client_request.allowed_bytes): + return True + if _exceeds_expedite_ratio(pending_request.allowed_messages, client_request.allowed_messages): + return True + return False diff --git a/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py b/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py new file mode 100644 index 00000000..45e78c70 --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py @@ -0,0 +1,28 @@ +from google.cloud.pubsublite.internal.wire.flow_control_batcher import FlowControlBatcher +from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage + + +def test_restart_clears_send(): + batcher = FlowControlBatcher() + batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3)) + assert batcher.should_expedite() + to_send = batcher.release_pending_request() + assert to_send.allowed_bytes == 10 + assert to_send.allowed_messages == 3 + restart_1 = batcher.request_for_restart() + assert restart_1.allowed_bytes == 10 + assert restart_1.allowed_messages == 3 + assert not batcher.should_expedite() + assert batcher.release_pending_request() is None + + +def test_add_remove(): + batcher = FlowControlBatcher() + batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3)) + restart_1 = batcher.request_for_restart() + assert restart_1.allowed_bytes == 10 + assert restart_1.allowed_messages == 3 + batcher.on_messages([SequencedMessage(size_bytes=2), SequencedMessage(size_bytes=3)]) + restart_2 = batcher.request_for_restart() + assert restart_2.allowed_bytes == 5 + assert restart_2.allowed_messages == 1 From 5eab873fb69efc908335f5a5b58fb3631b54a8ef Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Mon, 14 Sep 2020 12:01:39 -0400 Subject: [PATCH 5/7] Use correct request for comparisson. --- google/cloud/pubsublite/internal/wire/flow_control_batcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py index 821c42e1..7afce610 100644 --- a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py +++ b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py @@ -59,7 +59,7 @@ def release_pending_request(self) -> Optional[FlowControlRequest]: def should_expedite(self): pending_request = self._pending_tokens.request - client_request = self._pending_tokens.request + client_request = self._client_tokens.request if _exceeds_expedite_ratio(pending_request.allowed_bytes, client_request.allowed_bytes): return True if _exceeds_expedite_ratio(pending_request.allowed_messages, client_request.allowed_messages): From 6c44e92ad53cf309d82a7062d8c6635a1bdbdfc4 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Mon, 14 Sep 2020 13:47:35 -0400 Subject: [PATCH 6/7] feat: Implement Subscriber, which handles flow control and batch message processing. Also ensure all asynchronous loopers are torn down when their underlying objects are. --- .../pubsublite/internal/wire/assigner_impl.py | 4 +- .../internal/wire/committer_impl.py | 10 +- .../internal/wire/routing_publisher.py | 1 + .../wire/single_partition_publisher.py | 4 +- .../pubsublite/internal/wire/subscriber.py | 28 ++ .../internal/wire/subscriber_impl.py | 135 ++++++++ .../internal/wire/subscriber_impl_test.py | 296 ++++++++++++++++++ 7 files changed, 473 insertions(+), 5 deletions(-) create mode 100644 google/cloud/pubsublite/internal/wire/subscriber.py create mode 100644 google/cloud/pubsublite/internal/wire/subscriber_impl.py create mode 100644 tests/unit/pubsublite/internal/wire/subscriber_impl_test.py diff --git a/google/cloud/pubsublite/internal/wire/assigner_impl.py b/google/cloud/pubsublite/internal/wire/assigner_impl.py index 65a0117c..c754d0c2 100644 --- a/google/cloud/pubsublite/internal/wire/assigner_impl.py +++ b/google/cloud/pubsublite/internal/wire/assigner_impl.py @@ -39,6 +39,7 @@ def __init__(self, initial: InitialPartitionAssignmentRequest, async def __aenter__(self): await self._connection.__aenter__() + return self def _start_receiver(self): assert self._receiver is None @@ -63,10 +64,11 @@ async def _receive_loop(self): for partition in response.partitions: partitions.add(Partition(partition)) self._new_assignment.put_nowait(partitions) - except asyncio.CancelledError: + except (asyncio.CancelledError, GoogleAPICallError): return async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stop_receiver() await self._connection.__aexit__(exc_type, exc_val, exc_tb) async def reinitialize(self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment]): diff --git a/google/cloud/pubsublite/internal/wire/committer_impl.py b/google/cloud/pubsublite/internal/wire/committer_impl.py index 36a84a1c..dd3a907e 100644 --- a/google/cloud/pubsublite/internal/wire/committer_impl.py +++ b/google/cloud/pubsublite/internal/wire/committer_impl.py @@ -10,11 +10,13 @@ from google.cloud.pubsublite.internal.wire.connection import Connection from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher, BatchTester from google.cloud.pubsublite_v1 import Cursor -from google.cloud.pubsublite_v1.types import StreamingCommitCursorRequest, StreamingCommitCursorResponse, InitialCommitCursorRequest +from google.cloud.pubsublite_v1.types import StreamingCommitCursorRequest, StreamingCommitCursorResponse, \ + InitialCommitCursorRequest from google.cloud.pubsublite.internal.wire.work_item import WorkItem -class CommitterImpl(Committer, ConnectionReinitializer[StreamingCommitCursorRequest, StreamingCommitCursorResponse], BatchTester[Cursor]): +class CommitterImpl(Committer, ConnectionReinitializer[StreamingCommitCursorRequest, StreamingCommitCursorResponse], + BatchTester[Cursor]): _initial: InitialCommitCursorRequest _flush_seconds: float _connection: RetryingConnection[StreamingCommitCursorRequest, StreamingCommitCursorResponse] @@ -38,6 +40,7 @@ def __init__(self, initial: InitialCommitCursorRequest, flush_seconds: float, async def __aenter__(self): await self._connection.__aenter__() + return self def _start_loopers(self): assert self._receiver is None @@ -71,7 +74,7 @@ async def _receive_loop(self): while True: response = await self._connection.read() self._handle_response(response) - except asyncio.CancelledError: + except (asyncio.CancelledError, GoogleAPICallError): return async def _flush_loop(self): @@ -83,6 +86,7 @@ async def _flush_loop(self): return async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stop_loopers() if self._connection.error(): self._fail_if_retrying_failed() else: diff --git a/google/cloud/pubsublite/internal/wire/routing_publisher.py b/google/cloud/pubsublite/internal/wire/routing_publisher.py index 3f269550..6b8c0b77 100644 --- a/google/cloud/pubsublite/internal/wire/routing_publisher.py +++ b/google/cloud/pubsublite/internal/wire/routing_publisher.py @@ -18,6 +18,7 @@ def __init__(self, routing_policy: RoutingPolicy, publishers: Mapping[Partition, async def __aenter__(self): for publisher in self._publishers.values(): await publisher.__aenter__() + return self async def __aexit__(self, exc_type, exc_val, exc_tb): for publisher in self._publishers.values(): diff --git a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py index 081afa70..6c6edcc6 100644 --- a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py +++ b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py @@ -48,6 +48,7 @@ def _partition(self) -> Partition: async def __aenter__(self): await self._connection.__aenter__() + return self def _start_loopers(self): assert self._receiver is None @@ -82,7 +83,7 @@ async def _receive_loop(self): while True: response = await self._connection.read() self._handle_response(response) - except asyncio.CancelledError: + except (asyncio.CancelledError, GoogleAPICallError): return async def _flush_loop(self): @@ -98,6 +99,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self._fail_if_retrying_failed() else: await self._flush() + await self._stop_loopers() await self._connection.__aexit__(exc_type, exc_val, exc_tb) def _fail_if_retrying_failed(self): diff --git a/google/cloud/pubsublite/internal/wire/subscriber.py b/google/cloud/pubsublite/internal/wire/subscriber.py new file mode 100644 index 00000000..0e02024d --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/subscriber.py @@ -0,0 +1,28 @@ +from abc import abstractmethod +from typing import AsyncContextManager +from google.cloud.pubsublite_v1.types import SequencedMessage, FlowControlRequest + + +class Subscriber(AsyncContextManager): + """ + A Pub/Sub Lite asynchronous wire protocol subscriber. + """ + @abstractmethod + async def read(self) -> SequencedMessage: + """ + Read the next message off of the stream. + + Returns: + The next message. + + Raises: + GoogleAPICallError: On a permanent error. + """ + raise NotImplementedError() + + @abstractmethod + async def allow_flow(self, request: FlowControlRequest): + """ + Allow an additional amount of messages and bytes to be sent to this client. + """ + raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/subscriber_impl.py b/google/cloud/pubsublite/internal/wire/subscriber_impl.py new file mode 100644 index 00000000..4a28a8be --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/subscriber_impl.py @@ -0,0 +1,135 @@ +import asyncio +from typing import Optional + +from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition + +from google.cloud.pubsublite.internal.wire.connection import Request, Connection, Response, ConnectionFactory +from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer +from google.cloud.pubsublite.internal.wire.flow_control_batcher import FlowControlBatcher +from google.cloud.pubsublite.internal.wire.retrying_connection import RetryingConnection +from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite_v1 import SubscribeRequest, SubscribeResponse, FlowControlRequest, SequencedMessage, \ + InitialSubscribeRequest, SeekRequest, Cursor + + +class SubscriberImpl(Subscriber, ConnectionReinitializer[SubscribeRequest, SubscribeResponse]): + _initial: InitialSubscribeRequest + _token_flush_seconds: float + _connection: RetryingConnection[SubscribeRequest, SubscribeResponse] + + _outstanding_flow_control: FlowControlBatcher + + _reinitializing: bool + _last_received_offset: Optional[int] + + _message_queue: 'asyncio.Queue[SequencedMessage]' + + _receiver: Optional[asyncio.Future] + _flusher: Optional[asyncio.Future] + + def __init__(self, initial: InitialSubscribeRequest, token_flush_seconds: float, + factory: ConnectionFactory[SubscribeRequest, SubscribeResponse]): + self._initial = initial + self._token_flush_seconds = token_flush_seconds + self._connection = RetryingConnection(factory, self) + self._outstanding_flow_control = FlowControlBatcher() + self._reinitializing = False + self._last_received_offset = None + self._message_queue = asyncio.Queue() + self._receiver = None + self._flusher = None + + async def __aenter__(self): + await self._connection.__aenter__() + return self + + def _start_loopers(self): + assert self._receiver is None + assert self._flusher is None + self._receiver = asyncio.ensure_future(self._receive_loop()) + self._flusher = asyncio.ensure_future(self._flush_loop()) + + async def _stop_loopers(self): + if self._receiver: + self._receiver.cancel() + await self._receiver + self._receiver = None + if self._flusher: + self._flusher.cancel() + await self._flusher + self._flusher = None + + def _handle_response(self, response: SubscribeResponse): + if "messages" not in response: + self._connection.fail(FailedPrecondition("Received an invalid subsequent response on the subscribe stream.")) + return + self._outstanding_flow_control.on_messages(response.messages.messages) + for message in response.messages.messages: + if self._last_received_offset is not None and message.cursor.offset <= self._last_received_offset: + self._connection.fail(FailedPrecondition( + "Received an invalid out of order message from the server. Message is {}, previous last received is {}.".format( + message.cursor.offset, self._last_received_offset))) + return + self._last_received_offset = message.cursor.offset + for message in response.messages.messages: + # queue is unbounded. + self._message_queue.put_nowait(message) + + async def _receive_loop(self): + try: + while True: + response = await self._connection.read() + self._handle_response(response) + except (asyncio.CancelledError, GoogleAPICallError): + return + + async def _try_send_tokens(self): + req = self._outstanding_flow_control.release_pending_request() + if req is None: + return + try: + await self._connection.write(SubscribeRequest(flow_control=req)) + except GoogleAPICallError: + # May be transient, in which case these tokens will be resent. + pass + + async def _flush_loop(self): + try: + while True: + await asyncio.sleep(self._token_flush_seconds) + await self._try_send_tokens() + except asyncio.CancelledError: + return + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stop_loopers() + await self._connection.__aexit__(exc_type, exc_val, exc_tb) + + async def reinitialize(self, connection: Connection[SubscribeRequest, SubscribeResponse]): + self._reinitializing = True + await self._stop_loopers() + await connection.write(SubscribeRequest(initial=self._initial)) + response = await connection.read() + if "initial" not in response: + self._connection.fail(FailedPrecondition("Received an invalid initial response on the subscribe stream.")) + return + if self._last_received_offset is not None: + # Perform a seek to get the next message after the one we received. + await connection.write(SubscribeRequest(seek=SeekRequest(cursor=Cursor(offset=self._last_received_offset + 1)))) + seek_response = await connection.read() + if "seek" not in seek_response: + self._connection.fail(FailedPrecondition("Received an invalid seek response on the subscribe stream.")) + return + tokens = self._outstanding_flow_control.request_for_restart() + if tokens is not None: + await connection.write(SubscribeRequest(flow_control=tokens)) + self._reinitializing = False + self._start_loopers() + + async def read(self) -> SequencedMessage: + return await self._connection.await_unless_failed(self._message_queue.get()) + + async def allow_flow(self, request: FlowControlRequest): + self._outstanding_flow_control.add(request) + if not self._reinitializing and self._outstanding_flow_control.should_expedite(): + await self._try_send_tokens() diff --git a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py new file mode 100644 index 00000000..25564f48 --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py @@ -0,0 +1,296 @@ +import asyncio +from unittest.mock import call +from collections import defaultdict +from typing import Dict, List + +from asynctest.mock import MagicMock, CoroutineMock +import pytest +from grpc import StatusCode + +from google.cloud.pubsublite.internal.wire.connection import Connection, ConnectionFactory +from google.api_core.exceptions import InternalServerError, GoogleAPICallError + +from google.cloud.pubsublite.internal.wire.subscriber import Subscriber +from google.cloud.pubsublite.internal.wire.subscriber_impl import SubscriberImpl +from google.cloud.pubsublite_v1 import SubscribeRequest, SubscribeResponse, InitialSubscribeRequest, FlowControlRequest, \ + SeekRequest +from google.cloud.pubsublite_v1.types.common import Cursor, SequencedMessage +from google.cloud.pubsublite.testing.test_utils import make_queue_waiter +from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS + +FLUSH_SECONDS = 100000 + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def default_connection(): + conn = MagicMock(spec=Connection[SubscribeRequest, SubscribeResponse]) + conn.__aenter__.return_value = conn + return conn + + +@pytest.fixture() +def connection_factory(default_connection): + factory = MagicMock(spec=ConnectionFactory[SubscribeRequest, SubscribeResponse]) + factory.new.return_value = default_connection + return factory + + +@pytest.fixture() +def initial_request(): + return SubscribeRequest(initial=InitialSubscribeRequest(subscription="mysub")) + + +class QueuePair: + called: asyncio.Queue + results: asyncio.Queue + + def __init__(self): + self.called = asyncio.Queue() + self.results = asyncio.Queue() + + +@pytest.fixture +def sleep_queues() -> Dict[float, QueuePair]: + return defaultdict(QueuePair) + + +@pytest.fixture +def asyncio_sleep(monkeypatch, sleep_queues): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + mock = CoroutineMock() + monkeypatch.setattr(asyncio, "sleep", mock) + + async def sleeper(delay: float): + await make_queue_waiter(sleep_queues[delay].called, sleep_queues[delay].results)(delay) + + mock.side_effect = sleeper + return mock + + +@pytest.fixture() +def subscriber(connection_factory, initial_request): + return SubscriberImpl(initial_request.initial, FLUSH_SECONDS, connection_factory) + + +def as_request(flow: FlowControlRequest): + req = SubscribeRequest() + req.flow_control = flow + return req + + +def as_response(messages: List[SequencedMessage]): + res = SubscribeResponse() + res.messages.messages = messages + return res + + +async def test_basic_flow_control_after_timeout( + subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, sleep_queues): + sleep_called = sleep_queues[FLUSH_SECONDS].called + sleep_results = sleep_queues[FLUSH_SECONDS].results + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow_1 = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + flow_2 = FlowControlRequest(allowed_messages=5, allowed_bytes=10) + flow_3 = FlowControlRequest(allowed_messages=10, allowed_bytes=5) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut1 = asyncio.ensure_future(subscriber.allow_flow(flow_1)) + assert not flow_fut1.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut1 + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1))]) + + # Should complete without writing to the connection + await subscriber.allow_flow(flow_2) + await subscriber.allow_flow(flow_3) + + # Wait for writes to be waiting + await sleep_called.get() + asyncio_sleep.assert_called_with(FLUSH_SECONDS) + + # Handle the connection write + await sleep_results.put(None) + await write_called_queue.get() + await write_result_queue.put(None) + # Called with aggregate + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1)), call( + as_request(FlowControlRequest(allowed_messages=15, allowed_bytes=15)))]) + + +async def test_flow_resent_on_restart(subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, + sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow_1 = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + flow_2 = FlowControlRequest(allowed_messages=5, allowed_bytes=10) + flow_3 = FlowControlRequest(allowed_messages=10, allowed_bytes=5) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut1 = asyncio.ensure_future(subscriber.allow_flow(flow_1)) + assert not flow_fut1.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut1 + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1))]) + + # Should complete without writing to the connection + await subscriber.allow_flow(flow_2) + await subscriber.allow_flow(flow_3) + + # Fail the connection with a retryable error + await read_called_queue.get() + await read_result_queue.put(InternalServerError("retryable")) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + # Reinitialization + await write_called_queue.get() + await write_result_queue.put(None) + await read_called_queue.get() + await read_result_queue.put(SubscribeResponse(initial={})) + # Re-sending flow tokens on the new stream + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow_1)), call(initial_request), + call(as_request( + FlowControlRequest(allowed_messages=115, allowed_bytes=115)))]) + + +async def test_message_receipt(subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, + sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5) + message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) + assert not flow_fut.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow))]) + + message1_fut = asyncio.ensure_future(subscriber.read()) + + # Send messages to the subscriber. + await read_result_queue.put(as_response([message_1, message_2])) + # Wait for the next read call + await read_called_queue.get() + + assert (await message1_fut) == message_1 + assert (await subscriber.read()) == message_2 + + # Fail the connection with a retryable error + await read_called_queue.get() + await read_result_queue.put(InternalServerError("retryable")) + await sleep_queues[_MIN_BACKOFF_SECS].called.get() + await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) + # Reinitialization + await write_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow)), call(initial_request)]) + await write_result_queue.put(None) + await read_called_queue.get() + await read_result_queue.put(SubscribeResponse(initial={})) + # Sends fetch offset seek on the stream, and checks the response. + seek_req = SubscribeRequest(seek=SeekRequest(cursor=Cursor(offset=message_2.cursor.offset + 1))) + await write_called_queue.get() + default_connection.write.assert_has_calls( + [call(initial_request), call(as_request(flow)), call(initial_request), call(seek_req)]) + await write_result_queue.put(None) + await read_called_queue.get() + await read_result_queue.put(SubscribeResponse(seek={})) + # Re-sending flow tokens on the new stream. + await write_called_queue.get() + await write_result_queue.put(None) + default_connection.write.assert_has_calls( + [call(initial_request), call(as_request(flow)), call(initial_request), call(seek_req), + call(as_request(FlowControlRequest(allowed_messages=98, allowed_bytes=85)))]) + + +async def test_out_of_order_receipt_failure(subscriber: Subscriber, default_connection, initial_request, asyncio_sleep, + sleep_queues): + write_called_queue = asyncio.Queue() + write_result_queue = asyncio.Queue() + flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) + message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5) + message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10) + default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue) + read_called_queue = asyncio.Queue() + read_result_queue = asyncio.Queue() + default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue) + read_result_queue.put_nowait(SubscribeResponse(initial={})) + write_result_queue.put_nowait(None) + async with subscriber: + # Set up connection + await write_called_queue.get() + await read_called_queue.get() + default_connection.write.assert_has_calls([call(initial_request)]) + + # Send tokens. + flow_fut = asyncio.ensure_future(subscriber.allow_flow(flow)) + assert not flow_fut.done() + + # Handle the inline write since initial tokens are 100% of outstanding. + await write_called_queue.get() + await write_result_queue.put(None) + await flow_fut + default_connection.write.assert_has_calls([call(initial_request), call(as_request(flow))]) + + read_fut = asyncio.ensure_future(subscriber.read()) + + # Send out of order messages to the subscriber. + await read_result_queue.put(as_response([message_2, message_1])) + # Wait for the next read call + await read_called_queue.get() + + try: + await read_fut + assert False + except GoogleAPICallError as e: + assert e.grpc_status_code == StatusCode.FAILED_PRECONDITION + pass From 159d3efd98b6f342a8f18f93e680eee3869214e3 Mon Sep 17 00:00:00 2001 From: Daniel Collins Date: Mon, 14 Sep 2020 15:25:42 -0400 Subject: [PATCH 7/7] feat: Implement AdminClient, which helps users perform admin operations in a given region. --- google/cloud/pubsublite/admin_client.py | 73 +++++++++++++++++++ .../internal/wire/admin_client_impl.py | 61 ++++++++++++++++ .../internal/wire/make_publisher.py | 12 ++- google/cloud/pubsublite/location.py | 10 +++ google/cloud/pubsublite/paths.py | 44 +++++++++++ 5 files changed, 193 insertions(+), 7 deletions(-) create mode 100644 google/cloud/pubsublite/admin_client.py create mode 100644 google/cloud/pubsublite/internal/wire/admin_client_impl.py diff --git a/google/cloud/pubsublite/admin_client.py b/google/cloud/pubsublite/admin_client.py new file mode 100644 index 00000000..5d400e26 --- /dev/null +++ b/google/cloud/pubsublite/admin_client.py @@ -0,0 +1,73 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from google.api_core.client_options import ClientOptions +from google.protobuf.field_mask_pb2 import FieldMask + +from google.cloud.pubsublite.endpoints import regional_endpoint +from google.cloud.pubsublite.internal.wire.admin_client_impl import AdminClientImpl +from google.cloud.pubsublite.location import CloudRegion +from google.cloud.pubsublite.paths import TopicPath, LocationPath, SubscriptionPath +from google.cloud.pubsublite_v1 import Topic, Subscription, AdminServiceClient +from google.auth.credentials import Credentials + + +class AdminClient(ABC): + @abstractmethod + def region(self) -> CloudRegion: + """The region this client is for.""" + + @abstractmethod + def create_topic(self, topic: Topic) -> Topic: + """Create a topic, returns the created topic.""" + + @abstractmethod + def get_topic(self, topic_path: TopicPath) -> Topic: + """Get the topic object from the server.""" + + @abstractmethod + def get_topic_partition_count(self, topic_path: TopicPath) -> int: + """Get the number of partitions in the provided topic.""" + + @abstractmethod + def list_topics(self, location_path: LocationPath) -> List[Topic]: + """List the Pub/Sub lite topics that exist for a project in a given location.""" + + @abstractmethod + def update_topic(self, topic: Topic, update_mask: FieldMask) -> Topic: + """Update the masked fields of the provided topic.""" + + @abstractmethod + def delete_topic(self, topic_path: TopicPath): + """Delete a topic and all associated messages.""" + + @abstractmethod + def list_topic_subscriptions(self, topic_path: TopicPath): + """List the subscriptions that exist for a given topic.""" + + @abstractmethod + def create_subscription(self, subscription: Subscription) -> Subscription: + """Create a subscription, returns the created subscription.""" + + @abstractmethod + def get_subscription(self, subscription_path: SubscriptionPath) -> Subscription: + """Get the subscription object from the server.""" + + @abstractmethod + def list_subscriptions(self, location_path: LocationPath) -> List[Subscription]: + """List the Pub/Sub lite subscriptions that exist for a project in a given location.""" + + @abstractmethod + def update_subscription(self, subscription: Subscription, update_mask: FieldMask) -> Subscription: + """Update the masked fields of the provided subscription.""" + + @abstractmethod + def delete_subscription(self, subscription_path: SubscriptionPath): + """Delete a subscription and all associated messages.""" + + +def make_admin_client(region: CloudRegion, credentials: Optional[Credentials] = None, + client_options: Optional[ClientOptions] = None) -> AdminClient: + if client_options is None: + client_options = ClientOptions(api_endpoint=regional_endpoint(region)) + return AdminClientImpl(AdminServiceClient(client_options=client_options, credentials=credentials), region) diff --git a/google/cloud/pubsublite/internal/wire/admin_client_impl.py b/google/cloud/pubsublite/internal/wire/admin_client_impl.py new file mode 100644 index 00000000..3962e2cf --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/admin_client_impl.py @@ -0,0 +1,61 @@ +from typing import List + +from google.protobuf.field_mask_pb2 import FieldMask + +from google.cloud.pubsublite.admin_client import AdminClient +from google.cloud.pubsublite.location import CloudRegion +from google.cloud.pubsublite.paths import SubscriptionPath, LocationPath, TopicPath +from google.cloud.pubsublite_v1 import Subscription, Topic, AdminServiceClient, TopicPartitions + + +class AdminClientImpl(AdminClient): + _underlying: AdminServiceClient + _region: CloudRegion + + def __init__(self, underlying: AdminServiceClient, region: CloudRegion): + self._underlying = underlying + self._region = region + + def region(self) -> CloudRegion: + return self._region + + def create_topic(self, topic: Topic) -> Topic: + path = TopicPath.parse(topic.name) + return self._underlying.create_topic(parent=str(path.to_location_path()), topic=topic, topic_id=path.name) + + def get_topic(self, topic_path: TopicPath) -> Topic: + return self._underlying.get_topic(name=str(topic_path)) + + def get_topic_partition_count(self, topic_path: TopicPath) -> int: + partitions: TopicPartitions = self._underlying.get_topic_partitions(name=str(topic_path)) + return partitions.partition_count + + def list_topics(self, location_path: LocationPath) -> List[Topic]: + return [x for x in self._underlying.list_topics(parent=str(location_path))] + + def update_topic(self, topic: Topic, update_mask: FieldMask) -> Topic: + return self._underlying.update_topic(topic=topic, update_mask=update_mask) + + def delete_topic(self, topic_path: TopicPath): + self._underlying.delete_topic(name=str(topic_path)) + + def list_topic_subscriptions(self, topic_path: TopicPath): + subscription_strings = [x for x in self._underlying.list_topic_subscriptions(name=str(topic_path))] + return [SubscriptionPath.parse(x) for x in subscription_strings] + + def create_subscription(self, subscription: Subscription) -> Subscription: + path = SubscriptionPath.parse(subscription.name) + return self._underlying.create_subscription(parent=str(path.to_location_path()), subscription=subscription, + subscription_id=path.name) + + def get_subscription(self, subscription_path: SubscriptionPath) -> Subscription: + return self._underlying.get_subscription(name=str(subscription_path)) + + def list_subscriptions(self, location_path: LocationPath) -> List[Subscription]: + return [x for x in self._underlying.list_subscriptions(parent=str(location_path))] + + def update_subscription(self, subscription: Subscription, update_mask: FieldMask) -> Subscription: + return self._underlying.update_subscription(subscription=subscription, update_mask=update_mask) + + def delete_subscription(self, subscription_path: SubscriptionPath): + self._underlying.delete_subscription(name=str(subscription_path)) diff --git a/google/cloud/pubsublite/internal/wire/make_publisher.py b/google/cloud/pubsublite/internal/wire/make_publisher.py index ec2c690d..c9e0de28 100644 --- a/google/cloud/pubsublite/internal/wire/make_publisher.py +++ b/google/cloud/pubsublite/internal/wire/make_publisher.py @@ -1,5 +1,6 @@ from typing import AsyncIterator, Mapping, Optional, MutableMapping +from google.cloud.pubsublite.admin_client import make_admin_client from google.cloud.pubsublite.endpoints import regional_endpoint from google.cloud.pubsublite.internal.wire.default_routing_policy import DefaultRoutingPolicy from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnectionFactory @@ -12,8 +13,6 @@ from google.cloud.pubsublite.routing_metadata import topic_routing_metadata from google.cloud.pubsublite_v1 import InitialPublishRequest, PublishRequest from google.cloud.pubsublite_v1.services.publisher_service import async_client -from google.cloud.pubsublite_v1.services.admin_service.client import AdminServiceClient -from google.cloud.pubsublite_v1.types.admin import GetTopicPartitionsRequest from google.api_core.client_options import ClientOptions from google.auth.credentials import Credentials @@ -40,17 +39,16 @@ def make_publisher( Throws: GoogleApiCallException on any error determining topic structure. """ + admin_client = make_admin_client(region=topic.location.region, credentials=credentials, client_options=client_options) if client_options is None: client_options = ClientOptions(api_endpoint=regional_endpoint(topic.location.region)) client = async_client.PublisherServiceAsyncClient( credentials=credentials, client_options=client_options) # type: ignore - admin_client = AdminServiceClient(credentials=credentials, client_options=client_options) - partitions = admin_client.get_topic_partitions(GetTopicPartitionsRequest(name=str(topic))) - clients: MutableMapping[Partition, Publisher] = {} - for partition in range(partitions.partition_count): + partition_count = admin_client.get_topic_partition_count(topic) + for partition in range(partition_count): partition = Partition(partition) def connection_factory(requests: AsyncIterator[PublishRequest]): @@ -59,4 +57,4 @@ def connection_factory(requests: AsyncIterator[PublishRequest]): clients[partition] = SinglePartitionPublisher(InitialPublishRequest(topic=str(topic), partition=partition.value), batching_delay_secs, GapicConnectionFactory(connection_factory)) - return RoutingPublisher(DefaultRoutingPolicy(partitions.partition_count), clients) + return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients) diff --git a/google/cloud/pubsublite/location.py b/google/cloud/pubsublite/location.py index f73966ce..85414226 100644 --- a/google/cloud/pubsublite/location.py +++ b/google/cloud/pubsublite/location.py @@ -1,5 +1,7 @@ from typing import NamedTuple +from google.api_core.exceptions import InvalidArgument + class CloudRegion(NamedTuple): name: str @@ -11,3 +13,11 @@ class CloudZone(NamedTuple): def __str__(self): return f"{self.region.name}-{self.zone_id}" + + @staticmethod + def parse(to_parse: str): + splits = to_parse.split('-') + if len(splits) != 3 or len(splits[2]) != 1: + raise InvalidArgument("Invalid zone name: " + to_parse) + region = CloudRegion(name=splits[0] + '-' + splits[1]) + return CloudZone(region, zone_id=splits[2]) diff --git a/google/cloud/pubsublite/paths.py b/google/cloud/pubsublite/paths.py index 3a7208d0..f0921a45 100644 --- a/google/cloud/pubsublite/paths.py +++ b/google/cloud/pubsublite/paths.py @@ -1,8 +1,18 @@ from typing import NamedTuple +from google.api_core.exceptions import InvalidArgument + from google.cloud.pubsublite.location import CloudZone +class LocationPath(NamedTuple): + project_number: int + location: CloudZone + + def __str__(self): + return f"projects/{self.project_number}/locations/{self.location}" + + class TopicPath(NamedTuple): project_number: int location: CloudZone @@ -11,6 +21,23 @@ class TopicPath(NamedTuple): def __str__(self): return f"projects/{self.project_number}/locations/{self.location}/topics/{self.name}" + def to_location_path(self): + return LocationPath(self.project_number, self.location) + + @staticmethod + def parse(to_parse: str) -> "TopicPath": + splits = to_parse.split("/") + if len(splits) != 6 or splits[0] != "projects" or splits[2] != "locations" or splits[4] != "topics": + raise InvalidArgument( + "Topic path must be formatted like projects/{project_number}/locations/{location}/topics/{name} but was instead " + to_parse) + project_number: int + try: + project_number = int(splits[1]) + except ValueError: + raise InvalidArgument( + "Topic path must be formatted like projects/{project_number}/locations/{location}/topics/{name} but was instead " + to_parse) + return TopicPath(project_number, CloudZone.parse(splits[3]), splits[5]) + class SubscriptionPath(NamedTuple): project_number: int @@ -19,3 +46,20 @@ class SubscriptionPath(NamedTuple): def __str__(self): return f"projects/{self.project_number}/locations/{self.location}/subscriptions/{self.name}" + + def to_location_path(self): + return LocationPath(self.project_number, self.location) + + @staticmethod + def parse(to_parse: str) -> "SubscriptionPath": + splits = to_parse.split("/") + if len(splits) != 6 or splits[0] != "projects" or splits[2] != "locations" or splits[4] != "subscriptions": + raise InvalidArgument( + "Subscription path must be formatted like projects/{project_number}/locations/{location}/subscriptions/{name} but was instead " + to_parse) + project_number: int + try: + project_number = int(splits[1]) + except ValueError: + raise InvalidArgument( + "Subscription path must be formatted like projects/{project_number}/locations/{location}/subscriptions/{name} but was instead " + to_parse) + return SubscriptionPath(project_number, CloudZone.parse(splits[3]), splits[5])