diff --git a/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py index 6970d1f2..4d4f52ce 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/make_subscriber.py @@ -17,6 +17,11 @@ from google.api_core.client_options import ClientOptions from google.auth.credentials import Credentials + +from google.cloud.pubsublite.cloudpubsub.message_transforms import ( + to_cps_subscribe_message, + add_id_to_cps_subscribe_transformer, +) from google.cloud.pubsublite.types import FlowControlSettings from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import ( AckSetTrackerImpl, @@ -28,10 +33,7 @@ from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import ( SinglePartitionSingleSubscriber, ) -from google.cloud.pubsublite.cloudpubsub.message_transformer import ( - MessageTransformer, - DefaultMessageTransformer, -) +from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer from google.cloud.pubsublite.cloudpubsub.nack_handler import ( NackHandler, DefaultNackHandler, @@ -149,7 +151,7 @@ def cursor_connection_factory( flow_control_settings, ack_set_tracker, nack_handler, - message_transformer, + add_id_to_cps_subscribe_transformer(partition, message_transformer), ) return factory @@ -200,7 +202,7 @@ def make_async_subscriber( if nack_handler is None: nack_handler = DefaultNackHandler() if message_transformer is None: - message_transformer = DefaultMessageTransformer() + message_transformer = MessageTransformer.of_callable(to_cps_subscribe_message) partition_subscriber_factory = _make_partition_subscriber_factory( subscription, transport, diff --git a/google/cloud/pubsublite/cloudpubsub/message_transformer.py b/google/cloud/pubsublite/cloudpubsub/message_transformer.py index 3dfb6c86..d0b75bd5 100644 --- a/google/cloud/pubsublite/cloudpubsub/message_transformer.py +++ b/google/cloud/pubsublite/cloudpubsub/message_transformer.py @@ -13,12 +13,11 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Callable from google.pubsub_v1 import PubsubMessage +from overrides import overrides -from google.cloud.pubsublite.cloudpubsub.message_transforms import ( - to_cps_subscribe_message, -) from google.cloud.pubsublite_v1 import SequencedMessage @@ -39,7 +38,11 @@ def transform(self, source: SequencedMessage) -> PubsubMessage: """ pass + @staticmethod + def of_callable(transformer: Callable[[SequencedMessage], PubsubMessage]): + class CallableTransformer(MessageTransformer): + @overrides + def transform(self, source: SequencedMessage) -> PubsubMessage: + return transformer(source) -class DefaultMessageTransformer(MessageTransformer): - def transform(self, source: SequencedMessage) -> PubsubMessage: - return to_cps_subscribe_message(source) + return CallableTransformer() diff --git a/google/cloud/pubsublite/cloudpubsub/message_transforms.py b/google/cloud/pubsublite/cloudpubsub/message_transforms.py index ba6308e0..e3d9a895 100644 --- a/google/cloud/pubsublite/cloudpubsub/message_transforms.py +++ b/google/cloud/pubsublite/cloudpubsub/message_transforms.py @@ -18,6 +18,8 @@ from google.protobuf.timestamp_pb2 import Timestamp from google.pubsub_v1 import PubsubMessage +from google.cloud.pubsublite.cloudpubsub import MessageTransformer +from google.cloud.pubsublite.types import Partition, PublishMetadata from google.cloud.pubsublite_v1 import AttributeValues, SequencedMessage, PubSubMessage PUBSUB_LITE_EVENT_TIME = "x-goog-pubsublite-event-time" @@ -52,9 +54,23 @@ def _parse_attributes(values: AttributeValues) -> str: ) +def add_id_to_cps_subscribe_transformer( + partition: Partition, transformer: MessageTransformer +) -> MessageTransformer: + def add_id_to_message(source: SequencedMessage): + message: PubsubMessage = transformer.transform(source) + if message.message_id: + raise InvalidArgument( + "Message after transforming has the message_id field set." + ) + message.message_id = PublishMetadata(partition, source.cursor).encode() + return message + + return MessageTransformer.of_callable(add_id_to_message) + + def to_cps_subscribe_message(source: SequencedMessage) -> PubsubMessage: message: PubsubMessage = to_cps_publish_message(source.message) - message.message_id = str(source.cursor.offset) message.publish_time = source.publish_time return message diff --git a/google/cloud/pubsublite/cloudpubsub/subscriber_client.py b/google/cloud/pubsublite/cloudpubsub/subscriber_client.py index 96134938..13705beb 100644 --- a/google/cloud/pubsublite/cloudpubsub/subscriber_client.py +++ b/google/cloud/pubsublite/cloudpubsub/subscriber_client.py @@ -74,7 +74,7 @@ def __init__( Args: executor: A ThreadPoolExecutor to use. The client will shut it down on __exit__. If provided a single threaded executor, messages will be ordered per-partition, but take care that the callback does not block for too long as it will impede forward progress on all subscriptions. nack_handler: A handler for when `nack()` is called. The default NackHandler raises an exception and fails the subscribe stream. - message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages. + message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages. This may not return a message with "message_id" set. credentials: If provided, the credentials to use when connecting. transport: The transport to use. Must correspond to an asyncio transport. client_options: The client options to use when connecting. If used, must explicitly set `api_endpoint`. @@ -151,7 +151,7 @@ def __init__( Args: nack_handler: A handler for when `nack()` is called. The default NackHandler raises an exception and fails the subscribe stream. - message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages. + message_transformer: A transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages. This may not return a message with "message_id" set. credentials: If provided, the credentials to use when connecting. transport: The transport to use. Must correspond to an asyncio transport. client_options: The client options to use when connecting. If used, must explicitly set `api_endpoint`. diff --git a/samples/snippets/subscriber_example.py b/samples/snippets/subscriber_example.py index b4e32ee5..a8c9b5d3 100644 --- a/samples/snippets/subscriber_example.py +++ b/samples/snippets/subscriber_example.py @@ -21,6 +21,10 @@ import argparse +from google.pubsub_v1 import PubsubMessage + +from google.cloud.pubsublite.types import PublishMetadata + def receive_messages( project_number, cloud_region, zone_id, subscription_id, timeout=90 @@ -54,9 +58,10 @@ def receive_messages( bytes_outstanding=10 * 1024 * 1024, ) - def callback(message): + def callback(message: PubsubMessage): message_data = message.data.decode("utf-8") - print(f"Received {message_data} of ordering key {message.ordering_key}.") + metadata = PublishMetadata.decode(message.message_id) + print(f"Received {message_data} of ordering key {message.ordering_key} with id {metadata}.") message.ack() # SubscriberClient() must be used in a `with` block or have __enter__() called before use. diff --git a/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py b/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py index 79478a4f..43e71e90 100644 --- a/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py +++ b/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py @@ -19,12 +19,15 @@ from google.protobuf.timestamp_pb2 import Timestamp from google.pubsub_v1 import PubsubMessage +from google.cloud.pubsublite.cloudpubsub import MessageTransformer from google.cloud.pubsublite.cloudpubsub.message_transforms import ( PUBSUB_LITE_EVENT_TIME, to_cps_subscribe_message, encode_attribute_event_time, from_cps_publish_message, + add_id_to_cps_subscribe_transformer, ) +from google.cloud.pubsublite.types import Partition, PublishMetadata from google.cloud.pubsublite_v1 import ( SequencedMessage, Cursor, @@ -104,7 +107,6 @@ def test_subscribe_transform_correct(): Timestamp(seconds=55).ToDatetime() ), }, - message_id=str(10), publish_time=Timestamp(seconds=10), ) result = to_cps_subscribe_message( @@ -126,6 +128,66 @@ def test_subscribe_transform_correct(): assert result == expected +def test_wrapped_sets_id_error(): + wrapped = add_id_to_cps_subscribe_transformer( + Partition(1), + MessageTransformer.of_callable(lambda x: PubsubMessage(message_id="a")), + ) + with pytest.raises(InvalidArgument): + wrapped.transform( + SequencedMessage( + message=PubSubMessage( + data=b"xyz", + key=b"def", + event_time=Timestamp(seconds=55), + attributes={ + "x": AttributeValues(values=[b"abc"]), + "y": AttributeValues(values=[b"abc"]), + }, + ), + publish_time=Timestamp(seconds=10), + cursor=Cursor(offset=10), + size_bytes=10, + ) + ) + + +def test_wrapped_successful(): + wrapped = add_id_to_cps_subscribe_transformer( + Partition(1), MessageTransformer.of_callable(to_cps_subscribe_message) + ) + expected = PubsubMessage( + data=b"xyz", + ordering_key="def", + attributes={ + "x": "abc", + "y": "abc", + PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time( + Timestamp(seconds=55).ToDatetime() + ), + }, + message_id=PublishMetadata(Partition(1), Cursor(offset=10)).encode(), + publish_time=Timestamp(seconds=10), + ) + result = wrapped.transform( + SequencedMessage( + message=PubSubMessage( + data=b"xyz", + key=b"def", + event_time=Timestamp(seconds=55), + attributes={ + "x": AttributeValues(values=[b"abc"]), + "y": AttributeValues(values=[b"abc"]), + }, + ), + publish_time=Timestamp(seconds=10), + cursor=Cursor(offset=10), + size_bytes=10, + ) + ) + assert result == expected + + def test_publish_invalid_event_time(): with pytest.raises(InvalidArgument): from_cps_publish_message(