Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: performance issues with subscriber client #232

Merged
merged 2 commits into from Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from asyncio import Future, Queue, ensure_future
from typing import Callable, NamedTuple, Dict, Set, Optional
from typing import Callable, NamedTuple, Dict, List, Set, Optional

from google.cloud.pubsub_v1.subscriber.message import Message

Expand Down Expand Up @@ -41,7 +41,7 @@ class AssigningSingleSubscriber(AsyncSingleSubscriber, PermanentFailable):

# Lazily initialized to ensure they are initialized on the thread where __aenter__ is called.
_assigner: Optional[Assigner]
_messages: Optional["Queue[Message]"]
_batches: Optional["Queue[List[Message]]"]
_assign_poller: Future

def __init__(
Expand All @@ -58,14 +58,14 @@ def __init__(
self._assigner = None
self._subscriber_factory = subscriber_factory
self._subscribers = {}
self._messages = None
self._batches = None

async def read(self) -> Message:
return await self.await_unless_failed(self._messages.get())
async def read(self) -> List[Message]:
return await self.await_unless_failed(self._batches.get())

async def _subscribe_action(self, subscriber: AsyncSingleSubscriber):
message = await subscriber.read()
await self._messages.put(message)
batch = await subscriber.read()
await self._batches.put(batch)

async def _start_subscriber(self, partition: Partition):
new_subscriber = self._subscriber_factory(partition)
Expand All @@ -92,7 +92,7 @@ async def _assign_action(self):
await self._stop_subscriber(subscriber)

async def __aenter__(self):
self._messages = Queue()
self._batches = Queue()
self._assigner = self._assigner_factory()
await self._assigner.__aenter__()
self._assign_poller = ensure_future(self.run_poller(self._assign_action))
Expand Down
Expand Up @@ -38,27 +38,17 @@
from overrides import overrides


class _SubscriberAsyncIterator(AsyncIterator):
_subscriber: AsyncSingleSubscriber
_on_failure: Callable[[], Awaitable[None]]

def __init__(
self,
subscriber: AsyncSingleSubscriber,
on_failure: Callable[[], Awaitable[None]],
):
self._subscriber = subscriber
self._on_failure = on_failure

async def __anext__(self) -> Message:
try:
return await self._subscriber.read()
except: # noqa: E722
await self._on_failure()
raise

def __aiter__(self):
return self
async def _iterate_subscriber(
subscriber: AsyncSingleSubscriber, on_failure: Callable[[], Awaitable[None]]
) -> AsyncIterator[Message]:
try:
while True:
batch = await subscriber.read()
for message in batch:
yield message
except: # noqa: E722
await on_failure()
raise


class MultiplexedAsyncSubscriberClient(AsyncSubscriberClientInterface):
Expand All @@ -85,7 +75,7 @@ async def subscribe(
await subscriber.__aenter__()
self._live_clients.add(subscriber)

return _SubscriberAsyncIterator(
return _iterate_subscriber(
subscriber, lambda: self._try_remove_client(subscriber)
)

Expand Down
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

import asyncio
import json
from typing import Callable, Union, Dict, NamedTuple
from typing import Callable, Union, List, Dict, NamedTuple
import queue

from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.permanent_failable import adapt_error
from google.cloud.pubsublite.internal import fast_serialize
from google.cloud.pubsublite.types import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer
Expand All @@ -47,15 +48,13 @@ class _AckId(NamedTuple):
generation: int
offset: int

def str(self) -> str:
return json.dumps({"generation": self.generation, "offset": self.offset})
def encode(self) -> str:
return fast_serialize.dump([self.generation, self.offset])

@staticmethod
def parse(payload: str) -> "_AckId": # pytype: disable=invalid-annotation
loaded = json.loads(payload)
return _AckId(
generation=int(loaded["generation"]), offset=int(loaded["offset"]),
)
loaded = fast_serialize.load(payload)
return _AckId(generation=loaded[0], offset=loaded[1])


ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber]
Expand Down Expand Up @@ -99,26 +98,31 @@ async def handle_reset(self):
self._ack_generation_id += 1
await self._ack_set_tracker.clear_and_commit()

async def read(self) -> Message:
def _wrap_message(self, message: SequencedMessage.meta.pb) -> Message:
# Rewrap in the proto-plus-python wrapper for passing to the transform
rewrapped = SequencedMessage()
rewrapped._pb = message
cps_message = self._transformer.transform(rewrapped)
offset = message.cursor.offset
ack_id_str = _AckId(self._ack_generation_id, offset).encode()
self._ack_set_tracker.track(offset)
self._messages_by_ack_id[ack_id_str] = _SizedMessage(
cps_message, message.size_bytes
)
wrapped_message = Message(
cps_message._pb,
ack_id=ack_id_str,
delivery_attempt=0,
request_queue=self._queue,
)
return wrapped_message

async def read(self) -> List[Message]:
try:
message: SequencedMessage = await self.await_unless_failed(
self._underlying.read()
)
cps_message = self._transformer.transform(message)
offset = message.cursor.offset
ack_id = _AckId(self._ack_generation_id, offset)
self._ack_set_tracker.track(offset)
self._messages_by_ack_id[ack_id.str()] = _SizedMessage(
cps_message, message.size_bytes
)
wrapped_message = Message(
cps_message._pb,
ack_id=ack_id.str(),
delivery_attempt=0,
request_queue=self._queue,
)
return wrapped_message
except GoogleAPICallError as e:
latest_batch = await self.await_unless_failed(self._underlying.read())
return [self._wrap_message(message) for message in latest_batch]
except Exception as e:
e = adapt_error(e) # This could be from user code
self.fail(e)
raise e

Expand Down
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod, ABCMeta
from typing import AsyncContextManager, Callable, Set, Optional
from typing import AsyncContextManager, Callable, List, Set, Optional

from google.cloud.pubsub_v1.subscriber.message import Message

Expand All @@ -32,12 +32,13 @@ class AsyncSingleSubscriber(AsyncContextManager, metaclass=ABCMeta):
"""

@abstractmethod
async def read(self) -> Message:
async def read(self) -> List[Message]:
"""
Read the next message off of the stream.
Read the next batch off of the stream.

Returns:
The next message. ack() or nack() must eventually be called exactly once.
The next batch of messages. ack() or nack() must eventually be called
exactly once on each message.

Pub/Sub Lite does not support nack() by default- if you do call nack(), it will immediately fail the client
unless you have a NackHandler installed.
Expand Down
Expand Up @@ -84,8 +84,8 @@ def _fail(self, error: GoogleAPICallError):
async def _poller(self):
try:
while True:
message = await self._underlying.read()
self._unowned_executor.submit(self._callback, message)
batch = await self._underlying.read()
self._unowned_executor.map(self._callback, batch)
except GoogleAPICallError as e: # noqa: F841 Flake8 thinks e is unused
self._unowned_executor.submit(lambda: self._fail(e)) # noqa: F821

Expand Down
82 changes: 58 additions & 24 deletions google/cloud/pubsublite/cloudpubsub/message_transforms.py
Expand Up @@ -19,27 +19,42 @@
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.cloudpubsub import MessageTransformer
from google.cloud.pubsublite.internal import fast_serialize
from google.cloud.pubsublite.types import Partition, MessageMetadata
from google.cloud.pubsublite_v1 import AttributeValues, SequencedMessage, PubSubMessage

PUBSUB_LITE_EVENT_TIME = "x-goog-pubsublite-event-time"


def encode_attribute_event_time(dt: datetime.datetime) -> str:
ts = Timestamp()
ts.FromDatetime(dt)
return ts.ToJsonString()
def _encode_attribute_event_time_proto(ts: Timestamp) -> str:
return fast_serialize.dump([ts.seconds, ts.nanos])


def decode_attribute_event_time(attr: str) -> datetime.datetime:
def _decode_attribute_event_time_proto(attr: str) -> Timestamp:
try:
ts = Timestamp()
ts.FromJsonString(attr)
return ts.ToDatetime()
except ValueError:
loaded = fast_serialize.load(attr)
ts.seconds = loaded[0]
ts.nanos = loaded[1]
return ts
except Exception: # noqa: E722
raise InvalidArgument("Invalid value for event time attribute.")


def encode_attribute_event_time(dt: datetime.datetime) -> str:
ts = Timestamp()
ts.FromDatetime(dt.astimezone(datetime.timezone.utc))
return _encode_attribute_event_time_proto(ts)


def decode_attribute_event_time(attr: str) -> datetime.datetime:
return (
_decode_attribute_event_time_proto(attr)
.ToDatetime()
.replace(tzinfo=datetime.timezone.utc)
)


def _parse_attributes(values: AttributeValues) -> str:
if not len(values.values) == 1:
raise InvalidArgument(
Expand All @@ -58,25 +73,34 @@ def add_id_to_cps_subscribe_transformer(
partition: Partition, transformer: MessageTransformer
) -> MessageTransformer:
def add_id_to_message(source: SequencedMessage):
source_pb = source._pb
message: PubsubMessage = transformer.transform(source)
if message.message_id:
message_pb = message._pb
if message_pb.message_id:
raise InvalidArgument(
"Message after transforming has the message_id field set."
)
message.message_id = MessageMetadata(partition, source.cursor).encode()
message_pb.message_id = MessageMetadata._encode_parts(
partition.value, source_pb.cursor.offset
)
return message

return MessageTransformer.of_callable(add_id_to_message)


def to_cps_subscribe_message(source: SequencedMessage) -> PubsubMessage:
message: PubsubMessage = to_cps_publish_message(source.message)
message.publish_time = source.publish_time
return message
source_pb = source._pb
out_pb = _to_cps_publish_message_proto(source_pb.message)
out_pb.publish_time.CopyFrom(source_pb.publish_time)
out = PubsubMessage()
out._pb = out_pb
return out


def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage:
out = PubsubMessage()
def _to_cps_publish_message_proto(
source: PubSubMessage.meta.pb,
) -> PubsubMessage.meta.pb:
out = PubsubMessage.meta.pb()
try:
out.ordering_key = source.key.decode("utf-8")
except UnicodeError:
Expand All @@ -88,22 +112,32 @@ def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage:
out.data = source.data
for key, values in source.attributes.items():
out.attributes[key] = _parse_attributes(values)
if "event_time" in source:
out.attributes[PUBSUB_LITE_EVENT_TIME] = encode_attribute_event_time(
if source.HasField("event_time"):
out.attributes[PUBSUB_LITE_EVENT_TIME] = _encode_attribute_event_time_proto(
source.event_time
)
return out


def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage:
out = PubsubMessage()
out._pb = _to_cps_publish_message_proto(source._pb)
return out


def from_cps_publish_message(source: PubsubMessage) -> PubSubMessage:
source_pb = source._pb
out = PubSubMessage()
if PUBSUB_LITE_EVENT_TIME in source.attributes:
out.event_time = decode_attribute_event_time(
source.attributes[PUBSUB_LITE_EVENT_TIME]
out_pb = out._pb
if PUBSUB_LITE_EVENT_TIME in source_pb.attributes:
out_pb.event_time.CopyFrom(
_decode_attribute_event_time_proto(
source_pb.attributes[PUBSUB_LITE_EVENT_TIME]
)
)
out.data = source.data
out.key = source.ordering_key.encode("utf-8")
for key, value in source.attributes.items():
out_pb.data = source_pb.data
out_pb.key = source_pb.ordering_key.encode("utf-8")
for key, value in source_pb.attributes.items():
if key != PUBSUB_LITE_EVENT_TIME:
out.attributes[key] = AttributeValues(values=[value.encode("utf-8")])
out_pb.attributes[key].values.append(value.encode("utf-8"))
return out
13 changes: 13 additions & 0 deletions google/cloud/pubsublite/internal/fast_serialize.py
@@ -0,0 +1,13 @@
"""
A fast serialization method for lists of integers.
"""

from typing import List


def dump(data: List[int]) -> str:
return ",".join(str(x) for x in data)


def load(source: str) -> List[int]:
return [int(x) for x in source.split(",")]