From 0a09bb3170f06532d5e5d8e1b5f8f3fddd516f98 Mon Sep 17 00:00:00 2001 From: dpcollins-google <40498610+dpcollins-google@users.noreply.github.com> Date: Tue, 15 Sep 2020 16:37:24 -0400 Subject: [PATCH] feat: implement Flow control batcher (#15) * feat: Implement FlowControlBatcher This handles aggregating flow control requests without allowing them to get above the max int64 value. * Use correct request for comparison. --- .../internal/wire/flow_control_batcher.py | 67 +++++++++++++++++++ .../wire/flow_control_batcher_test.py | 28 ++++++++ 2 files changed, 95 insertions(+) create mode 100644 google/cloud/pubsublite/internal/wire/flow_control_batcher.py create mode 100644 tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py diff --git a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py new file mode 100644 index 00000000..7afce610 --- /dev/null +++ b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py @@ -0,0 +1,67 @@ +from typing import NamedTuple, List, Optional + +from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage + +_EXPEDITE_BATCH_REQUEST_RATIO = 0.5 +_MAX_INT64 = 0x7FFFFFFFFFFFFFFF + + +class _AggregateRequest: + request: FlowControlRequest + + def __init__(self): + self.request = FlowControlRequest() + + def __add__(self, other: FlowControlRequest): + self.request.allowed_bytes += other.allowed_bytes + self.request.allowed_bytes = min(self.request.allowed_bytes, _MAX_INT64) + self.request.allowed_messages += other.allowed_messages + self.request.allowed_messages = min(self.request.allowed_messages, _MAX_INT64) + return self + + +def _exceeds_expedite_ratio(pending: int, client: int): + if client <= 0: + return False + return (pending/client) >= _EXPEDITE_BATCH_REQUEST_RATIO + + +def _to_optional(req: FlowControlRequest) -> Optional[FlowControlRequest]: + if req.allowed_messages == 0 and req.allowed_bytes == 0: + return None + return req + + +class FlowControlBatcher: + _client_tokens: _AggregateRequest + _pending_tokens: _AggregateRequest + + def __init__(self): + self._client_tokens = _AggregateRequest() + self._pending_tokens = _AggregateRequest() + + def add(self, request: FlowControlRequest): + self._client_tokens += request + self._pending_tokens += request + + def on_messages(self, messages: List[SequencedMessage]): + byte_size = sum(message.size_bytes for message in messages) + self._client_tokens += FlowControlRequest(allowed_bytes=-byte_size, allowed_messages=-len(messages)) + + def request_for_restart(self) -> Optional[FlowControlRequest]: + self._pending_tokens = _AggregateRequest() + return _to_optional(self._client_tokens.request) + + def release_pending_request(self) -> Optional[FlowControlRequest]: + request = self._pending_tokens.request + self._pending_tokens = _AggregateRequest() + return _to_optional(request) + + def should_expedite(self): + pending_request = self._pending_tokens.request + client_request = self._client_tokens.request + if _exceeds_expedite_ratio(pending_request.allowed_bytes, client_request.allowed_bytes): + return True + if _exceeds_expedite_ratio(pending_request.allowed_messages, client_request.allowed_messages): + return True + return False diff --git a/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py b/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py new file mode 100644 index 00000000..45e78c70 --- /dev/null +++ b/tests/unit/pubsublite/internal/wire/flow_control_batcher_test.py @@ -0,0 +1,28 @@ +from google.cloud.pubsublite.internal.wire.flow_control_batcher import FlowControlBatcher +from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage + + +def test_restart_clears_send(): + batcher = FlowControlBatcher() + batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3)) + assert batcher.should_expedite() + to_send = batcher.release_pending_request() + assert to_send.allowed_bytes == 10 + assert to_send.allowed_messages == 3 + restart_1 = batcher.request_for_restart() + assert restart_1.allowed_bytes == 10 + assert restart_1.allowed_messages == 3 + assert not batcher.should_expedite() + assert batcher.release_pending_request() is None + + +def test_add_remove(): + batcher = FlowControlBatcher() + batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3)) + restart_1 = batcher.request_for_restart() + assert restart_1.allowed_bytes == 10 + assert restart_1.allowed_messages == 3 + batcher.on_messages([SequencedMessage(size_bytes=2), SequencedMessage(size_bytes=3)]) + restart_2 = batcher.request_for_restart() + assert restart_2.allowed_bytes == 5 + assert restart_2.allowed_messages == 1