From fdddd975f3123e56773941e0a55d632b3e734ccc Mon Sep 17 00:00:00 2001 From: Evan Palmer Date: Wed, 9 Dec 2020 23:31:36 -0500 Subject: [PATCH 1/3] Add support for increasing partitions in python --- .../internal/wire/make_publisher.py | 29 ++-- .../internal/wire/partition_count_watcher.py | 22 +++ .../wire/partition_count_watcher_impl.py | 65 ++++++++ .../partition_count_watching_publisher.py | 95 ++++++++++++ .../wire/partition_count_watcher_impl_test.py | 106 +++++++++++++ ...partition_count_watching_publisher_test.py | 140 ++++++++++++++++++ 6 files changed, 446 insertions(+), 11 deletions(-) create mode 100644 google/cloud/pubsublite/internal/wire/partition_count_watcher.py create mode 100644 google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py create mode 100644 google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py create mode 100644 tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py create mode 100644 tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py diff --git a/google/cloud/pubsublite/internal/wire/make_publisher.py b/google/cloud/pubsublite/internal/wire/make_publisher.py index 27d2d4cc..8aabd357 100644 --- a/google/cloud/pubsublite/internal/wire/make_publisher.py +++ b/google/cloud/pubsublite/internal/wire/make_publisher.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AsyncIterator, Mapping, Optional, MutableMapping +from typing import AsyncIterator, Mapping, Optional from google.cloud.pubsub_v1.types import BatchSettings @@ -25,8 +25,13 @@ GapicConnectionFactory, ) from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata +from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import ( + PartitionCountWatcherImpl, +) +from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import ( + PartitionCountWatchingPublisher, +) from google.cloud.pubsublite.internal.wire.publisher import Publisher -from google.cloud.pubsublite.internal.wire.routing_publisher import RoutingPublisher from google.cloud.pubsublite.internal.wire.single_partition_publisher import ( SinglePartitionPublisher, ) @@ -37,7 +42,6 @@ from google.api_core.client_options import ClientOptions from google.auth.credentials import Credentials - DEFAULT_BATCHING_SETTINGS = BatchSettings( max_bytes=( 3 * 1024 * 1024 @@ -87,21 +91,24 @@ def make_publisher( credentials=credentials, transport=transport, client_options=client_options ) # type: ignore - clients: MutableMapping[Partition, Publisher] = {} - - partition_count = admin_client.get_topic_partition_count(topic) - for partition in range(partition_count): - partition = Partition(partition) - + def publisher_factory(partition: Partition): def connection_factory(requests: AsyncIterator[PublishRequest]): final_metadata = merge_metadata( metadata, topic_routing_metadata(topic, partition) ) return client.publish(requests, metadata=list(final_metadata.items())) - clients[partition] = SinglePartitionPublisher( + return SinglePartitionPublisher( InitialPublishRequest(topic=str(topic), partition=partition.value), per_partition_batching_settings, GapicConnectionFactory(connection_factory), ) - return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients) + + def policy_factory(partition_count: int): + return DefaultRoutingPolicy(partition_count) + + return PartitionCountWatchingPublisher( + PartitionCountWatcherImpl(admin_client, topic, 10), + publisher_factory, + policy_factory, + ) diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watcher.py b/google/cloud/pubsublite/internal/wire/partition_count_watcher.py new file mode 100644 index 00000000..f0527755 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/partition_count_watcher.py @@ -0,0 +1,22 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import AsyncContextManager + + +class PartitionCountWatcher(AsyncContextManager): + @abstractmethod + async def get_partition_count(self) -> int: + raise NotImplementedError() diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py new file mode 100644 index 00000000..6e12f95b --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py @@ -0,0 +1,65 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from concurrent.futures.thread import ThreadPoolExecutor +import asyncio + +from google.cloud.pubsublite import AdminClientInterface +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( + PartitionCountWatcher, +) +from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable +from google.cloud.pubsublite.types import TopicPath +from google.api_core.exceptions import GoogleAPICallError + + +class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable): + def __init__( + self, admin: AdminClientInterface, topic_path: TopicPath, duration: float + ): + super().__init__() + self._admin = admin + self._topic_path = topic_path + self._duration = duration + self._any_success = False + + async def __aenter__(self): + self._thread = ThreadPoolExecutor(max_workers=1) + self._queue = asyncio.Queue(maxsize=1) + self._poll_partition_loop = asyncio.ensure_future( + self.run_poller(self._poll_partition_loop) + ) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._poll_partition_loop.cancel() + await wait_ignore_cancelled(self._poll_partition_loop) + self._thread.shutdown(wait=False) + + def _get_partition_count_sync(self) -> int: + return self._admin.get_topic_partition_count(self._topic_path) + + async def _poll_partition_loop(self): + try: + partition_count = await asyncio.get_event_loop().run_in_executor( + self._thread, self._get_partition_count_sync + ) + self._any_success = True + await self._queue.put(partition_count) + except GoogleAPICallError as e: + if not self._any_success: + raise e + await asyncio.sleep(self._duration) + + async def get_partition_count(self) -> int: + return await self.await_unless_failed(self._queue.get()) diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py new file mode 100644 index 00000000..0da5cf62 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py @@ -0,0 +1,95 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import sys +import threading +from typing import Callable + +from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( + PartitionCountWatcher, +) +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy +from google.cloud.pubsublite.types import PublishMetadata, Partition +from google.cloud.pubsublite_v1 import PubSubMessage + + +class PartitionCountWatchingPublisher(Publisher): + def __init__( + self, + watcher: PartitionCountWatcher, + publisher_factory: Callable[[Partition], Publisher], + policy_factory: Callable[[int], RoutingPolicy], + ): + self._publishers = {} + self._lock = threading.Lock() + self._publisher_factory = publisher_factory + self._policy_factory = policy_factory + self._watcher = watcher + + async def __aenter__(self): + try: + await self._watcher.__aenter__() + await self._poll_partition_count_action() + except Exception: + await self._watcher.__aexit__(*sys.exc_info()) + raise + self._partition_count_poller = asyncio.ensure_future( + self._watch_partition_count() + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._partition_count_poller.cancel() + await wait_ignore_cancelled(self._partition_count_poller) + await self._watcher.__aexit__(exc_type, exc_val, exc_tb) + with self._lock: + for publisher in self._publishers.values(): + await publisher.__aexit__(exc_type, exc_val, exc_tb) + + async def _poll_partition_count_action(self): + partition_count = await self._watcher.get_partition_count() + await self._handle_partition_count_update(partition_count) + + async def _watch_partition_count(self): + while True: + await self._poll_partition_count_action() + + async def _handle_partition_count_update(self, partition_count: int): + with self._lock: + current_count = len(self._publishers) + if current_count == partition_count: + return + if current_count > partition_count: + return + + new_publishers = { + Partition(index): self._publisher_factory(Partition(index)) + for index in range(current_count, partition_count) + } + await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()]) + routing_policy = self._policy_factory(partition_count) + + with self._lock: + self._publishers.update(new_publishers) + self._routing_policy = routing_policy + + async def publish(self, message: PubSubMessage) -> PublishMetadata: + with self._lock: + partition = self._routing_policy.route(message) + assert partition in self._publishers + publisher = self._publishers[partition] + + return await publisher.publish(message) diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py new file mode 100644 index 00000000..152724cf --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py @@ -0,0 +1,106 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import queue +import threading +from asynctest.mock import MagicMock +import pytest + +from google.cloud.pubsublite import AdminClientInterface +from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import ( + PartitionCountWatcherImpl, +) +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.testing.test_utils import Box +from google.cloud.pubsublite.types import Partition, TopicPath, CloudZone, CloudRegion +from google.api_core.exceptions import GoogleAPICallError + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def mock_publishers(): + return {Partition(i): MagicMock(spec=Publisher) for i in range(10)} + + +@pytest.fixture() +def topic(): + zone = CloudZone(region=CloudRegion("a"), zone_id="a") + return TopicPath(project_number=1, location=zone, name="c") + + +@pytest.fixture() +def mock_admin(): + admin = MagicMock(spec=AdminClientInterface) + return admin + + +@pytest.fixture() +def watcher(mock_admin, topic): + box = Box() + + def set_box(): + box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001) + + # Initialize publisher on another thread with a different event loop. + thread = threading.Thread(target=set_box) + thread.start() + thread.join() + return box.val + + +async def test_init(watcher, mock_admin, topic): + mock_admin.get_topic_partition_count.return_value = 2 + async with watcher: + pass + + +async def test_get_count_first_failure(watcher, mock_admin, topic): + mock_admin.get_topic_partition_count.side_effect = GoogleAPICallError("error") + with pytest.raises(GoogleAPICallError): + async with watcher: + await watcher.get_partition_count() + + +async def test_get_multiple_counts(watcher, mock_admin, topic): + q = queue.Queue() + mock_admin.get_topic_partition_count.side_effect = q.get + async with watcher: + task1 = asyncio.ensure_future(watcher.get_partition_count()) + task2 = asyncio.ensure_future(watcher.get_partition_count()) + assert not task1.done() + assert not task2.done() + q.put(3) + assert await task1 == 3 + assert not task2.done() + q.put(4) + assert await task2 == 4 + + +async def test_subsequent_failures_ignored(watcher, mock_admin, topic): + q = queue.Queue() + + def side_effect(): + value = q.get() + if isinstance(value, Exception): + raise value + return value + + mock_admin.get_topic_partition_count.side_effect = lambda x: side_effect() + async with watcher: + q.put(3) + assert await watcher.get_partition_count() == 3 + q.put(GoogleAPICallError("error")) + q.put(4) + assert await watcher.get_partition_count() == 4 diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py new file mode 100644 index 00000000..7c840bd1 --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py @@ -0,0 +1,140 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from asynctest.mock import MagicMock +import pytest + +from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( + PartitionCountWatcher, +) +from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import ( + PartitionCountWatchingPublisher, +) +from google.cloud.pubsublite.internal.wire.publisher import Publisher +from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy +from google.cloud.pubsublite.testing.test_utils import Box, wire_queues +from google.cloud.pubsublite.types import Partition +from google.cloud.pubsublite_v1 import PubSubMessage +from google.api_core.exceptions import GoogleAPICallError + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def mock_publishers(): + return {Partition(i): MagicMock(spec=Publisher) for i in range(10)} + + +@pytest.fixture() +def mock_policies(): + return {i: MagicMock(spec=RoutingPolicy) for i in range(10)} + + +@pytest.fixture() +def mock_watcher(): + watcher = MagicMock(spec=PartitionCountWatcher) + return watcher + + +@pytest.fixture() +def publisher(mock_watcher, mock_publishers, mock_policies): + box = Box() + + def set_box(): + box.val = PartitionCountWatchingPublisher( + mock_watcher, lambda p: mock_publishers[p], lambda c: mock_policies[c] + ) + + # Initialize publisher on another thread with a different event loop. + thread = threading.Thread(target=set_box) + thread.start() + thread.join() + return box.val + + +async def test_init(mock_watcher, publisher): + mock_watcher.get_partition_count.return_value = 2 + async with publisher: + mock_watcher.__aenter__.assert_called_once() + pass + mock_watcher.__aexit__.assert_called_once() + + +async def test_failed_init(mock_watcher, publisher): + mock_watcher.get_partition_count.side_effect = GoogleAPICallError("error") + with pytest.raises(GoogleAPICallError): + async with publisher: + pass + mock_watcher.__aenter__.assert_called_once() + mock_watcher.__aexit__.assert_called_once() + + +async def test_simple_publish(mock_publishers, mock_policies, mock_watcher, publisher): + mock_watcher.get_partition_count.return_value = 2 + async with publisher: + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called() + + +async def test_publish_after_increase( + mock_publishers, mock_policies, mock_watcher, publisher +): + get_queues = wire_queues(mock_watcher.get_partition_count) + await get_queues.results.put(2) + async with publisher: + get_queues.called.get_nowait() + + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called() + + await get_queues.called.get() + await get_queues.results.put(3) + await get_queues.called.get() + + mock_policies[3].route.return_value = Partition(2) + mock_publishers[Partition(2)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[3].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(2)].publish.assert_called() + + +async def test_decrease_ignored( + mock_publishers, mock_policies, mock_watcher, publisher +): + get_queues = wire_queues(mock_watcher.get_partition_count) + await get_queues.results.put(2) + async with publisher: + get_queues.called.get_nowait() + + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called() + + await get_queues.called.get() + await get_queues.results.put(1) + await get_queues.called.get() + + mock_policies[2].route.return_value = Partition(1) + mock_publishers[Partition(1)].publish.return_value = "a" + await publisher.publish(PubSubMessage()) + mock_policies[2].route.assert_called_with(PubSubMessage()) + mock_publishers[Partition(1)].publish.assert_called() From 1e8e5a53f500ffd40a270f14fa72a60e0cf35a89 Mon Sep 17 00:00:00 2001 From: Evan Palmer Date: Thu, 10 Dec 2020 14:34:47 -0500 Subject: [PATCH 2/3] updates to address comments --- .gitignore | 1 + .../internal/wire/make_publisher.py | 3 +- .../wire/partition_count_watcher_impl.py | 10 ++++++ .../partition_count_watching_publisher.py | 31 +++++++++---------- .../wire/partition_count_watcher_impl_test.py | 2 +- 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index b9daa52f..71f99fe7 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ develop-eggs .installed.cfg lib lib64 +venv __pycache__ # Installer logs diff --git a/google/cloud/pubsublite/internal/wire/make_publisher.py b/google/cloud/pubsublite/internal/wire/make_publisher.py index 8aabd357..26c2486e 100644 --- a/google/cloud/pubsublite/internal/wire/make_publisher.py +++ b/google/cloud/pubsublite/internal/wire/make_publisher.py @@ -49,6 +49,7 @@ max_messages=1000, max_latency=0.05, # 50 ms ) +DEFAULT_PARTITION_POLL_PERIOD = 600 # ten minutes def make_publisher( @@ -108,7 +109,7 @@ def policy_factory(partition_count: int): return DefaultRoutingPolicy(partition_count) return PartitionCountWatchingPublisher( - PartitionCountWatcherImpl(admin_client, topic, 10), + PartitionCountWatcherImpl(admin_client, topic, DEFAULT_PARTITION_POLL_PERIOD), publisher_factory, policy_factory, ) diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py index 6e12f95b..1b0796ae 100644 --- a/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py +++ b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from concurrent.futures.thread import ThreadPoolExecutor import asyncio @@ -25,6 +26,14 @@ class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable): + _admin: AdminClientInterface + _topic_path: TopicPath + _duration: float + _any_success: bool + _thread: ThreadPoolExecutor + _queue: asyncio.Queue + _poll_partition_loop: asyncio.Future + def __init__( self, admin: AdminClientInterface, topic_path: TopicPath, duration: float ): @@ -59,6 +68,7 @@ async def _poll_partition_loop(self): except GoogleAPICallError as e: if not self._any_success: raise e + logging.exception("Failed to retrieve partition count") await asyncio.sleep(self._duration) async def get_partition_count(self) -> int: diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py index 0da5cf62..55618de1 100644 --- a/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py +++ b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py @@ -13,8 +13,7 @@ # limitations under the License. import asyncio import sys -import threading -from typing import Callable +from typing import Callable, Dict from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( @@ -27,6 +26,12 @@ class PartitionCountWatchingPublisher(Publisher): + _publishers: Dict[Partition, Publisher] + _publisher_factory: Callable[[Partition], Publisher] + _policy_factory: Callable[[int], RoutingPolicy] + _watcher: PartitionCountWatcher + _partition_count_poller: asyncio.Future + def __init__( self, watcher: PartitionCountWatcher, @@ -34,7 +39,6 @@ def __init__( policy_factory: Callable[[int], RoutingPolicy], ): self._publishers = {} - self._lock = threading.Lock() self._publisher_factory = publisher_factory self._policy_factory = policy_factory self._watcher = watcher @@ -55,9 +59,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self._partition_count_poller.cancel() await wait_ignore_cancelled(self._partition_count_poller) await self._watcher.__aexit__(exc_type, exc_val, exc_tb) - with self._lock: - for publisher in self._publishers.values(): - await publisher.__aexit__(exc_type, exc_val, exc_tb) + for publisher in self._publishers.values(): + await publisher.__aexit__(exc_type, exc_val, exc_tb) async def _poll_partition_count_action(self): partition_count = await self._watcher.get_partition_count() @@ -68,8 +71,7 @@ async def _watch_partition_count(self): await self._poll_partition_count_action() async def _handle_partition_count_update(self, partition_count: int): - with self._lock: - current_count = len(self._publishers) + current_count = len(self._publishers) if current_count == partition_count: return if current_count > partition_count: @@ -82,14 +84,11 @@ async def _handle_partition_count_update(self, partition_count: int): await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()]) routing_policy = self._policy_factory(partition_count) - with self._lock: - self._publishers.update(new_publishers) - self._routing_policy = routing_policy + self._publishers.update(new_publishers) + self._routing_policy = routing_policy async def publish(self, message: PubSubMessage) -> PublishMetadata: - with self._lock: - partition = self._routing_policy.route(message) - assert partition in self._publishers - publisher = self._publishers[partition] - + partition = self._routing_policy.route(message) + assert partition in self._publishers + publisher = self._publishers[partition] return await publisher.publish(message) diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py index 152724cf..ee72e848 100644 --- a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py @@ -53,7 +53,7 @@ def watcher(mock_admin, topic): def set_box(): box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001) - # Initialize publisher on another thread with a different event loop. + # Initialize watcher on another thread with a different event loop. thread = threading.Thread(target=set_box) thread.start() thread.join() From 4440ad3649b566daabd2c929bb6e2d9ab1f1883a Mon Sep 17 00:00:00 2001 From: Evan Palmer Date: Mon, 14 Dec 2020 12:16:01 -0500 Subject: [PATCH 3/3] addressing comments --- .gitignore | 2 +- google/cloud/pubsublite/testing/test_utils.py | 16 +++++++++++++++- .../wire/partition_count_watcher_impl_test.py | 19 ++++--------------- ...partition_count_watching_publisher_test.py | 16 ++++------------ 4 files changed, 24 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index 71f99fe7..efad5203 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,6 @@ develop-eggs .installed.cfg lib lib64 -venv __pycache__ # Installer logs @@ -51,6 +50,7 @@ docs.metadata # Virtual environment env/ +venv/ coverage.xml sponge_log.xml diff --git a/google/cloud/pubsublite/testing/test_utils.py b/google/cloud/pubsublite/testing/test_utils.py index f5077203..7655fcb8 100644 --- a/google/cloud/pubsublite/testing/test_utils.py +++ b/google/cloud/pubsublite/testing/test_utils.py @@ -13,7 +13,8 @@ # limitations under the License. import asyncio -from typing import List, Union, Any, TypeVar, Generic, Optional +import threading +from typing import List, Union, Any, TypeVar, Generic, Optional, Callable from asynctest import CoroutineMock @@ -62,3 +63,16 @@ def wire_queues(mock: CoroutineMock) -> QueuePair: class Box(Generic[T]): val: Optional[T] + + +def run_on_thread(func: Callable[[], T]) -> T: + box = Box() + + def set_box(): + box.val = func() + + # Initialize watcher on another thread with a different event loop. + thread = threading.Thread(target=set_box) + thread.start() + thread.join() + return box.val diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py index ee72e848..038a2aa9 100644 --- a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio import queue -import threading from asynctest.mock import MagicMock import pytest @@ -22,8 +21,8 @@ PartitionCountWatcherImpl, ) from google.cloud.pubsublite.internal.wire.publisher import Publisher -from google.cloud.pubsublite.testing.test_utils import Box -from google.cloud.pubsublite.types import Partition, TopicPath, CloudZone, CloudRegion +from google.cloud.pubsublite.testing.test_utils import run_on_thread +from google.cloud.pubsublite.types import Partition, TopicPath from google.api_core.exceptions import GoogleAPICallError pytestmark = pytest.mark.asyncio @@ -36,8 +35,7 @@ def mock_publishers(): @pytest.fixture() def topic(): - zone = CloudZone(region=CloudRegion("a"), zone_id="a") - return TopicPath(project_number=1, location=zone, name="c") + return TopicPath.parse("projects/1/locations/us-central1-a/topics/topic") @pytest.fixture() @@ -48,16 +46,7 @@ def mock_admin(): @pytest.fixture() def watcher(mock_admin, topic): - box = Box() - - def set_box(): - box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001) - - # Initialize watcher on another thread with a different event loop. - thread = threading.Thread(target=set_box) - thread.start() - thread.join() - return box.val + return run_on_thread(lambda: PartitionCountWatcherImpl(mock_admin, topic, 0.001)) async def test_init(watcher, mock_admin, topic): diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py index 7c840bd1..a31f12fa 100644 --- a/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py +++ b/tests/unit/pubsublite/internal/wire/partition_count_watching_publisher_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import threading from asynctest.mock import MagicMock import pytest @@ -23,7 +22,7 @@ ) from google.cloud.pubsublite.internal.wire.publisher import Publisher from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy -from google.cloud.pubsublite.testing.test_utils import Box, wire_queues +from google.cloud.pubsublite.testing.test_utils import wire_queues, run_on_thread from google.cloud.pubsublite.types import Partition from google.cloud.pubsublite_v1 import PubSubMessage from google.api_core.exceptions import GoogleAPICallError @@ -49,18 +48,11 @@ def mock_watcher(): @pytest.fixture() def publisher(mock_watcher, mock_publishers, mock_policies): - box = Box() - - def set_box(): - box.val = PartitionCountWatchingPublisher( + return run_on_thread( + lambda: PartitionCountWatchingPublisher( mock_watcher, lambda p: mock_publishers[p], lambda c: mock_policies[c] ) - - # Initialize publisher on another thread with a different event loop. - thread = threading.Thread(target=set_box) - thread.start() - thread.join() - return box.val + ) async def test_init(mock_watcher, publisher):