diff --git a/google/cloud/pubsublite/cloudpubsub/publisher_client.py b/google/cloud/pubsublite/cloudpubsub/publisher_client.py index 2463d302..9ddd4da5 100644 --- a/google/cloud/pubsublite/cloudpubsub/publisher_client.py +++ b/google/cloud/pubsublite/cloudpubsub/publisher_client.py @@ -36,6 +36,7 @@ from google.cloud.pubsublite.internal.constructable_from_service_account import ( ConstructableFromServiceAccount, ) +from google.cloud.pubsublite.internal.require_started import RequireStarted from google.cloud.pubsublite.internal.wire.make_publisher import ( DEFAULT_BATCHING_SETTINGS as WIRE_DEFAULT_BATCHING, ) @@ -52,6 +53,7 @@ class PublisherClient(PublisherClientInterface, ConstructableFromServiceAccount) """ _impl: PublisherClientInterface + _require_stared: RequireStarted DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING """ @@ -83,6 +85,7 @@ def __init__( transport=transport, ) ) + self._require_stared = RequireStarted() @overrides def publish( @@ -92,18 +95,21 @@ def publish( ordering_key: str = "", **attrs: Mapping[str, str] ) -> "Future[str]": + self._require_stared.require_started() return self._impl.publish( topic=topic, data=data, ordering_key=ordering_key, **attrs ) @overrides def __enter__(self): + self._require_stared.__enter__() self._impl.__enter__() return self @overrides def __exit__(self, exc_type, exc_value, traceback): self._impl.__exit__(exc_type, exc_value, traceback) + self._require_stared.__exit__(exc_type, exc_value, traceback) class AsyncPublisherClient( @@ -117,6 +123,7 @@ class AsyncPublisherClient( """ _impl: AsyncPublisherClientInterface + _require_stared: RequireStarted DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING """ @@ -148,6 +155,7 @@ def __init__( transport=transport, ) ) + self._require_stared = RequireStarted() @overrides async def publish( @@ -157,15 +165,18 @@ async def publish( ordering_key: str = "", **attrs: Mapping[str, str] ) -> str: + self._require_stared.require_started() return await self._impl.publish( topic=topic, data=data, ordering_key=ordering_key, **attrs ) @overrides async def __aenter__(self): + self._require_stared.__enter__() await self._impl.__aenter__() return self @overrides async def __aexit__(self, exc_type, exc_value, traceback): await self._impl.__aexit__(exc_type, exc_value, traceback) + self._require_stared.__exit__(exc_type, exc_value, traceback) diff --git a/google/cloud/pubsublite/cloudpubsub/subscriber_client.py b/google/cloud/pubsublite/cloudpubsub/subscriber_client.py index 717369ef..96134938 100644 --- a/google/cloud/pubsublite/cloudpubsub/subscriber_client.py +++ b/google/cloud/pubsublite/cloudpubsub/subscriber_client.py @@ -39,6 +39,7 @@ from google.cloud.pubsublite.internal.constructable_from_service_account import ( ConstructableFromServiceAccount, ) +from google.cloud.pubsublite.internal.require_started import RequireStarted from google.cloud.pubsublite.types import ( FlowControlSettings, Partition, @@ -56,6 +57,7 @@ class SubscriberClient(SubscriberClientInterface, ConstructableFromServiceAccoun """ _impl: SubscriberClientInterface + _require_started: RequireStarted def __init__( self, @@ -92,6 +94,7 @@ def __init__( client_options=client_options, ), ) + self._require_started = RequireStarted() @overrides def subscribe( @@ -101,6 +104,7 @@ def subscribe( per_partition_flow_control_settings: FlowControlSettings, fixed_partitions: Optional[Set[Partition]] = None, ) -> StreamingPullFuture: + self._require_started.require_started() return self._impl.subscribe( subscription, callback, @@ -110,12 +114,14 @@ def subscribe( @overrides def __enter__(self): + self._require_started.__enter__() self._impl.__enter__() return self @overrides def __exit__(self, exc_type, exc_value, traceback): self._impl.__exit__(exc_type, exc_value, traceback) + self._require_started.__exit__(exc_type, exc_value, traceback) class AsyncSubscriberClient( @@ -130,6 +136,7 @@ class AsyncSubscriberClient( """ _impl: AsyncSubscriberClientInterface + _require_started: RequireStarted def __init__( self, @@ -161,6 +168,7 @@ def __init__( client_options=client_options, ) ) + self._require_started = RequireStarted() @overrides async def subscribe( @@ -169,15 +177,18 @@ async def subscribe( per_partition_flow_control_settings: FlowControlSettings, fixed_partitions: Optional[Set[Partition]] = None, ) -> AsyncIterator[Message]: + self._require_started.require_started() return await self._impl.subscribe( subscription, per_partition_flow_control_settings, fixed_partitions ) @overrides async def __aenter__(self): + self._require_started.__enter__() await self._impl.__aenter__() return self @overrides async def __aexit__(self, exc_type, exc_value, traceback): await self._impl.__aexit__(exc_type, exc_value, traceback) + self._require_started.__exit__(exc_type, exc_value, traceback) diff --git a/google/cloud/pubsublite/internal/require_started.py b/google/cloud/pubsublite/internal/require_started.py new file mode 100644 index 00000000..b7d4337b --- /dev/null +++ b/google/cloud/pubsublite/internal/require_started.py @@ -0,0 +1,35 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ContextManager + +from google.api_core.exceptions import FailedPrecondition + + +class RequireStarted(ContextManager): + def __init__(self): + self._started = False + + def __enter__(self): + if self._started: + raise FailedPrecondition("__enter__ called twice.") + self._started = True + return self + + def require_started(self): + if not self._started: + raise FailedPrecondition("__enter__ has never been called.") + + def __exit__(self, exc_type, exc_value, traceback): + self.require_started() diff --git a/samples/snippets/subscriber_example.py b/samples/snippets/subscriber_example.py index ed2fba41..a50c0f72 100644 --- a/samples/snippets/subscriber_example.py +++ b/samples/snippets/subscriber_example.py @@ -59,21 +59,20 @@ def callback(message): print(f"Received {message_data} of ordering key {message.ordering_key}.") message.ack() - subscriber_client = SubscriberClient() - - streaming_pull_future = subscriber_client.subscribe( - subscription_path, - callback=callback, - per_partition_flow_control_settings=per_partition_flow_control_settings, - ) - - print(f"Listening for messages on {str(subscription_path)}...") - - try: - streaming_pull_future.result(timeout=timeout) - except TimeoutError or KeyboardInterrupt: - streaming_pull_future.cancel() - assert streaming_pull_future.done() + with SubscriberClient() as subscriber_client: + streaming_pull_future = subscriber_client.subscribe( + subscription_path, + callback=callback, + per_partition_flow_control_settings=per_partition_flow_control_settings, + ) + + print(f"Listening for messages on {str(subscription_path)}...") + + try: + streaming_pull_future.result(timeout=timeout) + except TimeoutError or KeyboardInterrupt: + streaming_pull_future.cancel() + assert streaming_pull_future.done() # [END pubsublite_quickstart_subscriber]