Skip to content

Commit

Permalink
fix: Enforce that __enter__ is called on all user interfaces before u…
Browse files Browse the repository at this point in the history
…se (#70)

* chore: Add license headers to all files

* fix: Enforce that __enter__ is called on all user interfaces before use

* fix: Enforce that __enter__ is called on all user interfaces before use
  • Loading branch information
dpcollins-google committed Nov 18, 2020
1 parent b0407f6 commit eb8d63a
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 15 deletions.
11 changes: 11 additions & 0 deletions google/cloud/pubsublite/cloudpubsub/publisher_client.py
Expand Up @@ -36,6 +36,7 @@
from google.cloud.pubsublite.internal.constructable_from_service_account import (
ConstructableFromServiceAccount,
)
from google.cloud.pubsublite.internal.require_started import RequireStarted
from google.cloud.pubsublite.internal.wire.make_publisher import (
DEFAULT_BATCHING_SETTINGS as WIRE_DEFAULT_BATCHING,
)
Expand All @@ -52,6 +53,7 @@ class PublisherClient(PublisherClientInterface, ConstructableFromServiceAccount)
"""

_impl: PublisherClientInterface
_require_stared: RequireStarted

DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING
"""
Expand Down Expand Up @@ -83,6 +85,7 @@ def __init__(
transport=transport,
)
)
self._require_stared = RequireStarted()

@overrides
def publish(
Expand All @@ -92,18 +95,21 @@ def publish(
ordering_key: str = "",
**attrs: Mapping[str, str]
) -> "Future[str]":
self._require_stared.require_started()
return self._impl.publish(
topic=topic, data=data, ordering_key=ordering_key, **attrs
)

@overrides
def __enter__(self):
self._require_stared.__enter__()
self._impl.__enter__()
return self

@overrides
def __exit__(self, exc_type, exc_value, traceback):
self._impl.__exit__(exc_type, exc_value, traceback)
self._require_stared.__exit__(exc_type, exc_value, traceback)


class AsyncPublisherClient(
Expand All @@ -117,6 +123,7 @@ class AsyncPublisherClient(
"""

_impl: AsyncPublisherClientInterface
_require_stared: RequireStarted

DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING
"""
Expand Down Expand Up @@ -148,6 +155,7 @@ def __init__(
transport=transport,
)
)
self._require_stared = RequireStarted()

@overrides
async def publish(
Expand All @@ -157,15 +165,18 @@ async def publish(
ordering_key: str = "",
**attrs: Mapping[str, str]
) -> str:
self._require_stared.require_started()
return await self._impl.publish(
topic=topic, data=data, ordering_key=ordering_key, **attrs
)

@overrides
async def __aenter__(self):
self._require_stared.__enter__()
await self._impl.__aenter__()
return self

@overrides
async def __aexit__(self, exc_type, exc_value, traceback):
await self._impl.__aexit__(exc_type, exc_value, traceback)
self._require_stared.__exit__(exc_type, exc_value, traceback)
11 changes: 11 additions & 0 deletions google/cloud/pubsublite/cloudpubsub/subscriber_client.py
Expand Up @@ -39,6 +39,7 @@
from google.cloud.pubsublite.internal.constructable_from_service_account import (
ConstructableFromServiceAccount,
)
from google.cloud.pubsublite.internal.require_started import RequireStarted
from google.cloud.pubsublite.types import (
FlowControlSettings,
Partition,
Expand All @@ -56,6 +57,7 @@ class SubscriberClient(SubscriberClientInterface, ConstructableFromServiceAccoun
"""

_impl: SubscriberClientInterface
_require_started: RequireStarted

def __init__(
self,
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
client_options=client_options,
),
)
self._require_started = RequireStarted()

@overrides
def subscribe(
Expand All @@ -101,6 +104,7 @@ def subscribe(
per_partition_flow_control_settings: FlowControlSettings,
fixed_partitions: Optional[Set[Partition]] = None,
) -> StreamingPullFuture:
self._require_started.require_started()
return self._impl.subscribe(
subscription,
callback,
Expand All @@ -110,12 +114,14 @@ def subscribe(

@overrides
def __enter__(self):
self._require_started.__enter__()
self._impl.__enter__()
return self

@overrides
def __exit__(self, exc_type, exc_value, traceback):
self._impl.__exit__(exc_type, exc_value, traceback)
self._require_started.__exit__(exc_type, exc_value, traceback)


class AsyncSubscriberClient(
Expand All @@ -130,6 +136,7 @@ class AsyncSubscriberClient(
"""

_impl: AsyncSubscriberClientInterface
_require_started: RequireStarted

def __init__(
self,
Expand Down Expand Up @@ -161,6 +168,7 @@ def __init__(
client_options=client_options,
)
)
self._require_started = RequireStarted()

@overrides
async def subscribe(
Expand All @@ -169,15 +177,18 @@ async def subscribe(
per_partition_flow_control_settings: FlowControlSettings,
fixed_partitions: Optional[Set[Partition]] = None,
) -> AsyncIterator[Message]:
self._require_started.require_started()
return await self._impl.subscribe(
subscription, per_partition_flow_control_settings, fixed_partitions
)

@overrides
async def __aenter__(self):
self._require_started.__enter__()
await self._impl.__aenter__()
return self

@overrides
async def __aexit__(self, exc_type, exc_value, traceback):
await self._impl.__aexit__(exc_type, exc_value, traceback)
self._require_started.__exit__(exc_type, exc_value, traceback)
35 changes: 35 additions & 0 deletions google/cloud/pubsublite/internal/require_started.py
@@ -0,0 +1,35 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import ContextManager

from google.api_core.exceptions import FailedPrecondition


class RequireStarted(ContextManager):
def __init__(self):
self._started = False

def __enter__(self):
if self._started:
raise FailedPrecondition("__enter__ called twice.")
self._started = True
return self

def require_started(self):
if not self._started:
raise FailedPrecondition("__enter__ has never been called.")

def __exit__(self, exc_type, exc_value, traceback):
self.require_started()
29 changes: 14 additions & 15 deletions samples/snippets/subscriber_example.py
Expand Up @@ -59,21 +59,20 @@ def callback(message):
print(f"Received {message_data} of ordering key {message.ordering_key}.")
message.ack()

subscriber_client = SubscriberClient()

streaming_pull_future = subscriber_client.subscribe(
subscription_path,
callback=callback,
per_partition_flow_control_settings=per_partition_flow_control_settings,
)

print(f"Listening for messages on {str(subscription_path)}...")

try:
streaming_pull_future.result(timeout=timeout)
except TimeoutError or KeyboardInterrupt:
streaming_pull_future.cancel()
assert streaming_pull_future.done()
with SubscriberClient() as subscriber_client:
streaming_pull_future = subscriber_client.subscribe(
subscription_path,
callback=callback,
per_partition_flow_control_settings=per_partition_flow_control_settings,
)

print(f"Listening for messages on {str(subscription_path)}...")

try:
streaming_pull_future.result(timeout=timeout)
except TimeoutError or KeyboardInterrupt:
streaming_pull_future.cancel()
assert streaming_pull_future.done()
# [END pubsublite_quickstart_subscriber]


Expand Down

0 comments on commit eb8d63a

Please sign in to comment.