-
Notifications
You must be signed in to change notification settings - Fork 12
/
flow_control_batcher.py
67 lines (49 loc) · 2.21 KB
/
flow_control_batcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import NamedTuple, List, Optional
from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage
_EXPEDITE_BATCH_REQUEST_RATIO = 0.5
_MAX_INT64 = 0x7FFFFFFFFFFFFFFF
class _AggregateRequest:
request: FlowControlRequest
def __init__(self):
self.request = FlowControlRequest()
def __add__(self, other: FlowControlRequest):
self.request.allowed_bytes += other.allowed_bytes
self.request.allowed_bytes = min(self.request.allowed_bytes, _MAX_INT64)
self.request.allowed_messages += other.allowed_messages
self.request.allowed_messages = min(self.request.allowed_messages, _MAX_INT64)
return self
def _exceeds_expedite_ratio(pending: int, client: int):
if client <= 0:
return False
return (pending/client) >= _EXPEDITE_BATCH_REQUEST_RATIO
def _to_optional(req: FlowControlRequest) -> Optional[FlowControlRequest]:
if req.allowed_messages == 0 and req.allowed_bytes == 0:
return None
return req
class FlowControlBatcher:
_client_tokens: _AggregateRequest
_pending_tokens: _AggregateRequest
def __init__(self):
self._client_tokens = _AggregateRequest()
self._pending_tokens = _AggregateRequest()
def add(self, request: FlowControlRequest):
self._client_tokens += request
self._pending_tokens += request
def on_messages(self, messages: List[SequencedMessage]):
byte_size = sum(message.size_bytes for message in messages)
self._client_tokens += FlowControlRequest(allowed_bytes=-byte_size, allowed_messages=-len(messages))
def request_for_restart(self) -> Optional[FlowControlRequest]:
self._pending_tokens = _AggregateRequest()
return _to_optional(self._client_tokens.request)
def release_pending_request(self) -> Optional[FlowControlRequest]:
request = self._pending_tokens.request
self._pending_tokens = _AggregateRequest()
return _to_optional(request)
def should_expedite(self):
pending_request = self._pending_tokens.request
client_request = self._client_tokens.request
if _exceeds_expedite_ratio(pending_request.allowed_bytes, client_request.allowed_bytes):
return True
if _exceeds_expedite_ratio(pending_request.allowed_messages, client_request.allowed_messages):
return True
return False