Skip to content

Commit

Permalink
fix: Numerous small performance and correctness issues (#211)
Browse files Browse the repository at this point in the history
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-pubsublite/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)

Fixes #<issue_number_goes_here> 🦕
  • Loading branch information
dpcollins-google committed Aug 16, 2021
1 parent 1248cd8 commit 358a1d8
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 182 deletions.
82 changes: 36 additions & 46 deletions google/cloud/pubsublite/cloudpubsub/internal/client_multiplexer.py
Expand Up @@ -16,40 +16,33 @@
import threading
from typing import Generic, TypeVar, Callable, Dict, Awaitable

from google.api_core.exceptions import FailedPrecondition

_Key = TypeVar("_Key")
_Client = TypeVar("_Client")


class ClientMultiplexer(Generic[_Key, _Client]):
_OpenedClientFactory = Callable[[], _Client]
_OpenedClientFactory = Callable[[_Key], _Client]
_ClientCloser = Callable[[_Client], None]

_factory: _OpenedClientFactory
_closer: _ClientCloser
_lock: threading.Lock
_live_clients: Dict[_Key, _Client]

def __init__(
self, closer: _ClientCloser = lambda client: client.__exit__(None, None, None)
self,
factory: _OpenedClientFactory,
closer: _ClientCloser = lambda client: client.__exit__(None, None, None),
):
self._factory = factory
self._closer = closer
self._lock = threading.Lock()
self._live_clients = {}

