From 358a1d8a429086ee75373260eb087a9dd171e3e6 Mon Sep 17 00:00:00 2001 From: dpcollins-google <40498610+dpcollins-google@users.noreply.github.com> Date: Mon, 16 Aug 2021 10:16:28 -0400 Subject: [PATCH] fix: Numerous small performance and correctness issues (#211) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-pubsublite/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes # 🦕 --- .../internal/client_multiplexer.py | 82 ++++++++----------- .../internal/managed_event_loop.py | 4 +- .../multiplexed_async_publisher_client.py | 16 ++-- .../multiplexed_async_subscriber_client.py | 34 ++++---- .../internal/multiplexed_publisher_client.py | 8 +- .../internal/multiplexed_subscriber_client.py | 58 +++++++------ .../cloudpubsub/internal/publisher_impl.py | 2 +- .../cloudpubsub/internal/subscriber_impl.py | 2 +- .../internal/wire/committer_impl.py | 17 +--- .../internal/wire/serial_batcher.py | 46 ++++++++--- .../wire/single_partition_publisher.py | 32 +++++--- google/cloud/pubsublite/types/paths.py | 2 +- testing/constraints-3.6.txt | 3 - .../internal/async_client_multiplexer_test.py | 26 +++--- .../internal/client_multiplexer_test.py | 27 +++--- 15 files changed, 177 insertions(+), 182 deletions(-) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/client_multiplexer.py b/google/cloud/pubsublite/cloudpubsub/internal/client_multiplexer.py index 549d065e..69bbe4ac 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/client_multiplexer.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/client_multiplexer.py @@ -16,40 +16,33 @@ import threading from typing import Generic, TypeVar, Callable, Dict, Awaitable -from google.api_core.exceptions import FailedPrecondition - _Key = TypeVar("_Key") _Client = TypeVar("_Client") class ClientMultiplexer(Generic[_Key, _Client]): - _OpenedClientFactory = Callable[[], _Client] + _OpenedClientFactory = Callable[[_Key], _Client] _ClientCloser = Callable[[_Client], None] + _factory: _OpenedClientFactory _closer: _ClientCloser _lock: threading.Lock _live_clients: Dict[_Key, _Client] def __init__( - self, closer: _ClientCloser = lambda client: client.__exit__(None, None, None) + self, + factory: _OpenedClientFactory, + closer: _ClientCloser = lambda client: client.__exit__(None, None, None), ): + self._factory = factory self._closer = closer self._lock = threading.Lock() self._live_clients = {} - def get_or_create(self, key: _Key, factory: _OpenedClientFactory) -> _Client: + def get_or_create(self, key: _Key) -> _Client: with self._lock: if key not in self._live_clients: - self._live_clients[key] = factory() - return self._live_clients[key] - - def create_or_fail(self, key: _Key, factory: _OpenedClientFactory) -> _Client: - with self._lock: - if key in self._live_clients: - raise FailedPrecondition( - f"Cannot create two clients with the same key. {_Key}" - ) - self._live_clients[key] = factory() + self._live_clients[key] = self._factory(key) return self._live_clients[key] def try_erase(self, key: _Key, client: _Client): @@ -75,52 +68,49 @@ def __exit__(self, exc_type, exc_val, exc_tb): class AsyncClientMultiplexer(Generic[_Key, _Client]): - _OpenedClientFactory = Callable[[], Awaitable[_Client]] + _OpenedClientFactory = Callable[[_Key], Awaitable[_Client]] _ClientCloser = Callable[[_Client], Awaitable[None]] + _factory: _OpenedClientFactory _closer: _ClientCloser - _lock: asyncio.Lock - _live_clients: Dict[_Key, _Client] + _live_clients: Dict[_Key, Awaitable[_Client]] def __init__( - self, closer: _ClientCloser = lambda client: client.__aexit__(None, None, None) + self, + factory: _OpenedClientFactory, + closer: _ClientCloser = lambda client: client.__aexit__(None, None, None), ): + self._factory = factory self._closer = closer self._live_clients = {} - async def get_or_create(self, key: _Key, factory: _OpenedClientFactory) -> _Client: - async with self._lock: - if key not in self._live_clients: - self._live_clients[key] = await factory() - return self._live_clients[key] - - async def create_or_fail(self, key: _Key, factory: _OpenedClientFactory) -> _Client: - async with self._lock: - if key in self._live_clients: - raise FailedPrecondition( - f"Cannot create two clients with the same key. {_Key}" - ) - self._live_clients[key] = await factory() - return self._live_clients[key] + async def get_or_create(self, key: _Key) -> _Client: + if key not in self._live_clients: + self._live_clients[key] = asyncio.ensure_future(self._factory(key)) + return await self._live_clients[key] async def try_erase(self, key: _Key, client: _Client): - async with self._lock: - if key not in self._live_clients: - return - current_client = self._live_clients[key] - if current_client is not client: - return - del self._live_clients[key] + if key not in self._live_clients: + return + client_future = self._live_clients[key] + current_client = await client_future + if current_client is not client: + return + # duplicate check after await that no one raced with us + if ( + key not in self._live_clients + or self._live_clients[key] is not client_future + ): + return + del self._live_clients[key] await self._closer(client) async def __aenter__(self): - self._lock = asyncio.Lock() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - live_clients: Dict[_Key, _Client] - async with self._lock: - live_clients = self._live_clients - self._live_clients = {} + live_clients: Dict[_Key, Awaitable[_Client]] + live_clients = self._live_clients + self._live_clients = {} for topic, client in live_clients.items(): - await self._closer(client) + await self._closer(await client) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py b/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py index cbc8f971..434ab995 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py @@ -22,9 +22,9 @@ class ManagedEventLoop(ContextManager): _loop: AbstractEventLoop _thread: Thread - def __init__(self): + def __init__(self, name=None): self._loop = new_event_loop() - self._thread = Thread(target=lambda: self._loop.run_forever()) + self._thread = Thread(target=lambda: self._loop.run_forever(), name=name) def __enter__(self): self._thread.start() diff --git a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_publisher_client.py b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_publisher_client.py index 8e1b766e..9cacd275 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_publisher_client.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_publisher_client.py @@ -38,7 +38,14 @@ class MultiplexedAsyncPublisherClient(AsyncPublisherClientInterface): def __init__(self, publisher_factory: AsyncPublisherFactory): self._publisher_factory = publisher_factory - self._multiplexer = AsyncClientMultiplexer() + self._multiplexer = AsyncClientMultiplexer( + lambda topic: self._create_and_open(topic) + ) + + async def _create_and_open(self, topic: TopicPath): + client = self._publisher_factory(topic) + await client.__aenter__() + return client @overrides async def publish( @@ -51,12 +58,7 @@ async def publish( if isinstance(topic, str): topic = TopicPath.parse(topic) - async def create_and_open(): - client = self._publisher_factory(topic) - await client.__aenter__() - return client - - publisher = await self._multiplexer.get_or_create(topic, create_and_open) + publisher = await self._multiplexer.get_or_create(topic) try: return await publisher.publish( data=data, ordering_key=ordering_key, **attrs diff --git a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py index 17115d59..4187af33 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py @@ -23,9 +23,6 @@ from google.cloud.pubsub_v1.subscriber.message import Message -from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import ( - AsyncClientMultiplexer, -) from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import ( AsyncSubscriberFactory, AsyncSingleSubscriber, @@ -66,11 +63,11 @@ def __aiter__(self): class MultiplexedAsyncSubscriberClient(AsyncSubscriberClientInterface): _underlying_factory: AsyncSubscriberFactory - _multiplexer: AsyncClientMultiplexer[SubscriptionPath, AsyncSingleSubscriber] + _live_clients: Set[AsyncSingleSubscriber] def __init__(self, underlying_factory: AsyncSubscriberFactory): self._underlying_factory = underlying_factory - self._multiplexer = AsyncClientMultiplexer() + self._live_clients = set() @overrides async def subscribe( @@ -82,25 +79,28 @@ async def subscribe( if isinstance(subscription, str): subscription = SubscriptionPath.parse(subscription) - async def create_and_open(): - client = self._underlying_factory( - subscription, fixed_partitions, per_partition_flow_control_settings - ) - await client.__aenter__() - return client - - subscriber = await self._multiplexer.get_or_create( - subscription, create_and_open + subscriber = self._underlying_factory( + subscription, fixed_partitions, per_partition_flow_control_settings ) + await subscriber.__aenter__() + self._live_clients.add(subscriber) + return _SubscriberAsyncIterator( - subscriber, lambda: self._multiplexer.try_erase(subscription, subscriber) + subscriber, lambda: self._try_remove_client(subscriber) ) @overrides async def __aenter__(self): - await self._multiplexer.__aenter__() return self + async def _try_remove_client(self, client: AsyncSingleSubscriber): + if client in self._live_clients: + self._live_clients.remove(client) + await client.__aexit__(None, None, None) + @overrides async def __aexit__(self, exc_type, exc_value, traceback): - await self._multiplexer.__aexit__(exc_type, exc_value, traceback) + live_clients = self._live_clients + self._live_clients = set() + for client in live_clients: + await client.__aexit__(None, None, None) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_publisher_client.py b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_publisher_client.py index 60f9246b..4adb7a75 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_publisher_client.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_publisher_client.py @@ -38,7 +38,9 @@ class MultiplexedPublisherClient(PublisherClientInterface): def __init__(self, publisher_factory: PublisherFactory): self._publisher_factory = publisher_factory - self._multiplexer = ClientMultiplexer() + self._multiplexer = ClientMultiplexer( + lambda topic: self._create_and_start_publisher(topic) + ) @overrides def publish( @@ -51,9 +53,7 @@ def publish( if isinstance(topic, str): topic = TopicPath.parse(topic) try: - publisher = self._multiplexer.get_or_create( - topic, lambda: self._create_and_start_publisher(topic) - ) + publisher = self._multiplexer.get_or_create(topic) except GoogleAPICallError as e: failed = Future() failed.set_exception(e) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_subscriber_client.py b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_subscriber_client.py index 928984ae..6978fd8b 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_subscriber_client.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_subscriber_client.py @@ -14,12 +14,10 @@ from concurrent.futures.thread import ThreadPoolExecutor from typing import Union, Optional, Set +from threading import Lock from google.cloud.pubsub_v1.subscriber.futures import StreamingPullFuture -from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import ( - ClientMultiplexer, -) from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import ( AsyncSubscriberFactory, ) @@ -40,22 +38,16 @@ class MultiplexedSubscriberClient(SubscriberClientInterface): _executor: ThreadPoolExecutor _underlying_factory: AsyncSubscriberFactory - _multiplexer: ClientMultiplexer[SubscriptionPath, StreamingPullFuture] + _lock: Lock + _live_clients: Set[StreamingPullFuture] def __init__( self, executor: ThreadPoolExecutor, underlying_factory: AsyncSubscriberFactory ): self._executor = executor self._underlying_factory = underlying_factory - - def cancel_streaming_pull_future(fut: StreamingPullFuture): - try: - fut.cancel() - fut.result() - except: # noqa: E722 - pass - - self._multiplexer = ClientMultiplexer(cancel_streaming_pull_future) + self._lock = Lock() + self._live_clients = set() @overrides def subscribe( @@ -68,28 +60,40 @@ def subscribe( if isinstance(subscription, str): subscription = SubscriptionPath.parse(subscription) - def create_and_open(): - underlying = self._underlying_factory( - subscription, fixed_partitions, per_partition_flow_control_settings - ) - subscriber = SubscriberImpl(underlying, callback, self._executor) - future = StreamingPullFuture(subscriber) - subscriber.__enter__() - return future - - future = self._multiplexer.create_or_fail(subscription, create_and_open) - future.add_done_callback( - lambda fut: self._multiplexer.try_erase(subscription, future) + underlying = self._underlying_factory( + subscription, fixed_partitions, per_partition_flow_control_settings ) + subscriber = SubscriberImpl(underlying, callback, self._executor) + future = StreamingPullFuture(subscriber) + subscriber.__enter__() + future.add_done_callback(lambda fut: self._try_remove_client(future)) return future + @staticmethod + def _cancel_streaming_pull_future(fut: StreamingPullFuture): + try: + fut.cancel() + fut.result() + except: # noqa: E722 + pass + + def _try_remove_client(self, future: StreamingPullFuture): + with self._lock: + if future not in self._live_clients: + return + self._live_clients.remove(future) + self._cancel_streaming_pull_future(future) + @overrides def __enter__(self): self._executor.__enter__() - self._multiplexer.__enter__() return self @overrides def __exit__(self, exc_type, exc_value, traceback): - self._multiplexer.__exit__(exc_type, exc_value, traceback) + with self._lock: + live_clients = self._live_clients + self._live_clients = set() + for client in live_clients: + self._cancel_streaming_pull_future(client) self._executor.__exit__(exc_type, exc_value, traceback) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py index 7f25e77d..52ff13e3 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/publisher_impl.py @@ -30,7 +30,7 @@ class SinglePublisherImpl(SinglePublisher): def __init__(self, underlying: AsyncSinglePublisher): super().__init__() - self._managed_loop = ManagedEventLoop() + self._managed_loop = ManagedEventLoop("PublisherLoopThread") self._underlying = underlying def publish( diff --git a/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py index 98473fe0..ed6c8368 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py @@ -54,7 +54,7 @@ def __init__( self._underlying = underlying self._callback = callback self._unowned_executor = unowned_executor - self._event_loop = ManagedEventLoop() + self._event_loop = ManagedEventLoop("SubscriberLoopThread") self._close_lock = threading.Lock() self._failure = None self._close_callback = None diff --git a/google/cloud/pubsublite/internal/wire/committer_impl.py b/google/cloud/pubsublite/internal/wire/committer_impl.py index 4bf0c021..d8f06f03 100644 --- a/google/cloud/pubsublite/internal/wire/committer_impl.py +++ b/google/cloud/pubsublite/internal/wire/committer_impl.py @@ -13,7 +13,7 @@ # limitations under the License. import asyncio -from typing import Optional, List, Iterable +from typing import Optional, List import logging @@ -28,10 +28,7 @@ ConnectionReinitializer, ) from google.cloud.pubsublite.internal.wire.connection import Connection -from google.cloud.pubsublite.internal.wire.serial_batcher import ( - SerialBatcher, - BatchTester, -) +from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher from google.cloud.pubsublite_v1 import Cursor from google.cloud.pubsublite_v1.types import ( StreamingCommitCursorRequest, @@ -49,7 +46,6 @@ class CommitterImpl( ConnectionReinitializer[ StreamingCommitCursorRequest, StreamingCommitCursorResponse ], - BatchTester[Cursor], ): _initial: InitialCommitCursorRequest _flush_seconds: float @@ -76,7 +72,7 @@ def __init__( self._initial = initial self._flush_seconds = flush_seconds self._connection = RetryingConnection(factory, self) - self._batcher = SerialBatcher(self) + self._batcher = SerialBatcher() self._outstanding_commits = [] self._receiver = None self._flusher = None @@ -167,9 +163,6 @@ async def wait_until_empty(self): async def commit(self, cursor: Cursor) -> None: future = self._batcher.add(cursor) - if self._batcher.should_flush(): - # always returns false currently, here in case this changes in the future. - await self._flush() await future async def reinitialize( @@ -199,7 +192,3 @@ async def reinitialize( req.commit.cursor = rollup[-1].request await connection.write(req) self._start_loopers() - - def test(self, requests: Iterable[Cursor]) -> bool: - # There is no bound on the number of outstanding cursors. - return False diff --git a/google/cloud/pubsublite/internal/wire/serial_batcher.py b/google/cloud/pubsublite/internal/wire/serial_batcher.py index 89cda847..90554d60 100644 --- a/google/cloud/pubsublite/internal/wire/serial_batcher.py +++ b/google/cloud/pubsublite/internal/wire/serial_batcher.py @@ -13,38 +13,56 @@ # limitations under the License. from abc import abstractmethod -from typing import Generic, List, Iterable +from typing import Generic, List, NamedTuple import asyncio +from overrides import overrides from google.cloud.pubsublite.internal.wire.connection import Request, Response from google.cloud.pubsublite.internal.wire.work_item import WorkItem -class BatchTester(Generic[Request]): - """A BatchTester determines whether a given batch of messages must be sent.""" +class BatchSize(NamedTuple): + element_count: int + byte_count: int + + def __add__(self, other: "BatchSize") -> "BatchSize": + return BatchSize( + self.element_count + other.element_count, self.byte_count + other.byte_count + ) + + +class RequestSizer(Generic[Request]): + """A RequestSizer determines the size of a request.""" @abstractmethod - def test(self, requests: Iterable[Request]) -> bool: + def get_size(self, request: Request) -> BatchSize: """ Args: - requests: The current outstanding batch. + request: A single request. - Returns: Whether that batch must be sent. + Returns: The BatchSize of this request """ raise NotImplementedError() +class IgnoredRequestSizer(RequestSizer[Request]): + @overrides + def get_size(self, request: Request) -> BatchSize: + return BatchSize(0, 0) + + class SerialBatcher(Generic[Request, Response]): - _tester: BatchTester[Request] + _sizer: RequestSizer[Request] _requests: List[WorkItem[Request, Response]] # A list of outstanding requests + _batch_size: BatchSize - def __init__(self, tester: BatchTester[Request]): - self._tester = tester + def __init__(self, sizer: RequestSizer[Request] = IgnoredRequestSizer()): + self._sizer = sizer self._requests = [] + self._batch_size = BatchSize(0, 0) def add(self, request: Request) -> "asyncio.Future[Response]": - """Add a new request to this batcher. Callers must always call should_flush() after add, and flush() if that returns - true. + """Add a new request to this batcher. Args: request: The request to send. @@ -54,12 +72,14 @@ def add(self, request: Request) -> "asyncio.Future[Response]": """ item = WorkItem[Request, Response](request) self._requests.append(item) + self._batch_size += self._sizer.get_size(request) return item.response_future - def should_flush(self) -> bool: - return self._tester.test(item.request for item in self._requests) + def size(self) -> BatchSize: + return self._batch_size def flush(self) -> List[WorkItem[Request, Response]]: requests = self._requests self._requests = [] + self._batch_size = BatchSize(0, 0) return requests diff --git a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py index cc8cffe4..7e425e3f 100644 --- a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py +++ b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py @@ -13,8 +13,9 @@ # limitations under the License. import asyncio -from typing import Optional, List, Iterable +from typing import Optional, List +from overrides import overrides import logging from google.cloud.pubsub_v1.types import BatchSettings @@ -31,7 +32,8 @@ from google.cloud.pubsublite.internal.wire.connection import Connection from google.cloud.pubsublite.internal.wire.serial_batcher import ( SerialBatcher, - BatchTester, + RequestSizer, + BatchSize, ) from google.cloud.pubsublite.types import Partition, MessageMetadata from google.cloud.pubsublite_v1.types import ( @@ -55,7 +57,7 @@ class SinglePartitionPublisher( Publisher, ConnectionReinitializer[PublishRequest, PublishResponse], - BatchTester[PubSubMessage], + RequestSizer[PubSubMessage], ): _initial: InitialPublishRequest _batching_settings: BatchSettings @@ -162,10 +164,10 @@ async def _flush(self): self._fail_if_retrying_failed() async def publish(self, message: PubSubMessage) -> MessageMetadata: - cursor_future = self._batcher.add(message) - if self._batcher.should_flush(): + future = self._batcher.add(message) + if self._should_flush(): await self._flush() - return MessageMetadata(self._partition, await cursor_future) + return MessageMetadata(self._partition, await future) async def reinitialize( self, @@ -189,10 +191,14 @@ async def reinitialize( await connection.write(aggregate) self._start_loopers() - def test(self, requests: Iterable[PubSubMessage]) -> bool: - request_count = 0 - byte_count = 0 - for req in requests: - request_count += 1 - byte_count += PubSubMessage.pb(req).ByteSize() - return (request_count >= _MAX_MESSAGES) or (byte_count >= _MAX_BYTES) + @overrides + def get_size(self, request: PubSubMessage) -> BatchSize: + return BatchSize( + element_count=1, byte_count=PubSubMessage.pb(request).ByteSize() + ) + + def _should_flush(self) -> bool: + size = self._batcher.size() + return (size.element_count >= self._batching_settings.max_messages) or ( + size.byte_count >= self._batching_settings.max_bytes + ) diff --git a/google/cloud/pubsublite/types/paths.py b/google/cloud/pubsublite/types/paths.py index 9f60e182..eaa52fdf 100644 --- a/google/cloud/pubsublite/types/paths.py +++ b/google/cloud/pubsublite/types/paths.py @@ -105,4 +105,4 @@ def parse(to_parse: str) -> "ReservationPath": "Reservation path must be formatted like projects/{project_number}/locations/{location}/reservations/{name} but was instead " + to_parse ) - return ReservationPath(splits[1], CloudZone.parse(splits[3]), splits[5]) + return ReservationPath(splits[1], CloudRegion.parse(splits[3]), splits[5]) diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 0ed3c3c6..6617fadd 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -5,6 +5,3 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 -google-cloud-pubsub==2.1.0 -overrides==2.0.0 -packaging==14.3 diff --git a/tests/unit/pubsublite/cloudpubsub/internal/async_client_multiplexer_test.py b/tests/unit/pubsublite/cloudpubsub/internal/async_client_multiplexer_test.py index 926fdb0d..6541f155 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/async_client_multiplexer_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/async_client_multiplexer_test.py @@ -16,8 +16,6 @@ from asynctest.mock import call, CoroutineMock -from google.api_core.exceptions import FailedPrecondition - from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import ( AsyncClientMultiplexer, ) @@ -41,8 +39,8 @@ def client_closer(): @pytest.fixture() -def multiplexer(client_closer): - return AsyncClientMultiplexer(client_closer) +def multiplexer(client_factory, client_closer): + return AsyncClientMultiplexer(client_factory, client_closer) async def test_create( @@ -52,16 +50,12 @@ async def test_create( client2 = Client() async with multiplexer: client_factory.return_value = client1 - assert await multiplexer.create_or_fail(1, client_factory) is client1 - client_factory.assert_has_calls([call()]) + assert await multiplexer.get_or_create(1) is client1 + client_factory.assert_has_calls([call(1)]) client_factory.return_value = client2 - assert await multiplexer.get_or_create(1, client_factory) is client1 - with pytest.raises(FailedPrecondition): - await multiplexer.create_or_fail(1, client_factory) - assert await multiplexer.get_or_create(2, client_factory) is client2 - client_factory.assert_has_calls([call(), call()]) - with pytest.raises(FailedPrecondition): - await multiplexer.create_or_fail(2, client_factory) + assert await multiplexer.get_or_create(1) is client1 + assert await multiplexer.get_or_create(2) is client2 + client_factory.assert_has_calls([call(1), call(2)]) client_closer.assert_has_calls([call(client1), call(client2)], any_order=True) @@ -72,12 +66,12 @@ async def test_recreate( client2 = Client() async with multiplexer: client_factory.return_value = client1 - assert await multiplexer.create_or_fail(1, client_factory) is client1 - client_factory.assert_has_calls([call()]) + assert await multiplexer.get_or_create(1) is client1 + client_factory.assert_has_calls([call(1)]) client_factory.return_value = client2 await multiplexer.try_erase(1, client2) client_closer.assert_has_calls([]) await multiplexer.try_erase(1, client1) client_closer.assert_has_calls([call(client1)]) - assert await multiplexer.create_or_fail(1, client_factory) is client2 + assert await multiplexer.get_or_create(1) is client2 client_closer.assert_has_calls([call(client1), call(client2)]) diff --git a/tests/unit/pubsublite/cloudpubsub/internal/client_multiplexer_test.py b/tests/unit/pubsublite/cloudpubsub/internal/client_multiplexer_test.py index a3bc9305..697b024a 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/client_multiplexer_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/client_multiplexer_test.py @@ -16,9 +16,6 @@ from mock import MagicMock, call -# All test coroutines will be treated as marked. -from google.api_core.exceptions import FailedPrecondition - from google.cloud.pubsublite.cloudpubsub.internal.client_multiplexer import ( ClientMultiplexer, ) @@ -39,8 +36,8 @@ def client_closer(): @pytest.fixture() -def multiplexer(client_closer): - return ClientMultiplexer(client_closer) +def multiplexer(client_factory, client_closer): + return ClientMultiplexer(client_factory, client_closer) def test_create( @@ -50,16 +47,12 @@ def test_create( client2 = Client() with multiplexer: client_factory.return_value = client1 - assert multiplexer.create_or_fail(1, client_factory) is client1 - client_factory.assert_has_calls([call()]) + assert multiplexer.get_or_create(1) is client1 + client_factory.assert_has_calls([call(1)]) client_factory.return_value = client2 - assert multiplexer.get_or_create(1, client_factory) is client1 - with pytest.raises(FailedPrecondition): - multiplexer.create_or_fail(1, client_factory) - assert multiplexer.get_or_create(2, client_factory) is client2 - client_factory.assert_has_calls([call(), call()]) - with pytest.raises(FailedPrecondition): - multiplexer.create_or_fail(2, client_factory) + assert multiplexer.get_or_create(1) is client1 + assert multiplexer.get_or_create(2) is client2 + client_factory.assert_has_calls([call(1), call(2)]) client_closer.assert_has_calls([call(client1), call(client2)], any_order=True) @@ -70,12 +63,12 @@ def test_recreate( client2 = Client() with multiplexer: client_factory.return_value = client1 - assert multiplexer.create_or_fail(1, client_factory) is client1 - client_factory.assert_has_calls([call()]) + assert multiplexer.get_or_create(1) is client1 + client_factory.assert_has_calls([call(1)]) client_factory.return_value = client2 multiplexer.try_erase(1, client2) client_closer.assert_has_calls([]) multiplexer.try_erase(1, client1) client_closer.assert_has_calls([call(client1)]) - assert multiplexer.create_or_fail(1, client_factory) is client2 + assert multiplexer.get_or_create(1) is client2 client_closer.assert_has_calls([call(client1), call(client2)])