From 0d16f7fc330c3aa0b15545f7e5caf3a53515c2fe Mon Sep 17 00:00:00 2001 From: Marenz Date: Tue, 13 Dec 2022 17:40:00 +0100 Subject: [PATCH] Simple ringbuffer with tests Signed-off-by: Marenz --- src/frequenz/sdk/util/ringbuffer.py | 321 ++++++++++++++++++++++++++++ tests/utils/ringbuffer.py | 137 ++++++++++++ 2 files changed, 458 insertions(+) create mode 100644 src/frequenz/sdk/util/ringbuffer.py create mode 100644 tests/utils/ringbuffer.py diff --git a/src/frequenz/sdk/util/ringbuffer.py b/src/frequenz/sdk/util/ringbuffer.py new file mode 100644 index 000000000..315cb5d9c --- /dev/null +++ b/src/frequenz/sdk/util/ringbuffer.py @@ -0,0 +1,321 @@ +# License: MIT +# Copyright © 2022 Frequenz Energy-as-a-Service GmbH + +"""Ringbuffer implementation with focus on time & memory efficiency.""" + +from __future__ import annotations + +from copy import deepcopy +from datetime import datetime, timedelta +from typing import Any, Generic, Sequence, TypeVar + +import numpy as np + +T = TypeVar("T") + +Container = list | np.ndarray + + +class RingBuffer(Generic[T]): + """A ring buffer with a fixed size. + + Should work with most backends, tested with list and np.ndarrays. + """ + + class NoElements(Exception): + """Exception class, thrown when no elements are available when calling pop().""" + + def __init__(self, container: Container) -> None: + """Initialize the ring buffer with the given container. + + Args: + container: Container to store the data in. + """ + self._container = container + self._write_index = 0 + self._read_index = 0 + self._len = 0 + + def __len__(self) -> int: + """Get current amount of elements. + + Returns: + the amount of items that this container currently holds. + """ + return self._len + + @property + def maxlen(self) -> int: + """Get the max length. + + Returns: + the max amount of items this container can hold. + """ + return len(self._container) + + def push(self, value: T) -> int: + """Push a new value into the ring buffer. + + Args: + value: value to push into the ring buffer. + + Returns: + the index in the ringbuffer. + """ + if self._len == len(self._container): + # Move read position one forward, dropping the oldest written value + self._read_index = self._wrap(self._read_index + 1) + else: + self._len += 1 + + self._container[self._write_index] = value + value_index = self._write_index + self._write_index = self._wrap(self._write_index + 1) + + return value_index + + def pop(self) -> T: + """Remove the oldest value from the ring buffer and return it. + + Raises: + NoElements: when no elements exist to pop. + + Returns: + Oldest value found in the ring buffer. + """ + if self._len == 0: + raise RingBuffer.NoElements() + + val = self._container[self._read_index] + self._read_index = (self._read_index + 1) % len(self._container) + self._len -= 1 + + return val + + @property + def full(self) -> bool: + """Check if the container is full. + + Returns: + True when the container is full. + """ + return len(self) == len(self._container) + + def __setitem__(self, index: int | slice, value: T | Sequence[T]) -> None: + """Write the given value to the requested position. + + Args: + index: Position to write the value to. + value: Value to write. + """ + self._container[index] = value + + def __getitem__(self, index_or_slice: int | slice) -> T | Container: + """Request a value or slice. + + Does not support wrap-around or copying of data. + + Args: + index_or_slice: Index or slice specification of the requested data + + Returns: + the value at the given index or value range at the given slice. + """ + return self._container[index_or_slice] + + def _wrap(self, index: int) -> int: + return index % len(self._container) + + +class OrderedRingBuffer(Generic[T]): + """Time aware ringbuffer that keeps its entries sorted time.""" + + def __init__( + self, + buffer: Any, + resolution_in_seconds: int, + window_border: datetime = datetime(1, 1, 1), + ) -> None: + """Initialize the time aware ringbuffer. + + Args: + buffer: instance of a buffer container to use internally + resolution_in_seconds: resolution of the incoming timestamps in + seconds + window_border: datetime depicting point in time to use as border + beginning, useful to make data start at the beginning of the day or + hour. + """ + self._buffer = buffer + self._resolution_in_seconds = resolution_in_seconds + self._window_start = window_border + + self._missing_windows = [] + self._datetime_newest = datetime.min + self._datetime_oldest = datetime.max + + @property + def maxlen(self) -> int: + """Get the max length. + + Returns: + the max amount of items this container can hold. + """ + return len(self._buffer) + + def update(self, timestamp: datetime, value: T, missing: bool = False) -> None: + """Update the buffer with a new value for the given timestamp. + + Args: + timestamp: Timestamp of the new value + value: value to add + missing: if true, the given timestamp will be recorded as missing. + The value will still be written. + + Returns: + Nothing. + """ + # Update timestamps + self._datetime_newest = max(self._datetime_newest, timestamp) + self._datetime_oldest = min(self._datetime_oldest, timestamp) + + if self._datetime_oldest < self._datetime_newest - timedelta( + seconds=len(self._buffer) * self._resolution_in_seconds + ): + self._datetime_oldest = self._datetime_newest - timedelta( + len(self._buffer) * self._resolution_in_seconds + ) + + # Update data + insert_index = self.datetime_to_index(timestamp) + + self._buffer[insert_index] = value + + # Update list of missing windows + # + # We always append to the last pending window. + # A window is pending when end is None + if missing: + # Create new if no pending window + if ( + len(self._missing_windows) == 0 + or self._missing_windows[-1].end is not None + ): + self._missing_windows.append({"start": timestamp, "end": None}) + elif len(self._missing_windows) > 0: + # Finalize a pending window + if self._missing_windows[-1].end is None: + self._missing_windows[-1].end = timestamp + + # Delete out-to-date windows + if len(self._missing_windows) > 0 and self._missing_windows[0].end is not None: + if self._missing_windows[0].end <= self._datetime_oldest: + self._missing_windows = self._missing_windows[1:] + + def datetime_to_index(self, timestamp: datetime) -> int: + """Convert the given timestamp to an index. + + Throws an index error when the timestamp is not found within this + buffer. + + Args: + timestamp: Timestamp to convert. + + Raises: + IndexError: when requesting a timestamp outside the range this container holds + + Returns: + index where the value for the given timestamp can be found. + """ + if self._datetime_newest < timestamp or timestamp < self._datetime_oldest: + raise IndexError( + f"Requested timestamp {timestamp} is is " + f"outside the range [{self._datetime_oldest} - {self._datetime_newest}]" + ) + + return self._wrap(int(abs((self._window_start - timestamp).total_seconds()))) + + def window(self, start: datetime, end: datetime) -> Container: + """Request a view on the data between start timestamp and end timestamp. + + Will return a copy in the following cases: + * The requested time period is crossing the start/end of the buffer + * The requested time period contains missing entries. + + This means, if the caller needs to modify the data to account for + missing entries, they can safely do so. + + Args: + start: start time of the window + end: end time of the window + + Returns: + the requested window + """ + assert start < end + + start_index = self.datetime_to_index(start) + end_index = self.datetime_to_index(end) + + # Requested window wraps around the ends + if start_index > end_index: + window = self._buffer[start_index:] + + if end_index > 0: + if isinstance(self._buffer, list): + window += self._buffer[0:end_index] + else: + window = np.concatenate((window, self._buffer[0:end_index])) + return window + + def in_window(window): + if window.start <= start < window.end: + return True + if window.start <= end < window.end: + return True + + return False + + # Return a copy if there are none-values in the data + if any(map(in_window, self._missing_windows)): + return deepcopy(self._buffer[start_index:end_index]) + + return self._buffer[start_index:end_index] + + def _wrap(self, index: int) -> int: + """Normalize the given index to fit in the buffer by wrapping it around. + + Args: + index: index to normalize. + + Returns: + an index that will be within max_size. + """ + return index % self.maxlen + + def __getitem__(self, index_or_slice: int | slice) -> T | Sequence[T]: + """Get item or slice at requested position. + + Args: + index_or_slice: Index or slice specification of the requested data. + + Returns: + The requested value or slice. + """ + return self._buffer[index_or_slice] + + def __len__(self) -> int: + """Return the amount of items that this container currently holds. + + Returns: + The length. + """ + if self._datetime_newest == datetime.min: + return 0 + + start_index = self.datetime_to_index(self._datetime_oldest) + end_index = self.datetime_to_index(self._datetime_newest) + + if end_index < start_index: + return len(self._buffer) - start_index + end_index + return start_index - end_index diff --git a/tests/utils/ringbuffer.py b/tests/utils/ringbuffer.py new file mode 100644 index 000000000..4ba818f00 --- /dev/null +++ b/tests/utils/ringbuffer.py @@ -0,0 +1,137 @@ +# License: MIT +# Copyright © 2022 Frequenz Energy-as-a-Service GmbH + +"""Tests for the `Ringbuffer` class.""" + +import random +from datetime import datetime +from itertools import cycle, islice +from typing import TypeVar + +import numpy as np +import pytest + +from frequenz.sdk.util.ringbuffer import OrderedRingBuffer, RingBuffer + +T = TypeVar("T") + + +@pytest.mark.parametrize( + "buffer", + [ + RingBuffer[int]([0] * 50000), + RingBuffer[int](np.empty(shape=(50000,), dtype=np.float64)), + ], +) +def test_simple_push_pop(buffer: RingBuffer[int]) -> None: + """Test simple pushing/popping of RingBuffer.""" + for i in range(buffer.maxlen): + buffer.push(i) + + assert len(buffer) == buffer.maxlen + + for i in range(buffer.maxlen): + assert i == buffer.pop() + + +@pytest.mark.parametrize( + "buffer", + [ + RingBuffer[int]([0] * 50000), + RingBuffer[int](np.empty(shape=(50000,), dtype=np.float64)), + ], +) +def test_push_pop_over_limit(buffer: RingBuffer[int]) -> None: + """Test pushing over the limit and the expected loss of data.""" + over_limit_pushes = 1000 + + for i in range(buffer.maxlen + over_limit_pushes): + buffer.push(i) + + assert len(buffer) == buffer.maxlen + + for i in range(buffer.maxlen): + assert i + over_limit_pushes == buffer.pop() + + assert len(buffer) == 0 + + +@pytest.mark.parametrize( + "buffer, element_type", + [ + (RingBuffer[int]([0] * 5000), int), + (RingBuffer[float](np.empty(shape=(5000,), dtype=np.float64)), float), + ], +) +def test_slicing(buffer: RingBuffer[T], element_type: type) -> None: + """Test slicing method.""" + for i in range(buffer.maxlen): + buffer.push(element_type(i)) + + # Wrap in extra list() otherwise pytest complains about numpy arrays + # pylint: disable=protected-access + assert list(buffer._container[0:]) == list(range(buffer.maxlen)) + + +@pytest.mark.parametrize( + "buffer", + [ + OrderedRingBuffer[int]([0] * 24, 1), + OrderedRingBuffer[float](np.empty(shape=(24,), dtype=np.float64), 1), + ], +) +def test_timestamp_ringbuffer(buffer: OrderedRingBuffer[float | int]) -> None: + """Test ordered ring buffer.""" + size = buffer.maxlen + + # import pdb; pdb.set_trace() + random.seed(0) + + # Push in random order + for i in random.sample(range(size), size): + buffer.update(datetime.fromtimestamp(200 + i), i) + + # Check all possible window sizes and start positions + for i in range(size): + for j in range(1, size): + start = datetime.fromtimestamp(200 + i) + end = datetime.fromtimestamp(200 + j + i) + + tmp = list(islice(cycle(range(0, size)), i, i + j)) + assert list(buffer.window(start, end)) == list(tmp) + + +@pytest.mark.parametrize( + "buffer", + [ + (OrderedRingBuffer[float]([0] * 24, 1)), + (OrderedRingBuffer[float](np.empty(shape=(24,), dtype=np.float64), 1)), + ], +) +def test_timestamp_ringbuffer_overwrite(buffer: OrderedRingBuffer[float | int]) -> None: + """Test overwrite behavior and correctness.""" + size = buffer.maxlen + + # import pdb; pdb.set_trace() + random.seed(0) + + # Push in random order + for i in random.sample(range(size), size): + buffer.update(datetime.fromtimestamp(200 + i), i) + + # Push the same amount twice + for i in random.sample(range(size), size): + buffer.update(datetime.fromtimestamp(200 + i), i * 2) + + # Check all possible window sizes and start positions + for i in range(size): + for j in range(1, size): + start = datetime.fromtimestamp(200 + i) + end = datetime.fromtimestamp(200 + j + i) + + tmp = islice(cycle(range(0, size * 2, 2)), i, i + j) + # assert list(buffer.window(start, end)) == list(tmp) + for actual, expectation in zip(buffer.window(start, end), tmp): + assert actual == expectation + + assert j == len(buffer.window(start, end))