diff --git a/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py index 39b69b74..cc91e101 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/async_publisher_impl.py @@ -25,9 +25,9 @@ async def publish( psl_message = from_cps_publish_message(cps_message) return (await self._publisher.publish(psl_message)).encode() - def __aenter__(self): - self._publisher.__aenter__() + async def __aenter__(self): + await self._publisher.__aenter__() return self - def __aexit__(self, exc_type, exc_value, traceback): - self._publisher.__aexit__(exc_type, exc_value, traceback) + async def __aexit__(self, exc_type, exc_value, traceback): + await self._publisher.__aexit__(exc_type, exc_value, traceback) diff --git a/google/cloud/pubsublite/internal/wire/default_routing_policy.py b/google/cloud/pubsublite/internal/wire/default_routing_policy.py index 354bb8b2..90858ab9 100644 --- a/google/cloud/pubsublite/internal/wire/default_routing_policy.py +++ b/google/cloud/pubsublite/internal/wire/default_routing_policy.py @@ -17,15 +17,15 @@ class DefaultRoutingPolicy(RoutingPolicy): def __init__(self, num_partitions: int): self._num_partitions = num_partitions - self._current_round_robin = Partition(random.randint(0, num_partitions)) + self._current_round_robin = Partition(random.randint(0, num_partitions - 1)) def route(self, message: PubSubMessage) -> Partition: """Route the message using the key if set or round robin if unset.""" if not message.key: result = Partition(self._current_round_robin.value) - self._current_round_robin.value = ( - self._current_round_robin.value + 1 - ) % self._num_partitions + self._current_round_robin = Partition( + (self._current_round_robin.value + 1) % self._num_partitions + ) return result sha = hashlib.sha256() sha.update(message.key) diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index 5688e3e7..fa7fb6f2 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -9,10 +9,17 @@ class PermanentFailable: """A class that can experience permanent failures, with helpers for forwarding these to client actions.""" - _failure_task: asyncio.Future + _maybe_failure_task: Optional[asyncio.Future] def __init__(self): - self._failure_task = asyncio.Future() + self._maybe_failure_task = None + + @property + def _failure_task(self) -> asyncio.Future: + """Get the failure task, initializing it lazily, since it needs to be initialized in the event loop.""" + if self._maybe_failure_task is None: + self._maybe_failure_task = asyncio.Future() + return self._maybe_failure_task async def await_unless_failed(self, awaitable: Awaitable[T]) -> T: """ diff --git a/google/cloud/pubsublite/location.py b/google/cloud/pubsublite/location.py index 6f617b8e..6bacd83a 100644 --- a/google/cloud/pubsublite/location.py +++ b/google/cloud/pubsublite/location.py @@ -6,6 +6,9 @@ class CloudRegion(NamedTuple): name: str + def __str__(self): + return self.name + class CloudZone(NamedTuple): region: CloudRegion diff --git a/google/cloud/pubsublite/routing_metadata.py b/google/cloud/pubsublite/routing_metadata.py index 59e71246..ae8e6d00 100644 --- a/google/cloud/pubsublite/routing_metadata.py +++ b/google/cloud/pubsublite/routing_metadata.py @@ -8,12 +8,14 @@ def topic_routing_metadata(topic: TopicPath, partition: Partition) -> Mapping[str, str]: - encoded = urlencode(topic) - return {_PARAMS_HEADER: f"partition={partition.value}&topic={encoded}"} + encoded = urlencode({"partition": str(partition.value), "topic": str(topic)}) + return {_PARAMS_HEADER: encoded} def subscription_routing_metadata( subscription: SubscriptionPath, partition: Partition ) -> Mapping[str, str]: - encoded = urlencode(subscription) - return {_PARAMS_HEADER: f"partition={partition.value}&subscription={encoded}"} + encoded = urlencode( + {"partition": str(partition.value), "subscription": str(subscription)} + ) + return {_PARAMS_HEADER: encoded}