Skip to content

Commit

Permalink
fix: Race conditions and performance issues (#237)
Browse files Browse the repository at this point in the history
* fix: Race conditions and performance issues

There are two main retrying_connection race conditions fixed here:

1) Improper handling of cancelled write tasks can cause set_exception to be called when the task is already cancelled, which raises an InvalidStateError which is never caught by the existing code.
2) There is a race where if reinitialize() is called after queues are cycled, meaning a poller from the old instance of the class can add a message to the new queues. This has been fixed by splitting the ConnectionReinitializer interface into "stop_processing" and "reinitialize" parts.

Also fix other performance issues identified in profiles.

* fix: Race conditions and performance issues

There are two main retrying_connection race conditions fixed here:

1) Improper handling of cancelled write tasks can cause set_exception to be called when the task is already cancelled, which raises an InvalidStateError which is never caught by the existing code.
2) There is a race where if reinitialize() is called after queues are cycled, meaning a poller from the old instance of the class can add a message to the new queues. This has been fixed by splitting the ConnectionReinitializer interface into "stop_processing" and "reinitialize" parts.

Also fix other performance issues identified in profiles.
  • Loading branch information
dpcollins-google committed Sep 14, 2021
1 parent 90d2b58 commit ec76272
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 73 deletions.
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import queue
from collections import deque
from typing import Optional

from google.api_core.exceptions import FailedPrecondition

from google.cloud.pubsublite.cloudpubsub.internal.sorted_list import SortedList
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.internal.wire.committer import Committer
from google.cloud.pubsublite_v1 import Cursor
Expand All @@ -27,13 +27,13 @@ class AckSetTrackerImpl(AckSetTracker):
_committer: Committer

_receipts: "deque[int]"
_acks: "queue.PriorityQueue[int]"
_acks: SortedList[int]

def __init__(self, committer: Committer):
super().__init__()
self._committer = committer
self._receipts = deque()
self._acks = queue.PriorityQueue()
self._acks = SortedList()

def track(self, offset: int):
if len(self._receipts) > 0:
Expand All @@ -45,25 +45,27 @@ def track(self, offset: int):
self._receipts.append(offset)

def ack(self, offset: int):
self._acks.put_nowait(offset)
self._acks.push(offset)
prefix_acked_offset: Optional[int] = None
while len(self._receipts) != 0 and not self._acks.empty():
receipt = self._receipts.popleft()
ack = self._acks.get_nowait()
ack = self._acks.peek()
if receipt == ack:
prefix_acked_offset = receipt
self._acks.pop()
continue
self._receipts.appendleft(receipt)
self._acks.put(ack)
break
if prefix_acked_offset is None:
return
# Convert from last acked to first unacked.
self._committer.commit(Cursor(offset=prefix_acked_offset + 1))
cursor = Cursor()
cursor._pb.offset = prefix_acked_offset + 1
self._committer.commit(cursor)

async def clear_and_commit(self):
self._receipts.clear()
self._acks = queue.PriorityQueue()
self._acks = SortedList()
await self._committer.wait_until_empty()

async def __aenter__(self):
Expand Down
Expand Up @@ -127,12 +127,12 @@ async def read(self) -> List[Message]:
raise e

def _handle_ack(self, message: requests.AckRequest):
self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=1,
allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes,
)
)
flow_control = FlowControlRequest()
flow_control._pb.allowed_messages = 1
flow_control._pb.allowed_bytes = self._messages_by_ack_id[
message.ack_id
].size_bytes
self._underlying.allow_flow(flow_control)
del self._messages_by_ack_id[message.ack_id]
# Always refill flow control tokens, but do not commit offsets from outdated generations.
ack_id = _AckId.parse(message.ack_id)
Expand Down
25 changes: 25 additions & 0 deletions google/cloud/pubsublite/cloudpubsub/internal/sorted_list.py
@@ -0,0 +1,25 @@
from typing import Generic, TypeVar, List, Optional
import heapq

_T = TypeVar("_T")


class SortedList(Generic[_T]):
_vals: List[_T]

def __init__(self):
self._vals = []

def push(self, val: _T):
heapq.heappush(self._vals, val)

def peek(self) -> Optional[_T]:
if self.empty():
return None
return self._vals[0]

def pop(self):
heapq.heappop(self._vals)

def empty(self) -> bool:
return not bool(self._vals)
16 changes: 10 additions & 6 deletions google/cloud/pubsublite/internal/wire/assigner_impl.py
Expand Up @@ -17,6 +17,8 @@

