Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enforce that __enter__ is called on all user interfaces before use #70

Merged
merged 4 commits into from Nov 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions google/cloud/pubsublite/cloudpubsub/publisher_client.py
Expand Up @@ -36,6 +36,7 @@
from google.cloud.pubsublite.internal.constructable_from_service_account import (
ConstructableFromServiceAccount,
)
from google.cloud.pubsublite.internal.require_started import RequireStarted
from google.cloud.pubsublite.internal.wire.make_publisher import (
DEFAULT_BATCHING_SETTINGS as WIRE_DEFAULT_BATCHING,
)
Expand All @@ -52,6 +53,7 @@ class PublisherClient(PublisherClientInterface, ConstructableFromServiceAccount)
"""

_impl: PublisherClientInterface
_require_stared: RequireStarted

DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING
"""
Expand Down Expand Up @@ -83,6 +85,7 @@ def __init__(
transport=transport,
)
)
self._require_stared = RequireStarted()

@overrides
def publish(
Expand All @@ -92,18 +95,21 @@ def publish(
ordering_key: str = "",
**attrs: Mapping[str, str]
) -> "Future[str]":
self._require_stared.require_started()
return self._impl.publish(
topic=topic, data=data, ordering_key=ordering_key, **attrs
)

@overrides
def __enter__(self):
self._require_stared.__enter__()
self._impl.__enter__()
return self

@overrides
def __exit__(self, exc_type, exc_value, traceback):
self._impl.__exit__(exc_type, exc_value, traceback)
self._require_stared.__exit__(exc_type, exc_value, traceback)


class AsyncPublisherClient(
Expand All @@ -117,6 +123,7 @@ class AsyncPublisherClient(
"""

_impl: AsyncPublisherClientInterface
_require_stared: RequireStarted

DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING
"""
Expand Down Expand Up @@ -148,6 +155,7 @@ def __init__(
transport=transport,
)
)
self._require_stared = RequireStarted()

@overrides
async def publish(
Expand All @@ -157,15 +165,18 @@ async def publish(
ordering_key: str = "",
**attrs: Mapping[str, str]
) -> str:
self._require_stared.require_started()
return await self._impl.publish(
topic=topic, data=data, ordering_key=ordering_key, **attrs
)

@overrides
async def __aenter__(self):
self._require_stared.__enter__()
await self._impl.__aenter__()
return self

@overrides
async def __aexit__(self, exc_type, exc_value, traceback):
await self._impl.__aexit__(exc_type, exc_value, traceback)
self._require_stared.__exit__(exc_type, exc_value, traceback)
11 changes: 11 additions & 0 deletions google/cloud/pubsublite/cloudpubsub/subscriber_client.py
Expand Up @@ -39,6 +39,7 @@
from google.cloud.pubsublite.internal.constructable_from_service_account import (
ConstructableFromServiceAccount,
)
from google.cloud.pubsublite.internal.require_started import RequireStarted
from google.cloud.pubsublite.types import (
FlowControlSettings,
Partition,
Expand All @@ -56,6 +57,7 @@ class SubscriberClient(SubscriberClientInterface, ConstructableFromServiceAccoun
"""

_impl: SubscriberClientInterface
_require_started: RequireStarted

def __init__(
self,
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
client_options=client_options,
),
)
self._require_started = RequireStarted()

@overrides
def subscribe(
Expand All @@ -101,6 +104,7 @@ def subscribe(
per_partition_flow_control_settings: FlowControlSettings,
fixed_partitions: Optional[Set[Partition]] = None,
) -> StreamingPullFuture:
self._require_started.require_started()
return self._impl.subscribe(
subscription,
callback,
Expand All @@ -110,12 +114,14 @@ def subscribe(

@overrides
def __enter__(self):
self._require_started.__enter__()
self._impl.__enter__()
return self

@overrides
def __exit__(self, exc_type, exc_value, traceback):
self._impl.__exit__(exc_type, exc_value, traceback)
self._require_started.__exit__(exc_type, exc_value, traceback)


class AsyncSubscriberClient(
Expand All @@ -130,6 +136,7 @@ class AsyncSubscriberClient(
"""

_impl: AsyncSubscriberClientInterface
_require_started: RequireStarted

def __init__(
self,
Expand Down Expand Up @@ -161,6 +168,7 @@ def __init__(
client_options=client_options,
)
)
self._require_started = RequireStarted()

@overrides
async def subscribe(
Expand All @@ -169,15 +177,18 @@ async def subscribe(
per_partition_flow_control_settings: FlowControlSettings,
fixed_partitions: Optional[Set[Partition]] = None,
) -> AsyncIterator[Message]:
self._require_started.require_started()
return await self._impl.subscribe(
subscription, per_partition_flow_control_settings, fixed_partitions
)

@overrides
async def __aenter__(self):
self._require_started.__enter__()
await self._impl.__aenter__()
return self

@overrides
async def __aexit__(self, exc_type, exc_value, traceback):
await self._impl.__aexit__(exc_type, exc_value, traceback)
self._require_started.__exit__(exc_type, exc_value, traceback)
35 changes: 35 additions & 0 deletions google/cloud/pubsublite/internal/require_started.py
@@ -0,0 +1,35 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import ContextManager

from google.api_core.exceptions import FailedPrecondition


class RequireStarted(ContextManager):
def __init__(self):
self._started = False

def __enter__(self):
if self._started:
raise FailedPrecondition("__enter__ called twice.")
self._started = True
return self

def require_started(self):
if not self._started:
raise FailedPrecondition("__enter__ has never been called.")

def __exit__(self, exc_type, exc_value, traceback):
self.require_started()
29 changes: 14 additions & 15 deletions samples/snippets/subscriber_example.py
Expand Up @@ -59,21 +59,20 @@ def callback(message):
print(f"Received {message_data} of ordering key {message.ordering_key}.")
message.ack()

subscriber_client = SubscriberClient()

streaming_pull_future = subscriber_client.subscribe(
subscription_path,
callback=callback,
per_partition_flow_control_settings=per_partition_flow_control_settings,
)

print(f"Listening for messages on {str(subscription_path)}...")

try:
streaming_pull_future.result(timeout=timeout)
except TimeoutError or KeyboardInterrupt:
streaming_pull_future.cancel()
assert streaming_pull_future.done()
with SubscriberClient() as subscriber_client:
streaming_pull_future = subscriber_client.subscribe(
subscription_path,
callback=callback,
per_partition_flow_control_settings=per_partition_flow_control_settings,
)

print(f"Listening for messages on {str(subscription_path)}...")

try:
streaming_pull_future.result(timeout=timeout)
except TimeoutError or KeyboardInterrupt:
streaming_pull_future.cancel()
assert streaming_pull_future.done()
# [END pubsublite_quickstart_subscriber]


Expand Down