Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Numerous small performance and correctness issues #211

Merged
merged 2 commits into from Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 32 additions & 46 deletions google/cloud/pubsublite/cloudpubsub/internal/client_multiplexer.py
Expand Up @@ -16,40 +16,33 @@
import threading
from typing import Generic, TypeVar, Callable, Dict, Awaitable

from google.api_core.exceptions import FailedPrecondition

_Key = TypeVar("_Key")
_Client = TypeVar("_Client")


class ClientMultiplexer(Generic[_Key, _Client]):
_OpenedClientFactory = Callable[[], _Client]
_OpenedClientFactory = Callable[[_Key], _Client]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a lot of "self" time in get_or_create callers from constructing the factory to pass in. This is in the publish hot path, hence this change

_ClientCloser = Callable[[_Client], None]

_factory: _OpenedClientFactory
_closer: _ClientCloser
_lock: threading.Lock
_live_clients: Dict[_Key, _Client]

def __init__(
self, closer: _ClientCloser = lambda client: client.__exit__(None, None, None)
self,
factory: _OpenedClientFactory,
closer: _ClientCloser = lambda client: client.__exit__(None, None, None),
):
self._factory = factory
self._closer = closer
self._lock = threading.Lock()
self._live_clients = {}

def get_or_create(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
def get_or_create(self, key: _Key) -> _Client:
with self._lock:
if key not in self._live_clients:
self._live_clients[key] = factory()
return self._live_clients[key]

def create_or_fail(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
with self._lock:
if key in self._live_clients:
raise FailedPrecondition(
f"Cannot create two clients with the same key. {_Key}"
)
self._live_clients[key] = factory()
self._live_clients[key] = self._factory(key)
return self._live_clients[key]

def try_erase(self, key: _Key, client: _Client):
Expand All @@ -75,52 +68,45 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class AsyncClientMultiplexer(Generic[_Key, _Client]):
_OpenedClientFactory = Callable[[], Awaitable[_Client]]
_OpenedClientFactory = Callable[[_Key], Awaitable[_Client]]
_ClientCloser = Callable[[_Client], Awaitable[None]]

_factory: _OpenedClientFactory
_closer: _ClientCloser
_lock: asyncio.Lock
_live_clients: Dict[_Key, _Client]
_live_clients: Dict[_Key, Awaitable[_Client]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a lot of time in acquiring the lock in the publish hotpath- this change makes the lock not necessary


def __init__(
self, closer: _ClientCloser = lambda client: client.__aexit__(None, None, None)
self,
factory: _OpenedClientFactory,
closer: _ClientCloser = lambda client: client.__aexit__(None, None, None),
):
self._factory = factory
self._closer = closer
self._live_clients = {}

async def get_or_create(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
async with self._lock:
if key not in self._live_clients:
self._live_clients[key] = await factory()
return self._live_clients[key]

async def create_or_fail(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
async with self._lock:
if key in self._live_clients:
raise FailedPrecondition(
f"Cannot create two clients with the same key. {_Key}"
)
self._live_clients[key] = await factory()
return self._live_clients[key]
async def get_or_create(self, key: _Key) -> _Client:
if key not in self._live_clients:
self._live_clients[key] = asyncio.ensure_future(self._factory(key))
return await self._live_clients[key]

async def try_erase(self, key: _Key, client: _Client):
async with self._lock:
if key not in self._live_clients:
return
current_client = self._live_clients[key]
if current_client is not client:
return
del self._live_clients[key]
if key not in self._live_clients:
return
current_client = await self._live_clients[key]
if current_client is not client:
return
# duplicate check after await that no one raced with us
if key not in self._live_clients:
return
del self._live_clients[key]
await self._closer(client)

async def __aenter__(self):
self._lock = asyncio.Lock()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
live_clients: Dict[_Key, _Client]
async with self._lock:
live_clients = self._live_clients
self._live_clients = {}
live_clients: Dict[_Key, Awaitable[_Client]]
live_clients = self._live_clients
self._live_clients = {}
for topic, client in live_clients.items():
await self._closer(client)
await self._closer(await client)
Expand Up @@ -22,9 +22,9 @@ class ManagedEventLoop(ContextManager):
_loop: AbstractEventLoop
_thread: Thread

def __init__(self):
def __init__(self, name=None):
self._loop = new_event_loop()
self._thread = Thread(target=lambda: self._loop.run_forever())
self._thread = Thread(target=lambda: self._loop.run_forever(), name=name)

def __enter__(self):
self._thread.start()
Expand Down
Expand Up @@ -38,7 +38,14 @@ class MultiplexedAsyncPublisherClient(AsyncPublisherClientInterface):

def __init__(self, publisher_factory: AsyncPublisherFactory):
self._publisher_factory = publisher_factory
self._multiplexer = AsyncClientMultiplexer()
self._multiplexer = AsyncClientMultiplexer(
lambda topic: self._create_and_open(topic)
)

async def _create_and_open(self, topic: TopicPath):
client = self._publisher_factory(topic)
await client.__aenter__()
return client

@overrides
async def publish(
Expand All @@ -51,12 +58,7 @@ async def publish(
if isinstance(topic, str):
topic = TopicPath.parse(topic)

async def create_and_open():
client = self._publisher_factory(topic)
await client.__aenter__()
return client

publisher = await self._multiplexer.get_or_create(topic, create_and_open)
publisher = await self._multiplexer.get_or_create(topic)
try:
return await publisher.publish(
data=data, ordering_key=ordering_key, **attrs
Expand Down
Expand Up @@ -23,9 +23,6 @@

from google.cloud.pubsub_v1.subscriber.message import Message

from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import (
AsyncClientMultiplexer,
)
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
AsyncSubscriberFactory,
AsyncSingleSubscriber,
Expand Down Expand Up @@ -66,11 +63,11 @@ def __aiter__(self):

class MultiplexedAsyncSubscriberClient(AsyncSubscriberClientInterface):
_underlying_factory: AsyncSubscriberFactory
_multiplexer: AsyncClientMultiplexer[SubscriptionPath, AsyncSingleSubscriber]
_live_clients: Set[AsyncSingleSubscriber]

def __init__(self, underlying_factory: AsyncSubscriberFactory):
self._underlying_factory = underlying_factory
self._multiplexer = AsyncClientMultiplexer()
self._live_clients = set()

@overrides
async def subscribe(
Expand All @@ -82,25 +79,28 @@ async def subscribe(
if isinstance(subscription, str):
subscription = SubscriptionPath.parse(subscription)

async def create_and_open():
client = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
await client.__aenter__()
return client

subscriber = await self._multiplexer.get_or_create(
subscription, create_and_open
subscriber = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
await subscriber.__aenter__()
self._live_clients.add(subscriber)

return _SubscriberAsyncIterator(
subscriber, lambda: self._multiplexer.try_erase(subscription, subscriber)
subscriber, lambda: self._try_remove_client(subscriber)
)

@overrides
async def __aenter__(self):
await self._multiplexer.__aenter__()
return self

async def _try_remove_client(self, client: AsyncSingleSubscriber):
if client in self._live_clients:
self._live_clients.remove(client)
await client.__aexit__(None, None, None)

@overrides
async def __aexit__(self, exc_type, exc_value, traceback):
await self._multiplexer.__aexit__(exc_type, exc_value, traceback)
live_clients = self._live_clients
self._live_clients = set()
for client in live_clients:
await client.__aexit__(None, None, None)
Expand Up @@ -38,7 +38,9 @@ class MultiplexedPublisherClient(PublisherClientInterface):

def __init__(self, publisher_factory: PublisherFactory):
self._publisher_factory = publisher_factory
self._multiplexer = ClientMultiplexer()
self._multiplexer = ClientMultiplexer(
lambda topic: self._create_and_start_publisher(topic)
)

@overrides
def publish(
Expand All @@ -51,9 +53,7 @@ def publish(
if isinstance(topic, str):
topic = TopicPath.parse(topic)
try:
publisher = self._multiplexer.get_or_create(
topic, lambda: self._create_and_start_publisher(topic)
)
publisher = self._multiplexer.get_or_create(topic)
except GoogleAPICallError as e:
failed = Future()
failed.set_exception(e)
Expand Down
Expand Up @@ -14,12 +14,10 @@

from concurrent.futures.thread import ThreadPoolExecutor
from typing import Union, Optional, Set
from threading import Lock

from google.cloud.pubsub_v1.subscriber.futures import StreamingPullFuture

from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import (
ClientMultiplexer,
)
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
AsyncSubscriberFactory,
)
Expand All @@ -40,22 +38,16 @@ class MultiplexedSubscriberClient(SubscriberClientInterface):
_executor: ThreadPoolExecutor
_underlying_factory: AsyncSubscriberFactory

_multiplexer: ClientMultiplexer[SubscriptionPath, StreamingPullFuture]
_lock: Lock
_live_clients: Set[StreamingPullFuture]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the subscriber enforced that there was only one open subscription stream per-subscription per-client, but there's actually no need for this.


def __init__(
self, executor: ThreadPoolExecutor, underlying_factory: AsyncSubscriberFactory
):
self._executor = executor
self._underlying_factory = underlying_factory

def cancel_streaming_pull_future(fut: StreamingPullFuture):
try:
fut.cancel()
fut.result()
except: # noqa: E722
pass

self._multiplexer = ClientMultiplexer(cancel_streaming_pull_future)
self._lock = Lock()
self._live_clients = set()

@overrides
def subscribe(
Expand All @@ -68,28 +60,40 @@ def subscribe(
if isinstance(subscription, str):
subscription = SubscriptionPath.parse(subscription)

def create_and_open():
underlying = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
subscriber = SubscriberImpl(underlying, callback, self._executor)
future = StreamingPullFuture(subscriber)
subscriber.__enter__()
return future

future = self._multiplexer.create_or_fail(subscription, create_and_open)
future.add_done_callback(
lambda fut: self._multiplexer.try_erase(subscription, future)
underlying = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
subscriber = SubscriberImpl(underlying, callback, self._executor)
future = StreamingPullFuture(subscriber)
subscriber.__enter__()
future.add_done_callback(lambda fut: self._try_remove_client(future))
return future

@staticmethod
def _cancel_streaming_pull_future(fut: StreamingPullFuture):
try:
fut.cancel()
fut.result()
except: # noqa: E722
pass

def _try_remove_client(self, future: StreamingPullFuture):
with self._lock:
if future not in self._live_clients:
return
self._live_clients.remove(future)
self._cancel_streaming_pull_future(future)

@overrides
def __enter__(self):
self._executor.__enter__()
self._multiplexer.__enter__()
return self

@overrides
def __exit__(self, exc_type, exc_value, traceback):
self._multiplexer.__exit__(exc_type, exc_value, traceback)
with self._lock:
live_clients = self._live_clients
self._live_clients = set()
for client in live_clients:
self._cancel_streaming_pull_future(client)
self._executor.__exit__(exc_type, exc_value, traceback)
Expand Up @@ -30,7 +30,7 @@ class SinglePublisherImpl(SinglePublisher):

def __init__(self, underlying: AsyncSinglePublisher):
super().__init__()
self._managed_loop = ManagedEventLoop()
self._managed_loop = ManagedEventLoop("PublisherLoopThread")
self._underlying = underlying

def publish(
Expand Down
Expand Up @@ -54,7 +54,7 @@ def __init__(
self._underlying = underlying
self._callback = callback
self._unowned_executor = unowned_executor
self._event_loop = ManagedEventLoop()
self._event_loop = ManagedEventLoop("SubscriberLoopThread")
self._close_lock = threading.Lock()
self._failure = None
self._close_callback = None
Expand Down