import logging

from overrides import overrides

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.retrying_connection import (
Expand Down Expand Up @@ -103,15 +105,17 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_receiver()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

async def reinitialize(
self,
connection: Connection[PartitionAssignmentRequest, PartitionAssignment],
last_error: Optional[GoogleAPICallError],
):
@overrides
async def stop_processing(self, error: GoogleAPICallError):
await self._stop_receiver()
self._outstanding_assignment = False
while not self._new_assignment.empty():
self._new_assignment.get_nowait()
await self._stop_receiver()

@overrides
async def reinitialize(
self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment],
):
await connection.write(PartitionAssignmentRequest(initial=self._initial))
self._start_receiver()

Expand Down
9 changes: 7 additions & 2 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -17,6 +17,8 @@

import logging

from overrides import overrides

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.committer import Committer
from google.cloud.pubsublite.internal.wire.retrying_connection import (
Expand Down Expand Up @@ -152,14 +154,17 @@ def commit(self, cursor: Cursor) -> None:
raise self._connection.error()
self._next_to_commit = cursor

@overrides
async def stop_processing(self, error: GoogleAPICallError):
await self._stop_loopers()

@overrides
async def reinitialize(
self,
connection: Connection[
StreamingCommitCursorRequest, StreamingCommitCursorResponse
],
last_error: Optional[GoogleAPICallError],
):
await self._stop_loopers()
await connection.write(StreamingCommitCursorRequest(initial=self._initial))
response = await connection.read()
if "initial" not in response:
Expand Down
19 changes: 13 additions & 6 deletions google/cloud/pubsublite/internal/wire/connection_reinitializer.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Generic, Optional
from typing import Generic
from abc import ABCMeta, abstractmethod
from google.api_core.exceptions import GoogleAPICallError
from google.cloud.pubsublite.internal.wire.connection import (
Expand All @@ -26,17 +26,24 @@ class ConnectionReinitializer(Generic[Request, Response], metaclass=ABCMeta):
"""A class capable of reinitializing a connection after a new one has been created."""

@abstractmethod
def reinitialize(
self,
connection: Connection[Request, Response],
last_error: Optional[GoogleAPICallError],
async def stop_processing(self, error: GoogleAPICallError):
"""Tear down internal state processing the current connection in
response to a stream error.
Args:
error: The error that caused the stream to break
"""
raise NotImplementedError()

@abstractmethod
async def reinitialize(
self, connection: Connection[Request, Response],
):
"""Reinitialize a connection. Must ensure no calls to the associated RetryingConnection
occur until this completes.
Args:
connection: The connection to reinitialize
last_error: The last error that caused the stream to break
Raises:
GoogleAPICallError: If it fails to reinitialize.
Expand Down
70 changes: 44 additions & 26 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
Expand Up @@ -14,11 +14,16 @@

import asyncio
from asyncio import Future
import logging
import traceback

from typing import Optional
from google.api_core.exceptions import GoogleAPICallError, Cancelled
from google.api_core.exceptions import Cancelled
from google.cloud.pubsublite.internal.wire.permanent_failable import adapt_error
from google.cloud.pubsublite.internal.status_codes import is_retryable
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wait_ignore_cancelled import (
wait_ignore_errors,
wait_ignore_cancelled,
)
from google.cloud.pubsublite.internal.wire.connection_reinitializer import (
ConnectionReinitializer,
)
Expand Down Expand Up @@ -66,6 +71,8 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_val, exc_tb):
self.fail(Cancelled("Connection shutting down."))
self._loop_task.cancel()
await wait_ignore_errors(self._loop_task)

async def write(self, request: Request) -> None:
item = WorkItem(request)
Expand All @@ -79,46 +86,56 @@ async def _run_loop(self):
"""
Processes actions on this connection and handles retries until cancelled.
"""
last_failure: Optional[GoogleAPICallError] = None
try:
bad_retries = 0
while True:
while not self.error():
try:
conn_fut = self._connection_factory.new()
async with (await conn_fut) as connection:
# Needs to happen prior to reinitialization to clear outstanding waiters.
if last_failure is not None:
while not self._write_queue.empty():
self._write_queue.get_nowait().response_future.set_exception(
last_failure
)
self._read_queue = asyncio.Queue(maxsize=1)
self._write_queue = asyncio.Queue(maxsize=1)
await self._reinitializer.reinitialize(
connection, last_failure # pytype: disable=name-error
connection # pytype: disable=name-error
)
self._initialized_once.set()
bad_retries = 0
await self._loop_connection(
connection # pytype: disable=name-error
)
except GoogleAPICallError as e:
last_failure = e
except Exception as e:
if self.error():
return
e = adapt_error(e)
logging.debug(
"Saw a stream failure. Cause: \n%s", traceback.format_exc()
)
if not is_retryable(e):
self.fail(e)
return
await asyncio.sleep(
min(_MAX_BACKOFF_SECS, _MIN_BACKOFF_SECS * (2 ** bad_retries))
try:
await self._reinitializer.stop_processing(e)
except Exception as stop_error:
self.fail(adapt_error(stop_error))
return
while not self._write_queue.empty():
response_future = self._write_queue.get_nowait().response_future
if not response_future.cancelled():
response_future.set_exception(e)
self._read_queue = asyncio.Queue(maxsize=1)
self._write_queue = asyncio.Queue(maxsize=1)
await wait_ignore_cancelled(
asyncio.sleep(
min(
_MAX_BACKOFF_SECS,
_MIN_BACKOFF_SECS * (2 ** bad_retries),
)
)
)
bad_retries += 1

except asyncio.CancelledError:
return
except Exception as e:
import traceback

traceback.print_exc()
print(e)
logging.error(
"Saw a stream failure which was unhandled. Cause: \n%s",
traceback.format_exc(),
)
self.fail(adapt_error(e))

async def _loop_connection(self, connection: Connection[Request, Response]):
read_task: "Future[Response]" = asyncio.ensure_future(connection.read())
Expand Down Expand Up @@ -149,6 +166,7 @@ async def _handle_write(
try:
await connection.write(to_write.request)
to_write.response_future.set_result(None)
except GoogleAPICallError as e:
except Exception as e:
e = adapt_error(e)
to_write.response_future.set_exception(e)
raise e
Expand Up @@ -169,12 +169,14 @@ async def publish(self, message: PubSubMessage) -> MessageMetadata:
await self._flush()
return MessageMetadata(self._partition, await future)

@overrides
async def stop_processing(self, error: GoogleAPICallError):
await self._stop_loopers()

@overrides
async def reinitialize(
self,
connection: Connection[PublishRequest, PublishResponse],
last_error: Optional[GoogleAPICallError],
self, connection: Connection[PublishRequest, PublishResponse],
):
await self._stop_loopers()
await connection.write(PublishRequest(initial_request=self._initial))
response = await connection.read()
if "initial_response" not in response:
Expand Down
18 changes: 9 additions & 9 deletions google/cloud/pubsublite/internal/wire/subscriber_impl.py
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, List

from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition
from overrides import overrides

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.internal.wire.connection import (
Expand Down Expand Up @@ -56,7 +57,6 @@ class SubscriberImpl(

_outstanding_flow_control: FlowControlBatcher

_reinitializing: bool
_last_received_offset: Optional[int]

_message_queue: "asyncio.Queue[List[SequencedMessage.meta.pb]]"
Expand Down Expand Up @@ -154,14 +154,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

async def reinitialize(
self,
connection: Connection[SubscribeRequest, SubscribeResponse],
last_error: Optional[GoogleAPICallError],
):
self._reinitializing = True
@overrides
async def stop_processing(self, error: GoogleAPICallError):
await self._stop_loopers()
if last_error and is_reset_signal(last_error):
if is_reset_signal(error):
# Discard undelivered messages and refill flow control tokens.
while not self._message_queue.empty():
batch: List[SequencedMessage.meta.pb] = self._message_queue.get_nowait()
Expand All @@ -174,6 +170,11 @@ async def reinitialize(

await self._reset_handler.handle_reset()
self._last_received_offset = None

@overrides
async def reinitialize(
self, connection: Connection[SubscribeRequest, SubscribeResponse]
):
initial = deepcopy(self._base_initial)
if self._last_received_offset is not None:
initial.initial_location = SeekRequest(
Expand All @@ -195,7 +196,6 @@ async def reinitialize(
tokens = self._outstanding_flow_control.request_for_restart()
if tokens is not None:
await connection.write(SubscribeRequest(flow_control=tokens))
self._reinitializing = False
self._start_loopers()

async def read(self) -> List[SequencedMessage.meta.pb]:
Expand Down

0 comments on commit ec76272

Please sign in to comment.