diff --git a/.gitignore b/.gitignore index b9daa52f..efad5203 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ docs.metadata # Virtual environment env/ +venv/ coverage.xml sponge_log.xml diff --git a/google/cloud/pubsublite/internal/wire/make_publisher.py b/google/cloud/pubsublite/internal/wire/make_publisher.py index 27d2d4cc..26c2486e 100644 --- a/google/cloud/pubsublite/internal/wire/make_publisher.py +++ b/google/cloud/pubsublite/internal/wire/make_publisher.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AsyncIterator, Mapping, Optional, MutableMapping +from typing import AsyncIterator, Mapping, Optional from google.cloud.pubsub_v1.types import BatchSettings @@ -25,8 +25,13 @@ GapicConnectionFactory, ) from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata +from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import ( + PartitionCountWatcherImpl, +) +from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import ( + PartitionCountWatchingPublisher, +) from google.cloud.pubsublite.internal.wire.publisher import Publisher -from google.cloud.pubsublite.internal.wire.routing_publisher import RoutingPublisher from google.cloud.pubsublite.internal.wire.single_partition_publisher import ( SinglePartitionPublisher, ) @@ -37,7 +42,6 @@ from google.api_core.client_options import ClientOptions from google.auth.credentials import Credentials - DEFAULT_BATCHING_SETTINGS = BatchSettings( max_bytes=( 3 * 1024 * 1024 @@ -45,6 +49,7 @@ max_messages=1000, max_latency=0.05, # 50 ms ) +DEFAULT_PARTITION_POLL_PERIOD = 600 # ten minutes def make_publisher( @@ -87,21 +92,24 @@ def make_publisher( credentials=credentials, transport=transport, client_options=client_options ) # type: ignore - clients: MutableMapping[Partition, Publisher] = {} - - partition_count = admin_client.get_topic_partition_count(topic) - for partition in range(partition_count): - partition = Partition(partition) - + def publisher_factory(partition: Partition): def connection_factory(requests: AsyncIterator[PublishRequest]): final_metadata = merge_metadata( metadata, topic_routing_metadata(topic, partition) ) return client.publish(requests, metadata=list(final_metadata.items())) - clients[partition] = SinglePartitionPublisher( + return SinglePartitionPublisher( InitialPublishRequest(topic=str(topic), partition=partition.value), per_partition_batching_settings, GapicConnectionFactory(connection_factory), ) - return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients) + + def policy_factory(partition_count: int): + return DefaultRoutingPolicy(partition_count) + + return PartitionCountWatchingPublisher( + PartitionCountWatcherImpl(admin_client, topic, DEFAULT_PARTITION_POLL_PERIOD), + publisher_factory, + policy_factory, + ) diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watcher.py b/google/cloud/pubsublite/internal/wire/partition_count_watcher.py new file mode 100644 index 00000000..f0527755 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/partition_count_watcher.py @@ -0,0 +1,22 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import AsyncContextManager + + +class PartitionCountWatcher(AsyncContextManager): + @abstractmethod + async def get_partition_count(self) -> int: + raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py new file mode 100644 index 00000000..1b0796ae --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py @@ -0,0 +1,75 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from concurrent.futures.thread import ThreadPoolExecutor +import asyncio + +from google.cloud.pubsublite import AdminClientInterface +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( + PartitionCountWatcher, +) +from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable +from google.cloud.pubsublite.types import TopicPath +from google.api_core.exceptions import GoogleAPICallError + + +class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable): + _admin: AdminClientInterface + _topic_path: TopicPath + _duration: float + _any_success: bool + _thread: ThreadPoolExecutor + _queue: asyncio.Queue + _poll_partition_loop: asyncio.Future + + def __init__( + self, admin: AdminClientInterface, topic_path: TopicPath, duration: float + ): + super().__init__() + self._admin = admin + self._topic_path = topic_path + self._duration = duration + self._any_success = False + + async def __aenter__(self): + self._thread = ThreadPoolExecutor(max_workers=1) + self._queue = asyncio.Queue(maxsize=1) + self._poll_partition_loop = asyncio.ensure_future( + self.run_poller(self._poll_partition_loop) + ) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._poll_partition_loop.cancel() + await wait_ignore_cancelled(self._poll_partition_loop) + self._thread.shutdown(wait=False) + + def _get_partition_count_sync(self) -> int: + return self._admin.get_topic_partition_count(self._topic_path) + + async def _poll_partition_loop(self): + try: + partition_count = await asyncio.get_event_loop().run_in_executor( + self._thread, self._get_partition_count_sync + ) + self._any_success = True + await self._queue.put(partition_count) + except GoogleAPICallError as e: + if not self._any_success: + raise e + logging.exception("Failed to retrieve partition count") + await asyncio.sleep(self._duration) + + async def get_partition_count(self) -> int: + return await self.await_unless_failed(self._queue.get()) diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py new file mode 100644 index 00000000..55618de1 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py @@ -0,0 +1,94 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import sys +from typing import Callable, Dict + +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( + PartitionCountWatcher, +) +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy +from google.cloud.pubsublite.types import PublishMetadata, Partition +from google.cloud.pubsublite_v1 import PubSubMessage + + +class PartitionCountWatchingPublisher(Publisher): + _publishers: Dict[Partition, Publisher] + _publisher_factory: Callable[[Partition], Publisher] + _policy_factory: Callable[[int], RoutingPolicy] + _watcher: PartitionCountWatcher + _partition_count_poller: asyncio.Future + + def __init__( + self, + watcher: PartitionCountWatcher, + publisher_factory: Callable[[Partition], Publisher], + policy_factory: Callable[[int], RoutingPolicy], + ): + self._publishers = {} + self._publisher_factory = publisher_factory + self._policy_factory = policy_factory + self._watcher = watcher + + async def __aenter__(self): + try: + await self._watcher.__aenter__() + await self._poll_partition_count_action() + except Exception: + await self._watcher.__aexit__(*sys.exc_info()) + raise + self._partition_count_poller = asyncio.ensure_future( + self._watch_partition_count() + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._partition_count_poller.cancel() + await wait_ignore_cancelled(self._partition_count_poller) + await self._watcher.__aexit__(exc_type, exc_val, exc_tb) + for publisher in self._publishers.values(): + await publisher.__aexit__(exc_type, exc_val, exc_tb) + + async def _poll_partition_count_action(self): + partition_count = await self._watcher.get_partition_count() + await self._handle_partition_count_update(partition_count) + + async def _watch_partition_count(self): + while True: + await self._poll_partition_count_action() + + async def _handle_partition_count_update(self, partition_count: int): + current_count = len(self._publishers) + if current_count == partition_count: + return + if current_count > partition_count: + return + + new_publishers = { + Partition(index): self._publisher_factory(Partition(index)) + for index in range(current_count, partition_count) + } + await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()]) + routing_policy = self._policy_factory(partition_count) + + self._publishers.update(new_publishers) + self._routing_policy = routing_policy + + async def publish(self, message: PubSubMessage) -> PublishMetadata: + partition = self._routing_policy.route(message) + assert partition in self._publishers + publisher = self._publishers[partition] + return await publisher.publish(message) diff --git a/google/cloud/pubsublite/testing/test_utils.py b/google/cloud/pubsublite/testing/test_utils.py index f5077203..7655fcb8 100644 --- a/google/cloud/pubsublite/testing/test_utils.py +++ b/google/cloud/pubsublite/testing/test_utils.py @@ -13,7 +13,8 @@ # limitations under the License. import asyncio -from typing import List, Union, Any, TypeVar, Generic, Optional +import threading +from typing import List, Union, Any, TypeVar, Generic, Optional, Callable from asynctest import CoroutineMock @@ -62,3 +63,16 @@ def wire_queues(mock: CoroutineMock) -> QueuePair: class Box(Generic[T]): val: Optional[T] + + +def run_on_thread(func: Callable[[], T]) -> T: + box = Box() + + def set_box(): + box.val = func() + + # Initialize watcher on another thread with a different event loop. + thread = threading.Thread(target=set_box) + thread.start() + thread.join() + return box.val diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py new file mode 100644 index 00000000..038a2aa9 --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py @@ -0,0 +1,95 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import queue +from asynctest.mock import MagicMock +import pytest + +from google.cloud.pubsublite import AdminClientInterface +from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import ( + PartitionCountWatcherImpl, +) +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.testing.test_utils import run_on_thread +from google.cloud.pubsublite.types import Partition, TopicPath +from google.api_core.exceptions import GoogleAPICallError + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def mock_publishers(): + return {Partition(i): MagicMock(spec=Publisher) for i in range(10)} + + +@pytest.fixture() +def topic(): + return TopicPath.parse("projects/1/locations/us-central1-a/topics/topic") + + +@pytest.fixture() +def mock_admin(): + admin = MagicMock(spec=AdminClientInterface) + return admin + + +@pytest.fixture() +def watcher(mock_admin, topic): + return run_on_thread(lambda: PartitionCountWatcherImpl(mock_admin, topic, 0.001)) + + +async def test_init(watcher, mock_admin, topic): + mock_admin.get_topic_partition_count.return_value = 2 + async with watcher: + pass + + +async def test_get_count_first_failure(watcher, mock_admin, topic): + mock_admin.get_topic_partition_count.side_effect = GoogleAPICallError("error") + with pytest.raises(GoogleAPICallError): + async with watcher: + await watcher.get_partition_count() + + +async def test_get_multiple_counts(watcher, mock_admin, topic): + q = queue.Queue() + mock_admin.get_topic_partition_count.side_effect = q.get + async with watcher: + task1 = asyncio.ensure_future(watcher.get_partition_count()) + task2 = asyncio.ensure_future(watcher.get_partition_count()) + assert not task1.done() + assert not task2.done() + q.put(3) + assert await task1 == 3 + assert not task2.done() + q.put(4) + assert await task2 == 4 + + +async def test_subsequent_failures_ignored(watcher, mock_admin, topic): + q = queue.Queue() + + def side_effect(): + value = q.get() + if isinstance(value, Exception): + raise value + return value + + mock_admin.get_topic_partition_count.side_effect = lambda x: side_effect() + async with watcher: + q.put(3) + assert await watcher.get_partition_count() == 3 + q.put(GoogleAPICallError("error")) + q.put(4) + assert await watcher.get_partition_count() == 4 diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py new file mode 100644 index 00000000..a31f12fa --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py @@ -0,0 +1,132 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from asynctest.mock import MagicMock +import pytest + +from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( + PartitionCountWatcher, +) +from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import ( + PartitionCountWatchingPublisher, +) +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy +from google.cloud.pubsublite.testing.test_utils import wire_queues, run_on_thread +from google.cloud.pubsublite.types import Partition +from google.cloud.pubsublite_v1 import PubSubMessage +from google.api_core.exceptions import GoogleAPICallError + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def mock_publishers(): + return {Partition(i): MagicMock(spec=Publisher) for i in range(10)} + + +@pytest.fixture() +def mock_policies(): + return {i: MagicMock(spec=RoutingPolicy) for i in range(10)} + + +@pytest.fixture() +def mock_watcher(): + watcher = MagicMock(spec=PartitionCountWatcher) + return watcher + + +@pytest.fixture() +def publisher(mock_watcher, mock_publishers, mock_policies): + return run_on_thread( + lambda: PartitionCountWatchingPublisher( + mock_watcher, lambda p: mock_publishers[p], lambda c: mock_policies[c] + ) + ) + + +async def test_init(mock_watcher, publisher): + mock_watcher.get_partition_count.return_value = 2 + async with publisher: + mock_watcher.__aenter__.assert_called_once() + pass + mock_watcher.__aexit__.assert_called_once() + + +async def test_failed_init(mock_watcher, publisher): + mock_watcher.get_partition_count.side_effect = GoogleAPICallError("error") + with pytest.raises(GoogleAPICallError): + async with publisher: + pass + mock_watcher.__aenter__.assert_called_once() + mock_watcher.__aexit__.assert_called_once() + + +async def test_simple_publish(mock_publishers, mock_policies, mock_watcher, publisher): + mock_watcher.get_partition_count.return_value = 2 + async with publisher: + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called() + + +async def test_publish_after_increase( + mock_publishers, mock_policies, mock_watcher, publisher +): + get_queues = wire_queues(mock_watcher.get_partition_count) + await get_queues.results.put(2) + async with publisher: + get_queues.called.get_nowait() + + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called() + + await get_queues.called.get() + await get_queues.results.put(3) + await get_queues.called.get() + + mock_policies[3].route.return_value = Partition(2) + mock_publishers[Partition(2)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[3].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(2)].publish.assert_called() + + +async def test_decrease_ignored( + mock_publishers, mock_policies, mock_watcher, publisher +): + get_queues = wire_queues(mock_watcher.get_partition_count) + await get_queues.results.put(2) + async with publisher: + get_queues.called.get_nowait() + + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called() + + await get_queues.called.get() + await get_queues.results.put(1) + await get_queues.called.get() + + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called()