Skip to content

Commit

Permalink
refactor: Pass last stream error to ConnectionReinitializer.reinitial…
Browse files Browse the repository at this point in the history
…ize (#154)
  • Loading branch information
tmdiep committed Jun 4, 2021
1 parent 4a29b92 commit ab3fd7f
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 9 deletions.
4 changes: 3 additions & 1 deletion google/cloud/pubsublite/internal/wire/assigner_impl.py
Expand Up @@ -104,7 +104,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

async def reinitialize(
self, connection: Connection[PartitionAssignmentRequest, PartitionAssignment]
self,
connection: Connection[PartitionAssignmentRequest, PartitionAssignment],
last_error: Optional[GoogleAPICallError],
):
self._outstanding_assignment = False
while not self._new_assignment.empty():
Expand Down
1 change: 1 addition & 0 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Expand Up @@ -167,6 +167,7 @@ async def reinitialize(
connection: Connection[
StreamingCommitCursorRequest, StreamingCommitCursorResponse
],
last_error: Optional[GoogleAPICallError],
):
await self._stop_loopers()
await connection.write(StreamingCommitCursorRequest(initial=self._initial))
Expand Down
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Generic
from typing import Generic, Optional
from abc import ABCMeta, abstractmethod
from google.api_core.exceptions import GoogleAPICallError
from google.cloud.pubsublite.internal.wire.connection import (
Connection,
Request,
Expand All @@ -25,12 +26,17 @@ class ConnectionReinitializer(Generic[Request, Response], metaclass=ABCMeta):
"""A class capable of reinitializing a connection after a new one has been created."""

@abstractmethod
def reinitialize(self, connection: Connection[Request, Response]):
def reinitialize(
self,
connection: Connection[Request, Response],
last_error: Optional[GoogleAPICallError],
):
"""Reinitialize a connection. Must ensure no calls to the associated RetryingConnection
occur until this completes.
Args:
connection: The connection to reinitialize
last_error: The last error that caused the stream to break
Raises:
GoogleAPICallError: If it fails to reinitialize.
Expand Down
Expand Up @@ -94,7 +94,7 @@ async def _run_loop(self):
)
self._read_queue = asyncio.Queue(maxsize=1)
self._write_queue = asyncio.Queue(maxsize=1)
await self._reinitializer.reinitialize(connection)
await self._reinitializer.reinitialize(connection, last_failure)
self._initialized_once.set()
bad_retries = 0
await self._loop_connection(connection)
Expand Down
Expand Up @@ -168,7 +168,9 @@ async def publish(self, message: PubSubMessage) -> MessageMetadata:
return MessageMetadata(self._partition, await cursor_future)

async def reinitialize(
self, connection: Connection[PublishRequest, PublishResponse]
self,
connection: Connection[PublishRequest, PublishResponse],
last_error: Optional[GoogleAPICallError],
):
await self._stop_loopers()
await connection.write(PublishRequest(initial_request=self._initial))
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/pubsublite/internal/wire/subscriber_impl.py
Expand Up @@ -146,7 +146,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

async def reinitialize(
self, connection: Connection[SubscribeRequest, SubscribeResponse]
self,
connection: Connection[SubscribeRequest, SubscribeResponse],
last_error: Optional[GoogleAPICallError],
):
self._reinitializing = True
await self._stop_loopers()
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/pubsublite/internal/wire/retrying_connection_test.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
from unittest.mock import call

from asynctest.mock import MagicMock, CoroutineMock
import pytest
Expand Down Expand Up @@ -69,8 +70,9 @@ def asyncio_sleep(monkeypatch):
async def test_permanent_error_on_reinitializer(
retrying_connection: Connection[int, int], reinitializer, default_connection
):
async def reinit_action(conn):
async def reinit_action(conn, last_error):
assert conn == default_connection
assert last_error is None
raise InvalidArgument("abc")

reinitializer.reinitialize.side_effect = reinit_action
Expand All @@ -82,8 +84,9 @@ async def reinit_action(conn):
async def test_successful_reinitialize(
retrying_connection: Connection[int, int], reinitializer, default_connection
):
async def reinit_action(conn):
async def reinit_action(conn, last_error):
assert conn == default_connection
assert last_error is None
return None

default_connection.read.return_value = 1
Expand Down Expand Up @@ -116,11 +119,15 @@ async def test_reinitialize_after_retryable(

default_connection.read.return_value = 1

await reinit_queues.results.put(InternalServerError("abc"))
error = InternalServerError("abc")
await reinit_queues.results.put(error)
await reinit_queues.results.put(None)
async with retrying_connection as _:
asyncio_sleep.assert_called_once_with(_MIN_BACKOFF_SECS)
assert reinitializer.reinitialize.call_count == 2
reinitializer.reinitialize.assert_has_calls(
[call(default_connection, None), call(default_connection, error)]
)
assert await retrying_connection.read() == 1
assert (
default_connection.read.call_count == 2
Expand Down

0 comments on commit ab3fd7f

Please sign in to comment.