diff --git a/.gitignore b/.gitignore index b9daa52f..71f99fe7 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ develop-eggs .installed.cfg lib lib64 +venv __pycache__ # Installer logs diff --git a/google/cloud/pubsublite/internal/wire/make_publisher.py b/google/cloud/pubsublite/internal/wire/make_publisher.py index 8aabd357..26c2486e 100644 --- a/google/cloud/pubsublite/internal/wire/make_publisher.py +++ b/google/cloud/pubsublite/internal/wire/make_publisher.py @@ -49,6 +49,7 @@ max_messages=1000, max_latency=0.05, # 50 ms ) +DEFAULT_PARTITION_POLL_PERIOD = 600 # ten minutes def make_publisher( @@ -108,7 +109,7 @@ def policy_factory(partition_count: int): return DefaultRoutingPolicy(partition_count) return PartitionCountWatchingPublisher( - PartitionCountWatcherImpl(admin_client, topic, 10), + PartitionCountWatcherImpl(admin_client, topic, DEFAULT_PARTITION_POLL_PERIOD), publisher_factory, policy_factory, ) diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py index 6e12f95b..1b0796ae 100644 --- a/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py +++ b/google/cloud/pubsublite/internal/wire/partition_count_watcher_impl.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from concurrent.futures.thread import ThreadPoolExecutor import asyncio @@ -25,6 +26,14 @@ class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable): + _admin: AdminClientInterface + _topic_path: TopicPath + _duration: float + _any_success: bool + _thread: ThreadPoolExecutor + _queue: asyncio.Queue + _poll_partition_loop: asyncio.Future + def __init__( self, admin: AdminClientInterface, topic_path: TopicPath, duration: float ): @@ -59,6 +68,7 @@ async def _poll_partition_loop(self): except GoogleAPICallError as e: if not self._any_success: raise e + logging.exception("Failed to retrieve partition count") await asyncio.sleep(self._duration) async def get_partition_count(self) -> int: diff --git a/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py index 0da5cf62..55618de1 100644 --- a/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py +++ b/google/cloud/pubsublite/internal/wire/partition_count_watching_publisher.py @@ -13,8 +13,7 @@ # limitations under the License. import asyncio import sys -import threading -from typing import Callable +from typing import Callable, Dict from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled from google.cloud.pubsublite.internal.wire.partition_count_watcher import ( @@ -27,6 +26,12 @@ class PartitionCountWatchingPublisher(Publisher): + _publishers: Dict[Partition, Publisher] + _publisher_factory: Callable[[Partition], Publisher] + _policy_factory: Callable[[int], RoutingPolicy] + _watcher: PartitionCountWatcher + _partition_count_poller: asyncio.Future + def __init__( self, watcher: PartitionCountWatcher, @@ -34,7 +39,6 @@ def __init__( policy_factory: Callable[[int], RoutingPolicy], ): self._publishers = {} - self._lock = threading.Lock() self._publisher_factory = publisher_factory self._policy_factory = policy_factory self._watcher = watcher @@ -55,9 +59,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self._partition_count_poller.cancel() await wait_ignore_cancelled(self._partition_count_poller) await self._watcher.__aexit__(exc_type, exc_val, exc_tb) - with self._lock: - for publisher in self._publishers.values(): - await publisher.__aexit__(exc_type, exc_val, exc_tb) + for publisher in self._publishers.values(): + await publisher.__aexit__(exc_type, exc_val, exc_tb) async def _poll_partition_count_action(self): partition_count = await self._watcher.get_partition_count() @@ -68,8 +71,7 @@ async def _watch_partition_count(self): await self._poll_partition_count_action() async def _handle_partition_count_update(self, partition_count: int): - with self._lock: - current_count = len(self._publishers) + current_count = len(self._publishers) if current_count == partition_count: return if current_count > partition_count: @@ -82,14 +84,11 @@ async def _handle_partition_count_update(self, partition_count: int): await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()]) routing_policy = self._policy_factory(partition_count) - with self._lock: - self._publishers.update(new_publishers) - self._routing_policy = routing_policy + self._publishers.update(new_publishers) + self._routing_policy = routing_policy async def publish(self, message: PubSubMessage) -> PublishMetadata: - with self._lock: - partition = self._routing_policy.route(message) - assert partition in self._publishers - publisher = self._publishers[partition] - + partition = self._routing_policy.route(message) + assert partition in self._publishers + publisher = self._publishers[partition] return await publisher.publish(message) diff --git a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py index 152724cf..ee72e848 100644 --- a/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/partition_count_watcher_impl_test.py @@ -53,7 +53,7 @@ def watcher(mock_admin, topic): def set_box(): box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001) - # Initialize publisher on another thread with a different event loop. + # Initialize watcher on another thread with a different event loop. thread = threading.Thread(target=set_box) thread.start() thread.join()