Skip to content

Commit

Permalink
feat: Make message_id encode a PublishMetadata which includes the par…
Browse files Browse the repository at this point in the history
…tition (#90)

* feat: Make message_id encode a PublishMetadata which includes the partition

* feat: Make message_id encode a PublishMetadata which includes the partition
  • Loading branch information
dpcollins-google committed Feb 9, 2021
1 parent aa7105d commit 85944e7
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 18 deletions.
14 changes: 8 additions & 6 deletions google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py
Expand Up @@ -17,6 +17,11 @@

from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials

from google.cloud.pubsublite.cloudpubsub.message_transforms import (
to_cps_subscribe_message,
add_id_to_cps_subscribe_transformer,
)
from google.cloud.pubsublite.types import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import (
AckSetTrackerImpl,
Expand All @@ -28,10 +33,7 @@
from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import (
SinglePartitionSingleSubscriber,
)
from google.cloud.pubsublite.cloudpubsub.message_transformer import (
MessageTransformer,
DefaultMessageTransformer,
)
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer
from google.cloud.pubsublite.cloudpubsub.nack_handler import (
NackHandler,
DefaultNackHandler,
Expand Down Expand Up @@ -149,7 +151,7 @@ def cursor_connection_factory(
flow_control_settings,
ack_set_tracker,
nack_handler,
message_transformer,
add_id_to_cps_subscribe_transformer(partition, message_transformer),
)

return factory
Expand Down Expand Up @@ -200,7 +202,7 @@ def make_async_subscriber(
if nack_handler is None:
nack_handler = DefaultNackHandler()
if message_transformer is None:
message_transformer = DefaultMessageTransformer()
message_transformer = MessageTransformer.of_callable(to_cps_subscribe_message)
partition_subscriber_factory = _make_partition_subscriber_factory(
subscription,
transport,
Expand Down
15 changes: 9 additions & 6 deletions google/cloud/pubsublite/cloudpubsub/message_transformer.py
Expand Up @@ -13,12 +13,11 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Callable

from google.pubsub_v1 import PubsubMessage
from overrides import overrides

from google.cloud.pubsublite.cloudpubsub.message_transforms import (
to_cps_subscribe_message,
)
from google.cloud.pubsublite_v1 import SequencedMessage


Expand All @@ -39,7 +38,11 @@ def transform(self, source: SequencedMessage) -> PubsubMessage:
"""
pass

@staticmethod
def of_callable(transformer: Callable[[SequencedMessage], PubsubMessage]):
class CallableTransformer(MessageTransformer):
@overrides
def transform(self, source: SequencedMessage) -> PubsubMessage:
return transformer(source)

class DefaultMessageTransformer(MessageTransformer):
def transform(self, source: SequencedMessage) -> PubsubMessage:
return to_cps_subscribe_message(source)
return CallableTransformer()
18 changes: 17 additions & 1 deletion google/cloud/pubsublite/cloudpubsub/message_transforms.py
Expand Up @@ -18,6 +18,8 @@
from google.protobuf.timestamp_pb2 import Timestamp
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.cloudpubsub import MessageTransformer
from google.cloud.pubsublite.types import Partition, PublishMetadata
from google.cloud.pubsublite_v1 import AttributeValues, SequencedMessage, PubSubMessage

PUBSUB_LITE_EVENT_TIME = "x-goog-pubsublite-event-time"
Expand Down Expand Up @@ -52,9 +54,23 @@ def _parse_attributes(values: AttributeValues) -> str:
)


def add_id_to_cps_subscribe_transformer(
partition: Partition, transformer: MessageTransformer
) -> MessageTransformer:
def add_id_to_message(source: SequencedMessage):
message: PubsubMessage = transformer.transform(source)
if message.message_id:
raise InvalidArgument(
"Message after transforming has the message_id field set."
)
message.message_id = PublishMetadata(partition, source.cursor).encode()
return message

return MessageTransformer.of_callable(add_id_to_message)


def to_cps_subscribe_message(source: SequencedMessage) -> PubsubMessage:
message: PubsubMessage = to_cps_publish_message(source.message)
message.message_id = str(source.cursor.offset)
message.publish_time = source.publish_time
return message

Expand Down
4 changes: 2 additions & 2 deletions google/cloud/pubsublite/cloudpubsub/subscriber_client.py
Expand Up @@ -74,7 +74,7 @@ def __init__(
Args:
executor: A ThreadPoolExecutor to use. The client will shut it down on __exit__. If provided a single threaded executor, messages will be ordered per-partition, but take care that the callback does not block for too long as it will impede forward progress on all subscriptions.
nack_handler: A handler for when `nack()` is called. The default NackHandler raises an exception and fails the subscribe stream.
message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages.
message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages. This may not return a message with "message_id" set.
credentials: If provided, the credentials to use when connecting.
transport: The transport to use. Must correspond to an asyncio transport.
client_options: The client options to use when connecting. If used, must explicitly set `api_endpoint`.
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
Args:
nack_handler: A handler for when `nack()` is called. The default NackHandler raises an exception and fails the subscribe stream.
message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages.
message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages. This may not return a message with "message_id" set.
credentials: If provided, the credentials to use when connecting.
transport: The transport to use. Must correspond to an asyncio transport.
client_options: The client options to use when connecting. If used, must explicitly set `api_endpoint`.
Expand Down
9 changes: 7 additions & 2 deletions samples/snippets/subscriber_example.py
Expand Up @@ -21,6 +21,10 @@

import argparse

from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.types import PublishMetadata


def receive_messages(
project_number, cloud_region, zone_id, subscription_id, timeout=90
Expand Down Expand Up @@ -54,9 +58,10 @@ def receive_messages(
bytes_outstanding=10 * 1024 * 1024,
)

def callback(message):
def callback(message: PubsubMessage):
message_data = message.data.decode("utf-8")
print(f"Received {message_data} of ordering key {message.ordering_key}.")
metadata = PublishMetadata.decode(message.message_id)
print(f"Received {message_data} of ordering key {message.ordering_key} with id {metadata}.")
message.ack()

# SubscriberClient() must be used in a `with` block or have __enter__() called before use.
Expand Down
64 changes: 63 additions & 1 deletion tests/unit/pubsublite/cloudpubsub/message_transforms_test.py
Expand Up @@ -19,12 +19,15 @@
from google.protobuf.timestamp_pb2 import Timestamp
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.cloudpubsub import MessageTransformer
from google.cloud.pubsublite.cloudpubsub.message_transforms import (
PUBSUB_LITE_EVENT_TIME,
to_cps_subscribe_message,
encode_attribute_event_time,
from_cps_publish_message,
add_id_to_cps_subscribe_transformer,
)
from google.cloud.pubsublite.types import Partition, PublishMetadata
from google.cloud.pubsublite_v1 import (
SequencedMessage,
Cursor,
Expand Down Expand Up @@ -104,7 +107,6 @@ def test_subscribe_transform_correct():
Timestamp(seconds=55).ToDatetime()
),
},
message_id=str(10),
publish_time=Timestamp(seconds=10),
)
result = to_cps_subscribe_message(
Expand All @@ -126,6 +128,66 @@ def test_subscribe_transform_correct():
assert result == expected


def test_wrapped_sets_id_error():
wrapped = add_id_to_cps_subscribe_transformer(
Partition(1),
MessageTransformer.of_callable(lambda x: PubsubMessage(message_id="a")),
)
with pytest.raises(InvalidArgument):
wrapped.transform(
SequencedMessage(
message=PubSubMessage(
data=b"xyz",
key=b"def",
event_time=Timestamp(seconds=55),
attributes={
"x": AttributeValues(values=[b"abc"]),
"y": AttributeValues(values=[b"abc"]),
},
),
publish_time=Timestamp(seconds=10),
cursor=Cursor(offset=10),
size_bytes=10,
)
)


def test_wrapped_successful():
wrapped = add_id_to_cps_subscribe_transformer(
Partition(1), MessageTransformer.of_callable(to_cps_subscribe_message)
)
expected = PubsubMessage(
data=b"xyz",
ordering_key="def",
attributes={
"x": "abc",
"y": "abc",
PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time(
Timestamp(seconds=55).ToDatetime()
),
},
message_id=PublishMetadata(Partition(1), Cursor(offset=10)).encode(),
publish_time=Timestamp(seconds=10),
)
result = wrapped.transform(
SequencedMessage(
message=PubSubMessage(
data=b"xyz",
key=b"def",
event_time=Timestamp(seconds=55),
attributes={
"x": AttributeValues(values=[b"abc"]),
"y": AttributeValues(values=[b"abc"]),
},
),
publish_time=Timestamp(seconds=10),
cursor=Cursor(offset=10),
size_bytes=10,
)
)
assert result == expected


def test_publish_invalid_event_time():
with pytest.raises(InvalidArgument):
from_cps_publish_message(
Expand Down

0 comments on commit 85944e7

Please sign in to comment.