diff --git a/google/cloud/pubsublite/cloudpubsub/__init__.py b/google/cloud/pubsublite/cloudpubsub/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/google/cloud/pubsublite/cloudpubsub/internal/__init__.py b/google/cloud/pubsublite/cloudpubsub/internal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py new file mode 100644 index 00000000..b2d838d4 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py @@ -0,0 +1,32 @@ +from abc import abstractmethod +from typing import AsyncContextManager + + +class AckSetTracker(AsyncContextManager): + """ + An AckSetTracker tracks disjoint acknowledged messages and commits them when a contiguous prefix of tracked offsets + is aggregated. + """ + @abstractmethod + def track(self, offset: int): + """ + Track the provided offset. + + Args: + offset: the offset to track. + + Raises: + GoogleAPICallError: On an invalid offset to track. + """ + + @abstractmethod + async def ack(self, offset: int): + """ + Acknowledge the message with the provided offset. The offset must have previously been tracked. + + Args: + offset: the offset to acknowledge. + + Returns: + GoogleAPICallError: On a commit failure. + """ diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py new file mode 100644 index 00000000..45f0cd56 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py @@ -0,0 +1,52 @@ +import queue +from collections import deque +from typing import Optional + +from google.api_core.exceptions import FailedPrecondition +from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker +from google.cloud.pubsublite.internal.wire.committer import Committer +from google.cloud.pubsublite_v1 import Cursor + + +class AckSetTrackerImpl(AckSetTracker): + _committer: Committer + + _receipts: "deque[int]" + _acks: "queue.PriorityQueue[int]" + + def __init__(self, committer: Committer): + self._committer = committer + self._receipts = deque() + self._acks = queue.PriorityQueue() + + def track(self, offset: int): + if len(self._receipts) > 0: + last = self._receipts[0] + if last >= offset: + raise FailedPrecondition(f"Tried to track message {offset} which is before last tracked message {last}.") + self._receipts.append(offset) + + async def ack(self, offset: int): + # Note: put_nowait is used here and below to ensure that the below logic is executed without yielding + # to another coroutine in the event loop. The queue is unbounded so it will never throw. + self._acks.put_nowait(offset) + prefix_acked_offset: Optional[int] = None + while len(self._receipts) != 0 and not self._acks.empty(): + receipt = self._receipts.popleft() + ack = self._acks.get_nowait() + if receipt == ack: + prefix_acked_offset = receipt + continue + self._receipts.append(receipt) + self._acks.put(ack) + break + if prefix_acked_offset is None: + return + # Convert from last acked to first unacked. + await self._committer.commit(Cursor(offset=prefix_acked_offset+1)) + + async def __aenter__(self): + await self._committer.__aenter__() + + async def __aexit__(self, exc_type, exc_value, traceback): + await self._committer.__aexit__(exc_type, exc_value, traceback) diff --git a/tests/unit/pubsublite/cloudpubsub/__init__.py b/tests/unit/pubsublite/cloudpubsub/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/pubsublite/cloudpubsub/internal/__init__.py b/tests/unit/pubsublite/cloudpubsub/internal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py b/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py new file mode 100644 index 00000000..f7089044 --- /dev/null +++ b/tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py @@ -0,0 +1,46 @@ +from asynctest.mock import MagicMock, call +import pytest + +# All test coroutines will be treated as marked. +from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker +from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import AckSetTrackerImpl +from google.cloud.pubsublite.internal.wire.committer import Committer +from google.cloud.pubsublite_v1 import Cursor + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +def committer(): + committer = MagicMock(spec=Committer) + committer.__aenter__.return_value = committer + return committer + + +@pytest.fixture() +def tracker(committer): + return AckSetTrackerImpl(committer) + + +async def test_track_and_aggregate_acks(committer, tracker: AckSetTracker): + async with tracker: + committer.__aenter__.assert_called_once() + tracker.track(offset=1) + tracker.track(offset=3) + tracker.track(offset=5) + tracker.track(offset=7) + + committer.commit.assert_has_calls([]) + await tracker.ack(offset=3) + committer.commit.assert_has_calls([]) + await tracker.ack(offset=5) + committer.commit.assert_has_calls([]) + await tracker.ack(offset=1) + committer.commit.assert_has_calls([call(Cursor(offset=6))]) + + tracker.track(offset=8) + await tracker.ack(offset=7) + committer.commit.assert_has_calls([call(Cursor(offset=6)), call(Cursor(offset=8))]) + committer.__aexit__.assert_called_once() + +