Skip to content

Commit

Permalink
fix: Add ClientCache which forces new client creation after 75 uses (#…
Browse files Browse the repository at this point in the history
…188)

* fix: Add ClientCache which forces new client creation after 75 uses

This avoids the 100 stream per channel GRPC limit.

* fix: Access cache later for subscriber
  • Loading branch information
dpcollins-google committed Jul 16, 2021
1 parent c230ff1 commit 089789c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 10 deletions.
20 changes: 14 additions & 6 deletions google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py
Expand Up @@ -22,6 +22,7 @@
to_cps_subscribe_message,
add_id_to_cps_subscribe_transformer,
)
from google.cloud.pubsublite.internal.wire.client_cache import ClientCache
from google.cloud.pubsublite.types import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import (
AckSetTrackerImpl,
Expand Down Expand Up @@ -113,24 +114,31 @@ def _make_partition_subscriber_factory(
nack_handler: NackHandler,
message_transformer: MessageTransformer,
) -> PartitionSubscriberFactory:
subscribe_client_cache = ClientCache(
lambda: SubscriberServiceAsyncClient(
credentials=credentials, transport=transport, client_options=client_options
)
)
cursor_client_cache = ClientCache(
lambda: CursorServiceAsyncClient(
credentials=credentials, transport=transport, client_options=client_options
)
)

def factory(partition: Partition) -> AsyncSingleSubscriber:
subscribe_client = SubscriberServiceAsyncClient(
credentials=credentials, client_options=client_options, transport=transport
) # type: ignore
cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options, transport=transport) # type: ignore
final_metadata = merge_metadata(
base_metadata, subscription_routing_metadata(subscription, partition)
)

def subscribe_connection_factory(requests: AsyncIterator[SubscribeRequest]):
return subscribe_client.subscribe(
return subscribe_client_cache.get().subscribe(
requests, metadata=list(final_metadata.items())
)

def cursor_connection_factory(
requests: AsyncIterator[StreamingCommitCursorRequest],
):
return cursor_client.streaming_commit_cursor(
return cursor_client_cache.get().streaming_commit_cursor(
requests, metadata=list(final_metadata.items())
)

Expand Down
42 changes: 42 additions & 0 deletions google/cloud/pubsublite/internal/wire/client_cache.py
@@ -0,0 +1,42 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import Generic, TypeVar, Callable, Optional

_Client = TypeVar("_Client")
_MAX_CLIENT_USES = 75 # GRPC channels are limited to 100 concurrent streams.


class ClientCache(Generic[_Client]):
_ClientFactory = Callable[[], _Client]

_factory: _ClientFactory
_latest: Optional[_Client]
_remaining_uses: int
_lock: threading.Lock

def __init__(self, factory: _ClientFactory):
self._factory = factory
self._latest = None
self._remaining_uses = 0
self._lock = threading.Lock()

def get(self) -> _Client:
with self._lock:
if self._remaining_uses <= 0:
self._remaining_uses = _MAX_CLIENT_USES
self._latest = self._factory()
self._remaining_uses -= 1
return self._latest
13 changes: 9 additions & 4 deletions google/cloud/pubsublite/internal/wire/make_publisher.py
Expand Up @@ -18,6 +18,7 @@

from google.cloud.pubsublite.admin_client import AdminClient
from google.cloud.pubsublite.internal.endpoints import regional_endpoint
from google.cloud.pubsublite.internal.wire.client_cache import ClientCache
from google.cloud.pubsublite.internal.wire.default_routing_policy import (
DefaultRoutingPolicy,
)
Expand Down Expand Up @@ -88,16 +89,20 @@ def make_publisher(
client_options = ClientOptions(
api_endpoint=regional_endpoint(topic.location.region)
)
client = async_client.PublisherServiceAsyncClient(
credentials=credentials, transport=transport, client_options=client_options
) # type: ignore
client_cache = ClientCache(
lambda: async_client.PublisherServiceAsyncClient(
credentials=credentials, transport=transport, client_options=client_options
)
)

def publisher_factory(partition: Partition):
def connection_factory(requests: AsyncIterator[PublishRequest]):
final_metadata = merge_metadata(
metadata, topic_routing_metadata(topic, partition)
)
return client.publish(requests, metadata=list(final_metadata.items()))
return client_cache.get().publish(
requests, metadata=list(final_metadata.items())
)

return SinglePartitionPublisher(
InitialPublishRequest(topic=str(topic), partition=partition.value),
Expand Down

0 comments on commit 089789c

Please sign in to comment.