def get_or_create(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
def get_or_create(self, key: _Key) -> _Client:
with self._lock:
if key not in self._live_clients:
self._live_clients[key] = factory()
return self._live_clients[key]

def create_or_fail(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
with self._lock:
if key in self._live_clients:
raise FailedPrecondition(
f"Cannot create two clients with the same key. {_Key}"
)
self._live_clients[key] = factory()
self._live_clients[key] = self._factory(key)
return self._live_clients[key]

def try_erase(self, key: _Key, client: _Client):
Expand All @@ -75,52 +68,49 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class AsyncClientMultiplexer(Generic[_Key, _Client]):
_OpenedClientFactory = Callable[[], Awaitable[_Client]]
_OpenedClientFactory = Callable[[_Key], Awaitable[_Client]]
_ClientCloser = Callable[[_Client], Awaitable[None]]

_factory: _OpenedClientFactory
_closer: _ClientCloser
_lock: asyncio.Lock
_live_clients: Dict[_Key, _Client]
_live_clients: Dict[_Key, Awaitable[_Client]]

def __init__(
self, closer: _ClientCloser = lambda client: client.__aexit__(None, None, None)
self,
factory: _OpenedClientFactory,
closer: _ClientCloser = lambda client: client.__aexit__(None, None, None),
):
self._factory = factory
self._closer = closer
self._live_clients = {}

async def get_or_create(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
async with self._lock:
if key not in self._live_clients:
self._live_clients[key] = await factory()
return self._live_clients[key]

async def create_or_fail(self, key: _Key, factory: _OpenedClientFactory) -> _Client:
async with self._lock:
if key in self._live_clients:
raise FailedPrecondition(
f"Cannot create two clients with the same key. {_Key}"
)
self._live_clients[key] = await factory()
return self._live_clients[key]
async def get_or_create(self, key: _Key) -> _Client:
if key not in self._live_clients:
self._live_clients[key] = asyncio.ensure_future(self._factory(key))
return await self._live_clients[key]

async def try_erase(self, key: _Key, client: _Client):
async with self._lock:
if key not in self._live_clients:
return
current_client = self._live_clients[key]
if current_client is not client:
return
del self._live_clients[key]
if key not in self._live_clients:
return
client_future = self._live_clients[key]
current_client = await client_future
if current_client is not client:
return
# duplicate check after await that no one raced with us
if (
key not in self._live_clients
or self._live_clients[key] is not client_future
):
return
del self._live_clients[key]
await self._closer(client)

async def __aenter__(self):
self._lock = asyncio.Lock()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
live_clients: Dict[_Key, _Client]
async with self._lock:
live_clients = self._live_clients
self._live_clients = {}
live_clients: Dict[_Key, Awaitable[_Client]]
live_clients = self._live_clients
self._live_clients = {}
for topic, client in live_clients.items():
await self._closer(client)
await self._closer(await client)
Expand Up @@ -22,9 +22,9 @@ class ManagedEventLoop(ContextManager):
_loop: AbstractEventLoop
_thread: Thread

def __init__(self):
def __init__(self, name=None):
self._loop = new_event_loop()
self._thread = Thread(target=lambda: self._loop.run_forever())
self._thread = Thread(target=lambda: self._loop.run_forever(), name=name)

def __enter__(self):
self._thread.start()
Expand Down
Expand Up @@ -38,7 +38,14 @@ class MultiplexedAsyncPublisherClient(AsyncPublisherClientInterface):

def __init__(self, publisher_factory: AsyncPublisherFactory):
self._publisher_factory = publisher_factory
self._multiplexer = AsyncClientMultiplexer()
self._multiplexer = AsyncClientMultiplexer(
lambda topic: self._create_and_open(topic)
)

async def _create_and_open(self, topic: TopicPath):
client = self._publisher_factory(topic)
await client.__aenter__()
return client

@overrides
async def publish(
Expand All @@ -51,12 +58,7 @@ async def publish(
if isinstance(topic, str):
topic = TopicPath.parse(topic)

async def create_and_open():
client = self._publisher_factory(topic)
await client.__aenter__()
return client

publisher = await self._multiplexer.get_or_create(topic, create_and_open)
publisher = await self._multiplexer.get_or_create(topic)
try:
return await publisher.publish(
data=data, ordering_key=ordering_key, **attrs
Expand Down
Expand Up @@ -23,9 +23,6 @@

from google.cloud.pubsub_v1.subscriber.message import Message

from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import (
AsyncClientMultiplexer,
)
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
AsyncSubscriberFactory,
AsyncSingleSubscriber,
Expand Down Expand Up @@ -66,11 +63,11 @@ def __aiter__(self):

class MultiplexedAsyncSubscriberClient(AsyncSubscriberClientInterface):
_underlying_factory: AsyncSubscriberFactory
_multiplexer: AsyncClientMultiplexer[SubscriptionPath, AsyncSingleSubscriber]
_live_clients: Set[AsyncSingleSubscriber]

def __init__(self, underlying_factory: AsyncSubscriberFactory):
self._underlying_factory = underlying_factory
self._multiplexer = AsyncClientMultiplexer()
self._live_clients = set()

@overrides
async def subscribe(
Expand All @@ -82,25 +79,28 @@ async def subscribe(
if isinstance(subscription, str):
subscription = SubscriptionPath.parse(subscription)

async def create_and_open():
client = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
await client.__aenter__()
return client

subscriber = await self._multiplexer.get_or_create(
subscription, create_and_open
subscriber = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
await subscriber.__aenter__()
self._live_clients.add(subscriber)

return _SubscriberAsyncIterator(
subscriber, lambda: self._multiplexer.try_erase(subscription, subscriber)
subscriber, lambda: self._try_remove_client(subscriber)
)

@overrides
async def __aenter__(self):
await self._multiplexer.__aenter__()
return self

async def _try_remove_client(self, client: AsyncSingleSubscriber):
if client in self._live_clients:
self._live_clients.remove(client)
await client.__aexit__(None, None, None)

@overrides
async def __aexit__(self, exc_type, exc_value, traceback):
await self._multiplexer.__aexit__(exc_type, exc_value, traceback)
live_clients = self._live_clients
self._live_clients = set()
for client in live_clients:
await client.__aexit__(None, None, None)
Expand Up @@ -38,7 +38,9 @@ class MultiplexedPublisherClient(PublisherClientInterface):

def __init__(self, publisher_factory: PublisherFactory):
self._publisher_factory = publisher_factory
self._multiplexer = ClientMultiplexer()
self._multiplexer = ClientMultiplexer(
lambda topic: self._create_and_start_publisher(topic)
)

@overrides
def publish(
Expand All @@ -51,9 +53,7 @@ def publish(
if isinstance(topic, str):
topic = TopicPath.parse(topic)
try:
publisher = self._multiplexer.get_or_create(
topic, lambda: self._create_and_start_publisher(topic)
)
publisher = self._multiplexer.get_or_create(topic)
except GoogleAPICallError as e:
failed = Future()
failed.set_exception(e)
Expand Down
Expand Up @@ -14,12 +14,10 @@

from concurrent.futures.thread import ThreadPoolExecutor
from typing import Union, Optional, Set
from threading import Lock

from google.cloud.pubsub_v1.subscriber.futures import StreamingPullFuture

from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import (
ClientMultiplexer,
)
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
AsyncSubscriberFactory,
)
Expand All @@ -40,22 +38,16 @@ class MultiplexedSubscriberClient(SubscriberClientInterface):
_executor: ThreadPoolExecutor
_underlying_factory: AsyncSubscriberFactory

_multiplexer: ClientMultiplexer[SubscriptionPath, StreamingPullFuture]
_lock: Lock
_live_clients: Set[StreamingPullFuture]

def __init__(
self, executor: ThreadPoolExecutor, underlying_factory: AsyncSubscriberFactory
):
self._executor = executor
self._underlying_factory = underlying_factory

def cancel_streaming_pull_future(fut: StreamingPullFuture):
try:
fut.cancel()
fut.result()
except: # noqa: E722
pass

self._multiplexer = ClientMultiplexer(cancel_streaming_pull_future)
self._lock = Lock()
self._live_clients = set()

@overrides
def subscribe(
Expand All @@ -68,28 +60,40 @@ def subscribe(
if isinstance(subscription, str):
subscription = SubscriptionPath.parse(subscription)

def create_and_open():
underlying = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
subscriber = SubscriberImpl(underlying, callback, self._executor)
future = StreamingPullFuture(subscriber)
subscriber.__enter__()
return future

future = self._multiplexer.create_or_fail(subscription, create_and_open)
future.add_done_callback(
lambda fut: self._multiplexer.try_erase(subscription, future)
underlying = self._underlying_factory(
subscription, fixed_partitions, per_partition_flow_control_settings
)
subscriber = SubscriberImpl(underlying, callback, self._executor)
future = StreamingPullFuture(subscriber)
subscriber.__enter__()
future.add_done_callback(lambda fut: self._try_remove_client(future))
return future

@staticmethod
def _cancel_streaming_pull_future(fut: StreamingPullFuture):
try:
fut.cancel()
fut.result()
except: # noqa: E722
pass

def _try_remove_client(self, future: StreamingPullFuture):
with self._lock:
if future not in self._live_clients:
return
self._live_clients.remove(future)
self._cancel_streaming_pull_future(future)

@overrides
def __enter__(self):
self._executor.__enter__()
self._multiplexer.__enter__()
return self

@overrides
def __exit__(self, exc_type, exc_value, traceback):
self._multiplexer.__exit__(exc_type, exc_value, traceback)
with self._lock:
live_clients = self._live_clients
self._live_clients = set()
for client in live_clients:
self._cancel_streaming_pull_future(client)
self._executor.__exit__(exc_type, exc_value, traceback)
Expand Up @@ -30,7 +30,7 @@ class SinglePublisherImpl(SinglePublisher):

def __init__(self, underlying: AsyncSinglePublisher):
super().__init__()
self._managed_loop = ManagedEventLoop()
self._managed_loop = ManagedEventLoop("PublisherLoopThread")
self._underlying = underlying

def publish(
Expand Down
Expand Up @@ -54,7 +54,7 @@ def __init__(
self._underlying = underlying
self._callback = callback
self._unowned_executor = unowned_executor
self._event_loop = ManagedEventLoop()
self._event_loop = ManagedEventLoop("SubscriberLoopThread")
self._close_lock = threading.Lock()
self._failure = None
self._close_callback = None
Expand Down

0 comments on commit 358a1d8

Please sign in to comment.