From ec7627295bbd26a4ac910f88f4dadc5492b48fa3 Mon Sep 17 00:00:00 2001 From: dpcollins-google <40498610+dpcollins-google@users.noreply.github.com> Date: Tue, 14 Sep 2021 15:13:02 -0400 Subject: [PATCH] fix: Race conditions and performance issues (#237) * fix: Race conditions and performance issues There are two main retrying_connection race conditions fixed here: 1) Improper handling of cancelled write tasks can cause set_exception to be called when the task is already cancelled, which raises an InvalidStateError which is never caught by the existing code. 2) There is a race where if reinitialize() is called after queues are cycled, meaning a poller from the old instance of the class can add a message to the new queues. This has been fixed by splitting the ConnectionReinitializer interface into "stop_processing" and "reinitialize" parts. Also fix other performance issues identified in profiles. * fix: Race conditions and performance issues There are two main retrying_connection race conditions fixed here: 1) Improper handling of cancelled write tasks can cause set_exception to be called when the task is already cancelled, which raises an InvalidStateError which is never caught by the existing code. 2) There is a race where if reinitialize() is called after queues are cycled, meaning a poller from the old instance of the class can add a message to the new queues. This has been fixed by splitting the ConnectionReinitializer interface into "stop_processing" and "reinitialize" parts. Also fix other performance issues identified in profiles. --- .../internal/ack_set_tracker_impl.py | 18 ++--- .../internal/single_partition_subscriber.py | 12 ++-- .../cloudpubsub/internal/sorted_list.py | 25 +++++++ .../pubsublite/internal/wire/assigner_impl.py | 16 +++-- .../internal/wire/committer_impl.py | 9 ++- .../internal/wire/connection_reinitializer.py | 19 +++-- .../internal/wire/retrying_connection.py | 70 ++++++++++++------- .../wire/single_partition_publisher.py | 10 +-- .../internal/wire/subscriber_impl.py | 18 ++--- .../internal/wire/retrying_connection_test.py | 31 ++++++-- 10 files changed, 155 insertions(+), 73 deletions(-) create mode 100644 google/cloud/pubsublite/cloudpubsub/internal/sorted_list.py diff --git a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py index b1e6aedd..3222994e 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import queue from collections import deque from typing import Optional from google.api_core.exceptions import FailedPrecondition +from google.cloud.pubsublite.cloudpubsub.internal.sorted_list import SortedList from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker from google.cloud.pubsublite.internal.wire.committer import Committer from google.cloud.pubsublite_v1 import Cursor @@ -27,13 +27,13 @@ class AckSetTrackerImpl(AckSetTracker): _committer: Committer _receipts: "deque[int]" - _acks: "queue.PriorityQueue[int]" + _acks: SortedList[int] def __init__(self, committer: Committer): super().__init__() self._committer = committer self._receipts = deque() - self._acks = queue.PriorityQueue() + self._acks = SortedList() def track(self, offset: int): if len(self._receipts) > 0: @@ -45,25 +45,27 @@ def track(self, offset: int): self._receipts.append(offset) def ack(self, offset: int): - self._acks.put_nowait(offset) + self._acks.push(offset) prefix_acked_offset: Optional[int] = None while len(self._receipts) != 0 and not self._acks.empty(): receipt = self._receipts.popleft() - ack = self._acks.get_nowait() + ack = self._acks.peek() if receipt == ack: prefix_acked_offset = receipt + self._acks.pop() continue self._receipts.appendleft(receipt) - self._acks.put(ack) break if prefix_acked_offset is None: return # Convert from last acked to first unacked. - self._committer.commit(Cursor(offset=prefix_acked_offset + 1)) + cursor = Cursor() + cursor._pb.offset = prefix_acked_offset + 1 + self._committer.commit(cursor) async def clear_and_commit(self): self._receipts.clear() - self._acks = queue.PriorityQueue() + self._acks = SortedList() await self._committer.wait_until_empty() async def __aenter__(self): diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py index 801adfd9..296b4aac 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py @@ -127,12 +127,12 @@ async def read(self) -> List[Message]: raise e def _handle_ack(self, message: requests.AckRequest): - self._underlying.allow_flow( - FlowControlRequest( - allowed_messages=1, - allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes, - ) - ) + flow_control = FlowControlRequest() + flow_control._pb.allowed_messages = 1 + flow_control._pb.allowed_bytes = self._messages_by_ack_id[ + message.ack_id + ].size_bytes + self._underlying.allow_flow(flow_control) del self._messages_by_ack_id[message.ack_id] # Always refill flow control tokens, but do not commit offsets from outdated generations. ack_id = _AckId.parse(message.ack_id) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/sorted_list.py b/google/cloud/pubsublite/cloudpubsub/internal/sorted_list.py new file mode 100644 index 00000000..a64e25d5 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/sorted_list.py @@ -0,0 +1,25 @@ +from typing import Generic, TypeVar, List, Optional +import heapq + +_T = TypeVar("_T") + + +class SortedList(Generic[_T]): + _vals: List[_T] + + def __init__(self): + self._vals = [] + + def push(self, val: _T): + heapq.heappush(self._vals, val) + + def peek(self) -> Optional[_T]: + if self.empty(): + return None + return self._vals[0] + + def pop(self): + heapq.heappop(self._vals) + + def empty(self) -> bool: + return not bool(self._vals) diff --git a/google/cloud/pubsublite/internal/wire/assigner_impl.py b/google/cloud/pubsublite/internal/wire/assigner_impl.py index 17ee55ea..b411ad7d 100644 --- a/google/cloud/pubsublite/internal/wire/assigner_impl.py +++ b/google/cloud/pubsublite/internal/wire/assigner_impl.py @@ -17,6 +17,8 @@ import logging +from overrides import overrides + from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.assigner import Assigner from google.cloud.pubsublite.internal.wire.retrying_connection import ( @@ -103,15 +105,17 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._stop_receiver() await self._connection.__aexit__(exc_type, exc_val, exc_tb) - async def reinitialize( - self, - connection: Connection[PartitionAssignmentRequest, PartitionAssignment], - last_error: Optional[GoogleAPICallError], - ): + @overrides + async def stop_processing(self, error: GoogleAPICallError): + await self._stop_receiver() self._outstanding_assignment = False while not self._new_assignment.empty(): self._new_assignment.get_nowait() - await self._stop_receiver() + + @overrides + async def reinitialize( + self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment], + ): await connection.write(PartitionAssignmentRequest(initial=self._initial)) self._start_receiver() diff --git a/google/cloud/pubsublite/internal/wire/committer_impl.py b/google/cloud/pubsublite/internal/wire/committer_impl.py index cafaa344..b96bf331 100644 --- a/google/cloud/pubsublite/internal/wire/committer_impl.py +++ b/google/cloud/pubsublite/internal/wire/committer_impl.py @@ -17,6 +17,8 @@ import logging +from overrides import overrides + from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.committer import Committer from google.cloud.pubsublite.internal.wire.retrying_connection import ( @@ -152,14 +154,17 @@ def commit(self, cursor: Cursor) -> None: raise self._connection.error() self._next_to_commit = cursor + @overrides + async def stop_processing(self, error: GoogleAPICallError): + await self._stop_loopers() + + @overrides async def reinitialize( self, connection: Connection[ StreamingCommitCursorRequest, StreamingCommitCursorResponse ], - last_error: Optional[GoogleAPICallError], ): - await self._stop_loopers() await connection.write(StreamingCommitCursorRequest(initial=self._initial)) response = await connection.read() if "initial" not in response: diff --git a/google/cloud/pubsublite/internal/wire/connection_reinitializer.py b/google/cloud/pubsublite/internal/wire/connection_reinitializer.py index c02713f0..f0fbb083 100644 --- a/google/cloud/pubsublite/internal/wire/connection_reinitializer.py +++ b/google/cloud/pubsublite/internal/wire/connection_reinitializer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Generic, Optional +from typing import Generic from abc import ABCMeta, abstractmethod from google.api_core.exceptions import GoogleAPICallError from google.cloud.pubsublite.internal.wire.connection import ( @@ -26,17 +26,24 @@ class ConnectionReinitializer(Generic[Request, Response], metaclass=ABCMeta): """A class capable of reinitializing a connection after a new one has been created.""" @abstractmethod - def reinitialize( - self, - connection: Connection[Request, Response], - last_error: Optional[GoogleAPICallError], + async def stop_processing(self, error: GoogleAPICallError): + """Tear down internal state processing the current connection in + response to a stream error. + + Args: + error: The error that caused the stream to break + """ + raise NotImplementedError() + + @abstractmethod + async def reinitialize( + self, connection: Connection[Request, Response], ): """Reinitialize a connection. Must ensure no calls to the associated RetryingConnection occur until this completes. Args: connection: The connection to reinitialize - last_error: The last error that caused the stream to break Raises: GoogleAPICallError: If it fails to reinitialize. diff --git a/google/cloud/pubsublite/internal/wire/retrying_connection.py b/google/cloud/pubsublite/internal/wire/retrying_connection.py index ae35930a..2b3c7377 100644 --- a/google/cloud/pubsublite/internal/wire/retrying_connection.py +++ b/google/cloud/pubsublite/internal/wire/retrying_connection.py @@ -14,11 +14,16 @@ import asyncio from asyncio import Future +import logging +import traceback -from typing import Optional -from google.api_core.exceptions import GoogleAPICallError, Cancelled +from google.api_core.exceptions import Cancelled +from google.cloud.pubsublite.internal.wire.permanent_failable import adapt_error from google.cloud.pubsublite.internal.status_codes import is_retryable -from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors +from google.cloud.pubsublite.internal.wait_ignore_cancelled import ( + wait_ignore_errors, + wait_ignore_cancelled, +) from google.cloud.pubsublite.internal.wire.connection_reinitializer import ( ConnectionReinitializer, ) @@ -66,6 +71,8 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): self.fail(Cancelled("Connection shutting down.")) + self._loop_task.cancel() + await wait_ignore_errors(self._loop_task) async def write(self, request: Request) -> None: item = WorkItem(request) @@ -79,46 +86,56 @@ async def _run_loop(self): """ Processes actions on this connection and handles retries until cancelled. """ - last_failure: Optional[GoogleAPICallError] = None try: bad_retries = 0 - while True: + while not self.error(): try: conn_fut = self._connection_factory.new() async with (await conn_fut) as connection: - # Needs to happen prior to reinitialization to clear outstanding waiters. - if last_failure is not None: - while not self._write_queue.empty(): - self._write_queue.get_nowait().response_future.set_exception( - last_failure - ) - self._read_queue = asyncio.Queue(maxsize=1) - self._write_queue = asyncio.Queue(maxsize=1) await self._reinitializer.reinitialize( - connection, last_failure # pytype: disable=name-error + connection # pytype: disable=name-error ) self._initialized_once.set() bad_retries = 0 await self._loop_connection( connection # pytype: disable=name-error ) - except GoogleAPICallError as e: - last_failure = e + except Exception as e: + if self.error(): + return + e = adapt_error(e) + logging.debug( + "Saw a stream failure. Cause: \n%s", traceback.format_exc() + ) if not is_retryable(e): self.fail(e) return - await asyncio.sleep( - min(_MAX_BACKOFF_SECS, _MIN_BACKOFF_SECS * (2 ** bad_retries)) + try: + await self._reinitializer.stop_processing(e) + except Exception as stop_error: + self.fail(adapt_error(stop_error)) + return + while not self._write_queue.empty(): + response_future = self._write_queue.get_nowait().response_future + if not response_future.cancelled(): + response_future.set_exception(e) + self._read_queue = asyncio.Queue(maxsize=1) + self._write_queue = asyncio.Queue(maxsize=1) + await wait_ignore_cancelled( + asyncio.sleep( + min( + _MAX_BACKOFF_SECS, + _MIN_BACKOFF_SECS * (2 ** bad_retries), + ) + ) ) bad_retries += 1 - - except asyncio.CancelledError: - return except Exception as e: - import traceback - - traceback.print_exc() - print(e) + logging.error( + "Saw a stream failure which was unhandled. Cause: \n%s", + traceback.format_exc(), + ) + self.fail(adapt_error(e)) async def _loop_connection(self, connection: Connection[Request, Response]): read_task: "Future[Response]" = asyncio.ensure_future(connection.read()) @@ -149,6 +166,7 @@ async def _handle_write( try: await connection.write(to_write.request) to_write.response_future.set_result(None) - except GoogleAPICallError as e: + except Exception as e: + e = adapt_error(e) to_write.response_future.set_exception(e) raise e diff --git a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py index 7e425e3f..dd7f57c5 100644 --- a/google/cloud/pubsublite/internal/wire/single_partition_publisher.py +++ b/google/cloud/pubsublite/internal/wire/single_partition_publisher.py @@ -169,12 +169,14 @@ async def publish(self, message: PubSubMessage) -> MessageMetadata: await self._flush() return MessageMetadata(self._partition, await future) + @overrides + async def stop_processing(self, error: GoogleAPICallError): + await self._stop_loopers() + + @overrides async def reinitialize( - self, - connection: Connection[PublishRequest, PublishResponse], - last_error: Optional[GoogleAPICallError], + self, connection: Connection[PublishRequest, PublishResponse], ): - await self._stop_loopers() await connection.write(PublishRequest(initial_request=self._initial)) response = await connection.read() if "initial_response" not in response: diff --git a/google/cloud/pubsublite/internal/wire/subscriber_impl.py b/google/cloud/pubsublite/internal/wire/subscriber_impl.py index 9cbb9fff..3212e1b4 100644 --- a/google/cloud/pubsublite/internal/wire/subscriber_impl.py +++ b/google/cloud/pubsublite/internal/wire/subscriber_impl.py @@ -17,6 +17,7 @@ from typing import Optional, List from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition +from overrides import overrides from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors from google.cloud.pubsublite.internal.wire.connection import ( @@ -56,7 +57,6 @@ class SubscriberImpl( _outstanding_flow_control: FlowControlBatcher - _reinitializing: bool _last_received_offset: Optional[int] _message_queue: "asyncio.Queue[List[SequencedMessage.meta.pb]]" @@ -154,14 +154,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._stop_loopers() await self._connection.__aexit__(exc_type, exc_val, exc_tb) - async def reinitialize( - self, - connection: Connection[SubscribeRequest, SubscribeResponse], - last_error: Optional[GoogleAPICallError], - ): - self._reinitializing = True + @overrides + async def stop_processing(self, error: GoogleAPICallError): await self._stop_loopers() - if last_error and is_reset_signal(last_error): + if is_reset_signal(error): # Discard undelivered messages and refill flow control tokens. while not self._message_queue.empty(): batch: List[SequencedMessage.meta.pb] = self._message_queue.get_nowait() @@ -174,6 +170,11 @@ async def reinitialize( await self._reset_handler.handle_reset() self._last_received_offset = None + + @overrides + async def reinitialize( + self, connection: Connection[SubscribeRequest, SubscribeResponse] + ): initial = deepcopy(self._base_initial) if self._last_received_offset is not None: initial.initial_location = SeekRequest( @@ -195,7 +196,6 @@ async def reinitialize( tokens = self._outstanding_flow_control.request_for_restart() if tokens is not None: await connection.write(SubscribeRequest(flow_control=tokens)) - self._reinitializing = False self._start_loopers() async def read(self) -> List[SequencedMessage.meta.pb]: diff --git a/tests/unit/pubsublite/internal/wire/retrying_connection_test.py b/tests/unit/pubsublite/internal/wire/retrying_connection_test.py index 63560b1f..f682547c 100644 --- a/tests/unit/pubsublite/internal/wire/retrying_connection_test.py +++ b/tests/unit/pubsublite/internal/wire/retrying_connection_test.py @@ -17,7 +17,7 @@ from asynctest.mock import MagicMock, CoroutineMock import pytest -from google.api_core.exceptions import InternalServerError, InvalidArgument +from google.api_core.exceptions import InternalServerError, InvalidArgument, Unknown from google.cloud.pubsublite.internal.wire.connection import ( Connection, ConnectionFactory, @@ -70,9 +70,8 @@ def asyncio_sleep(monkeypatch): async def test_permanent_error_on_reinitializer( retrying_connection: Connection[int, int], reinitializer, default_connection ): - async def reinit_action(conn, last_error): + async def reinit_action(conn): assert conn == default_connection - assert last_error is None raise InvalidArgument("abc") reinitializer.reinitialize.side_effect = reinit_action @@ -84,9 +83,8 @@ async def reinit_action(conn, last_error): async def test_successful_reinitialize( retrying_connection: Connection[int, int], reinitializer, default_connection ): - async def reinit_action(conn, last_error): + async def reinit_action(conn): assert conn == default_connection - assert last_error is None return None default_connection.read.return_value = 1 @@ -126,9 +124,30 @@ async def test_reinitialize_after_retryable( asyncio_sleep.assert_called_once_with(_MIN_BACKOFF_SECS) assert reinitializer.reinitialize.call_count == 2 reinitializer.reinitialize.assert_has_calls( - [call(default_connection, None), call(default_connection, error)] + [call(default_connection), call(default_connection)] ) + reinitializer.stop_processing.assert_called_once_with(error) assert await retrying_connection.read() == 1 assert ( default_connection.read.call_count == 2 ) # re-call to read once first completes + + +async def test_reinitialize_stop_processing_fails( + retrying_connection: Connection[int, int], + reinitializer, + default_connection, + asyncio_sleep, +): + reinit_queues = wire_queues(reinitializer.reinitialize) + + default_connection.read.return_value = 1 + + error = InternalServerError("abc") + await reinit_queues.results.put(error) + reinitializer.stop_processing.side_effect = Exception("can't stop me") + with pytest.raises(Unknown): + async with retrying_connection as _: + pass + reinitializer.reinitialize.assert_called_once_with(default_connection) + reinitializer.stop_processing.assert_called_once_with(error)