Skip to content

Commit

Permalink
fix: performance issues with subscriber client (#232)
Browse files Browse the repository at this point in the history
The primary change here is to access the raw proto in as many places as possible due to proto-plus-python performance issues.

The secondary change is to reduce the asyncio overhead by propagating batches through more layers of code.
  • Loading branch information
dpcollins-google committed Sep 13, 2021
1 parent 2bb209c commit 78a47b2
Show file tree
Hide file tree
Showing 18 changed files with 280 additions and 202 deletions.
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from asyncio import Future, Queue, ensure_future
from typing import Callable, NamedTuple, Dict, Set, Optional
from typing import Callable, NamedTuple, Dict, List, Set, Optional

from google.cloud.pubsub_v1.subscriber.message import Message

Expand Down Expand Up @@ -41,7 +41,7 @@ class AssigningSingleSubscriber(AsyncSingleSubscriber, PermanentFailable):

# Lazily initialized to ensure they are initialized on the thread where __aenter__ is called.
_assigner: Optional[Assigner]
_messages: Optional["Queue[Message]"]
_batches: Optional["Queue[List[Message]]"]
_assign_poller: Future

def __init__(
Expand All @@ -58,14 +58,14 @@ def __init__(
self._assigner = None
self._subscriber_factory = subscriber_factory
self._subscribers = {}
self._messages = None
self._batches = None

async def read(self) -> Message:
return await self.await_unless_failed(self._messages.get())
async def read(self) -> List[Message]:
return await self.await_unless_failed(self._batches.get())

async def _subscribe_action(self, subscriber: AsyncSingleSubscriber):
message = await subscriber.read()
await self._messages.put(message)
batch = await subscriber.read()
await self._batches.put(batch)

async def _start_subscriber(self, partition: Partition):
new_subscriber = self._subscriber_factory(partition)
Expand All @@ -92,7 +92,7 @@ async def _assign_action(self):
await self._stop_subscriber(subscriber)

async def __aenter__(self):
self._messages = Queue()
self._batches = Queue()
self._assigner = self._assigner_factory()
await self._assigner.__aenter__()
self._assign_poller = ensure_future(self.run_poller(self._assign_action))
Expand Down
Expand Up @@ -38,27 +38,17 @@
from overrides import overrides


class _SubscriberAsyncIterator(AsyncIterator):
_subscriber: AsyncSingleSubscriber
_on_failure: Callable[[], Awaitable[None]]

def __init__(
self,
subscriber: AsyncSingleSubscriber,
on_failure: Callable[[], Awaitable[None]],
):
self._subscriber = subscriber
self._on_failure = on_failure

async def __anext__(self) -> Message:
try:
return await self._subscriber.read()
except: # noqa: E722
await self._on_failure()
raise

def __aiter__(self):
return self
async def _iterate_subscriber(
subscriber: AsyncSingleSubscriber, on_failure: Callable[[], Awaitable[None]]
) -> AsyncIterator[Message]:
try:
while True:
batch = await subscriber.read()
for message in batch:
yield message
except: # noqa: E722
await on_failure()
raise


class MultiplexedAsyncSubscriberClient(AsyncSubscriberClientInterface):
Expand All @@ -85,7 +75,7 @@ async def subscribe(
await subscriber.__aenter__()
self._live_clients.add(subscriber)

return _SubscriberAsyncIterator(
return _iterate_subscriber(
subscriber, lambda: self._try_remove_client(subscriber)
)

Expand Down
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

import asyncio
import json
from typing import Callable, Union, Dict, NamedTuple
from typing import Callable, Union, List, Dict, NamedTuple
import queue

from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.permanent_failable import adapt_error
from google.cloud.pubsublite.internal import fast_serialize
from google.cloud.pubsublite.types import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer
Expand All @@ -47,15 +48,13 @@ class _AckId(NamedTuple):
generation: int
offset: int

def str(self) -> str:
return json.dumps({"generation": self.generation, "offset": self.offset})
def encode(self) -> str:
return fast_serialize.dump([self.generation, self.offset])

@staticmethod
def parse(payload: str) -> "_AckId": # pytype: disable=invalid-annotation
loaded = json.loads(payload)
return _AckId(
generation=int(loaded["generation"]), offset=int(loaded["offset"]),
)
loaded = fast_serialize.load(payload)
return _AckId(generation=loaded[0], offset=loaded[1])


ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber]
Expand Down Expand Up @@ -99,26 +98,31 @@ async def handle_reset(self):
self._ack_generation_id += 1
await self._ack_set_tracker.clear_and_commit()

async def read(self) -> Message:
def _wrap_message(self, message: SequencedMessage.meta.pb) -> Message:
# Rewrap in the proto-plus-python wrapper for passing to the transform
rewrapped = SequencedMessage()
rewrapped._pb = message
cps_message = self._transformer.transform(rewrapped)
offset = message.cursor.offset
ack_id_str = _AckId(self._ack_generation_id, offset).encode()
self._ack_set_tracker.track(offset)
self._messages_by_ack_id[ack_id_str] = _SizedMessage(
cps_message, message.size_bytes
)
wrapped_message = Message(
cps_message._pb,
ack_id=ack_id_str,
delivery_attempt=0,
request_queue=self._queue,
)
return wrapped_message

async def read(self) -> List[Message]:
try:
message: SequencedMessage = await self.await_unless_failed(
self._underlying.read()
)
cps_message = self._transformer.transform(message)
offset = message.cursor.offset
ack_id = _AckId(self._ack_generation_id, offset)
self._ack_set_tracker.track(offset)
self._messages_by_ack_id[ack_id.str()] = _SizedMessage(
cps_message, message.size_bytes
)
wrapped_message = Message(
cps_message._pb,
ack_id=ack_id.str(),
delivery_attempt=0,
request_queue=self._queue,
)
return wrapped_message
except GoogleAPICallError as e:
latest_batch = await self.await_unless_failed(self._underlying.read())
return [self._wrap_message(message) for message in latest_batch]
except Exception as e:
e = adapt_error(e) # This could be from user code
self.fail(e)
raise e

Expand Down
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod, ABCMeta
from typing import AsyncContextManager, Callable, Set, Optional
from typing import AsyncContextManager, Callable, List, Set, Optional

from google.cloud.pubsub_v1.subscriber.message import Message

Expand All @@ -32,12 +32,13 @@ class AsyncSingleSubscriber(AsyncContextManager, metaclass=ABCMeta):
"""

@abstractmethod
async def read(self) -> Message:
async def read(self) -> List[Message]:
"""
Read the next message off of the stream.
Read the next batch off of the stream.
Returns:
The next message. ack() or nack() must eventually be called exactly once.
The next batch of messages. ack() or nack() must eventually be called
exactly once on each message.
Pub/Sub Lite does not support nack() by default- if you do call nack(), it will immediately fail the client
unless you have a NackHandler installed.
Expand Down
Expand Up @@ -84,8 +84,8 @@ def _fail(self, error: GoogleAPICallError):
async def _poller(self):
try:
while True:
message = await self._underlying.read()
self._unowned_executor.submit(self._callback, message)
batch = await self._underlying.read()
self._unowned_executor.map(self._callback, batch)
except GoogleAPICallError as e: # noqa: F841 Flake8 thinks e is unused
self._unowned_executor.submit(lambda: self._fail(e)) # noqa: F821

Expand Down
82 changes: 58 additions & 24 deletions google/cloud/pubsublite/cloudpubsub/message_transforms.py
Expand Up @@ -19,27 +19,42 @@
from google.pubsub_v1 import PubsubMessage

from google.cloud.pubsublite.cloudpubsub import MessageTransformer
from google.cloud.pubsublite.internal import fast_serialize
from google.cloud.pubsublite.types import Partition, MessageMetadata
from google.cloud.pubsublite_v1 import AttributeValues, SequencedMessage, PubSubMessage

PUBSUB_LITE_EVENT_TIME = "x-goog-pubsublite-event-time"


def encode_attribute_event_time(dt: datetime.datetime) -> str:
ts = Timestamp()
ts.FromDatetime(dt)
return ts.ToJsonString()
def _encode_attribute_event_time_proto(ts: Timestamp) -> str:
return fast_serialize.dump([ts.seconds, ts.nanos])


def decode_attribute_event_time(attr: str) -> datetime.datetime:
def _decode_attribute_event_time_proto(attr: str) -> Timestamp:
try:
ts = Timestamp()
ts.FromJsonString(attr)
return ts.ToDatetime()
except ValueError:
loaded = fast_serialize.load(attr)
ts.seconds = loaded[0]
ts.nanos = loaded[1]
return ts
except Exception: # noqa: E722
raise InvalidArgument("Invalid value for event time attribute.")


def encode_attribute_event_time(dt: datetime.datetime) -> str:
ts = Timestamp()
ts.FromDatetime(dt.astimezone(datetime.timezone.utc))
return _encode_attribute_event_time_proto(ts)


def decode_attribute_event_time(attr: str) -> datetime.datetime:
return (
_decode_attribute_event_time_proto(attr)
.ToDatetime()
.replace(tzinfo=datetime.timezone.utc)
)


def _parse_attributes(values: AttributeValues) -> str:
if not len(values.values) == 1:
raise InvalidArgument(
Expand All @@ -58,25 +73,34 @@ def add_id_to_cps_subscribe_transformer(
partition: Partition, transformer: MessageTransformer
) -> MessageTransformer:
def add_id_to_message(source: SequencedMessage):
source_pb = source._pb
message: PubsubMessage = transformer.transform(source)
if message.message_id:
message_pb = message._pb
if message_pb.message_id:
raise InvalidArgument(
"Message after transforming has the message_id field set."
)
message.message_id = MessageMetadata(partition, source.cursor).encode()
message_pb.message_id = MessageMetadata._encode_parts(
partition.value, source_pb.cursor.offset
)
return message

return MessageTransformer.of_callable(add_id_to_message)


def to_cps_subscribe_message(source: SequencedMessage) -> PubsubMessage:
message: PubsubMessage = to_cps_publish_message(source.message)
message.publish_time = source.publish_time
return message
source_pb = source._pb
out_pb = _to_cps_publish_message_proto(source_pb.message)
out_pb.publish_time.CopyFrom(source_pb.publish_time)
out = PubsubMessage()
out._pb = out_pb
return out


def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage:
out = PubsubMessage()
def _to_cps_publish_message_proto(
source: PubSubMessage.meta.pb,
) -> PubsubMessage.meta.pb:
out = PubsubMessage.meta.pb()
try:
out.ordering_key = source.key.decode("utf-8")
except UnicodeError:
Expand All @@ -88,22 +112,32 @@ def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage:
out.data = source.data
for key, values in source.attributes.items():
out.attributes[key] = _parse_attributes(values)
if "event_time" in source:
out.attributes[PUBSUB_LITE_EVENT_TIME] = encode_attribute_event_time(
if source.HasField("event_time"):
out.attributes[PUBSUB_LITE_EVENT_TIME] = _encode_attribute_event_time_proto(
source.event_time
)
return out


def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage:
out = PubsubMessage()
out._pb = _to_cps_publish_message_proto(source._pb)
return out


def from_cps_publish_message(source: PubsubMessage) -> PubSubMessage:
source_pb = source._pb
out = PubSubMessage()
if PUBSUB_LITE_EVENT_TIME in source.attributes:
out.event_time = decode_attribute_event_time(
source.attributes[PUBSUB_LITE_EVENT_TIME]
out_pb = out._pb
if PUBSUB_LITE_EVENT_TIME in source_pb.attributes:
out_pb.event_time.CopyFrom(
_decode_attribute_event_time_proto(
source_pb.attributes[PUBSUB_LITE_EVENT_TIME]
)
)
out.data = source.data
out.key = source.ordering_key.encode("utf-8")
for key, value in source.attributes.items():
out_pb.data = source_pb.data
out_pb.key = source_pb.ordering_key.encode("utf-8")
for key, value in source_pb.attributes.items():
if key != PUBSUB_LITE_EVENT_TIME:
out.attributes[key] = AttributeValues(values=[value.encode("utf-8")])
out_pb.attributes[key].values.append(value.encode("utf-8"))
return out
13 changes: 13 additions & 0 deletions google/cloud/pubsublite/internal/fast_serialize.py
@@ -0,0 +1,13 @@
"""
A fast serialization method for lists of integers.
"""

from typing import List


def dump(data: List[int]) -> str:
return ",".join(str(x) for x in data)


def load(source: str) -> List[int]:
return [int(x) for x in source.split(",")]

0 comments on commit 78a47b2

Please sign in to comment.