From 089789c54e876615157ec7e05b79000fc93e2dd9 Mon Sep 17 00:00:00 2001 From: dpcollins-google <40498610+dpcollins-google@users.noreply.github.com> Date: Thu, 15 Jul 2021 22:18:52 -0400 Subject: [PATCH] fix: Add ClientCache which forces new client creation after 75 uses (#188) * fix: Add ClientCache which forces new client creation after 75 uses This avoids the 100 stream per channel GRPC limit. * fix: Access cache later for subscriber --- .../cloudpubsub/internal/make_subscriber.py | 20 ++++++--- .../pubsublite/internal/wire/client_cache.py | 42 +++++++++++++++++++ .../internal/wire/make_publisher.py | 13 ++++-- 3 files changed, 65 insertions(+), 10 deletions(-) create mode 100644 google/cloud/pubsublite/internal/wire/client_cache.py diff --git a/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py index c7926e13..eee816e2 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py @@ -22,6 +22,7 @@ to_cps_subscribe_message, add_id_to_cps_subscribe_transformer, ) +from google.cloud.pubsublite.internal.wire.client_cache import ClientCache from google.cloud.pubsublite.types import FlowControlSettings from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import ( AckSetTrackerImpl, @@ -113,24 +114,31 @@ def _make_partition_subscriber_factory( nack_handler: NackHandler, message_transformer: MessageTransformer, ) -> PartitionSubscriberFactory: + subscribe_client_cache = ClientCache( + lambda: SubscriberServiceAsyncClient( + credentials=credentials, transport=transport, client_options=client_options + ) + ) + cursor_client_cache = ClientCache( + lambda: CursorServiceAsyncClient( + credentials=credentials, transport=transport, client_options=client_options + ) + ) + def factory(partition: Partition) -> AsyncSingleSubscriber: - subscribe_client = SubscriberServiceAsyncClient( - credentials=credentials, client_options=client_options, transport=transport - ) # type: ignore - cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options, transport=transport) # type: ignore final_metadata = merge_metadata( base_metadata, subscription_routing_metadata(subscription, partition) ) def subscribe_connection_factory(requests: AsyncIterator[SubscribeRequest]): - return subscribe_client.subscribe( + return subscribe_client_cache.get().subscribe( requests, metadata=list(final_metadata.items()) ) def cursor_connection_factory( requests: AsyncIterator[StreamingCommitCursorRequest], ): - return cursor_client.streaming_commit_cursor( + return cursor_client_cache.get().streaming_commit_cursor( requests, metadata=list(final_metadata.items()) ) diff --git a/google/cloud/pubsublite/internal/wire/client_cache.py b/google/cloud/pubsublite/internal/wire/client_cache.py new file mode 100644 index 00000000..6221e888 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/client_cache.py @@ -0,0 +1,42 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from typing import Generic, TypeVar, Callable, Optional + +_Client = TypeVar("_Client") +_MAX_CLIENT_USES = 75 # GRPC channels are limited to 100 concurrent streams. + + +class ClientCache(Generic[_Client]): + _ClientFactory = Callable[[], _Client] + + _factory: _ClientFactory + _latest: Optional[_Client] + _remaining_uses: int + _lock: threading.Lock + + def __init__(self, factory: _ClientFactory): + self._factory = factory + self._latest = None + self._remaining_uses = 0 + self._lock = threading.Lock() + + def get(self) -> _Client: + with self._lock: + if self._remaining_uses <= 0: + self._remaining_uses = _MAX_CLIENT_USES + self._latest = self._factory() + self._remaining_uses -= 1 + return self._latest diff --git a/google/cloud/pubsublite/internal/wire/make_publisher.py b/google/cloud/pubsublite/internal/wire/make_publisher.py index 799c22b4..9d834591 100644 --- a/google/cloud/pubsublite/internal/wire/make_publisher.py +++ b/google/cloud/pubsublite/internal/wire/make_publisher.py @@ -18,6 +18,7 @@ from google.cloud.pubsublite.admin_client import AdminClient from google.cloud.pubsublite.internal.endpoints import regional_endpoint +from google.cloud.pubsublite.internal.wire.client_cache import ClientCache from google.cloud.pubsublite.internal.wire.default_routing_policy import ( DefaultRoutingPolicy, ) @@ -88,16 +89,20 @@ def make_publisher( client_options = ClientOptions( api_endpoint=regional_endpoint(topic.location.region) ) - client = async_client.PublisherServiceAsyncClient( - credentials=credentials, transport=transport, client_options=client_options - ) # type: ignore + client_cache = ClientCache( + lambda: async_client.PublisherServiceAsyncClient( + credentials=credentials, transport=transport, client_options=client_options + ) + ) def publisher_factory(partition: Partition): def connection_factory(requests: AsyncIterator[PublishRequest]): final_metadata = merge_metadata( metadata, topic_routing_metadata(topic, partition) ) - return client.publish(requests, metadata=list(final_metadata.items())) + return client_cache.get().publish( + requests, metadata=list(final_metadata.items()) + ) return SinglePartitionPublisher( InitialPublishRequest(topic=str(topic), partition=partition.value),