Skip to content

Commit

Permalink
updates to address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
palmere-google committed Dec 10, 2020
1 parent fdddd97 commit 1e8e5a5
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 18 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -19,6 +19,7 @@ develop-eggs
.installed.cfg
lib
lib64
venv
__pycache__

# Installer logs
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/pubsublite/internal/wire/make_publisher.py
Expand Up @@ -49,6 +49,7 @@
max_messages=1000,
max_latency=0.05, # 50 ms
)
DEFAULT_PARTITION_POLL_PERIOD = 600 # ten minutes


def make_publisher(
Expand Down Expand Up @@ -108,7 +109,7 @@ def policy_factory(partition_count: int):
return DefaultRoutingPolicy(partition_count)

return PartitionCountWatchingPublisher(
PartitionCountWatcherImpl(admin_client, topic, 10),
PartitionCountWatcherImpl(admin_client, topic, DEFAULT_PARTITION_POLL_PERIOD),
publisher_factory,
policy_factory,
)
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from concurrent.futures.thread import ThreadPoolExecutor
import asyncio

Expand All @@ -25,6 +26,14 @@


class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable):
_admin: AdminClientInterface
_topic_path: TopicPath
_duration: float
_any_success: bool
_thread: ThreadPoolExecutor
_queue: asyncio.Queue
_poll_partition_loop: asyncio.Future

def __init__(
self, admin: AdminClientInterface, topic_path: TopicPath, duration: float
):
Expand Down Expand Up @@ -59,6 +68,7 @@ async def _poll_partition_loop(self):
except GoogleAPICallError as e:
if not self._any_success:
raise e
logging.exception("Failed to retrieve partition count")
await asyncio.sleep(self._duration)

async def get_partition_count(self) -> int:
Expand Down
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.
import asyncio
import sys
import threading
from typing import Callable
from typing import Callable, Dict

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
Expand All @@ -27,14 +26,19 @@


class PartitionCountWatchingPublisher(Publisher):
_publishers: Dict[Partition, Publisher]
_publisher_factory: Callable[[Partition], Publisher]
_policy_factory: Callable[[int], RoutingPolicy]
_watcher: PartitionCountWatcher
_partition_count_poller: asyncio.Future

def __init__(
self,
watcher: PartitionCountWatcher,
publisher_factory: Callable[[Partition], Publisher],
policy_factory: Callable[[int], RoutingPolicy],
):
self._publishers = {}
self._lock = threading.Lock()
self._publisher_factory = publisher_factory
self._policy_factory = policy_factory
self._watcher = watcher
Expand All @@ -55,9 +59,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
self._partition_count_poller.cancel()
await wait_ignore_cancelled(self._partition_count_poller)
await self._watcher.__aexit__(exc_type, exc_val, exc_tb)
with self._lock:
for publisher in self._publishers.values():
await publisher.__aexit__(exc_type, exc_val, exc_tb)
for publisher in self._publishers.values():
await publisher.__aexit__(exc_type, exc_val, exc_tb)

async def _poll_partition_count_action(self):
partition_count = await self._watcher.get_partition_count()
Expand All @@ -68,8 +71,7 @@ async def _watch_partition_count(self):
await self._poll_partition_count_action()

async def _handle_partition_count_update(self, partition_count: int):
with self._lock:
current_count = len(self._publishers)
current_count = len(self._publishers)
if current_count == partition_count:
return
if current_count > partition_count:
Expand All @@ -82,14 +84,11 @@ async def _handle_partition_count_update(self, partition_count: int):
await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()])
routing_policy = self._policy_factory(partition_count)

with self._lock:
self._publishers.update(new_publishers)
self._routing_policy = routing_policy
self._publishers.update(new_publishers)
self._routing_policy = routing_policy

async def publish(self, message: PubSubMessage) -> PublishMetadata:
with self._lock:
partition = self._routing_policy.route(message)
assert partition in self._publishers
publisher = self._publishers[partition]

partition = self._routing_policy.route(message)
assert partition in self._publishers
publisher = self._publishers[partition]
return await publisher.publish(message)
Expand Up @@ -53,7 +53,7 @@ def watcher(mock_admin, topic):
def set_box():
box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001)

# Initialize publisher on another thread with a different event loop.
# Initialize watcher on another thread with a different event loop.
thread = threading.Thread(target=set_box)
thread.start()
thread.join()
Expand Down

0 comments on commit 1e8e5a5

Please sign in to comment.