Skip to content

Commit

Permalink
feat: Implement python retrying connection, which generically retries…
Browse files Browse the repository at this point in the history
… stream errors (#4)

* feat: Implement python retrying connection, which generically retries stream errors.

* fix: Add asynctest to tests_require.

* fix: Add class comments.
  • Loading branch information
dpcollins-google committed Aug 10, 2020
1 parent 4624ac7 commit 11c9a69
Show file tree
Hide file tree
Showing 20 changed files with 439 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -52,6 +52,9 @@ env/
coverage.xml
sponge_log.xml

# Pycharm virtual environment
venv/

# System test environment variables.
system_tests/local_test_setup

Expand Down
3 changes: 3 additions & 0 deletions google/__init__.py
@@ -0,0 +1,3 @@
import pkg_resources

pkg_resources.declare_namespace(__name__)
3 changes: 3 additions & 0 deletions google/cloud/__init__.py
@@ -0,0 +1,3 @@
import pkg_resources

pkg_resources.declare_namespace(__name__)
Empty file.
Empty file.
38 changes: 38 additions & 0 deletions google/cloud/pubsublite/internal/wire/connection.py
@@ -0,0 +1,38 @@
from typing import Generic, TypeVar, Coroutine, Any, AsyncContextManager
from abc import ABCMeta, abstractmethod
from google.api_core.exceptions import GoogleAPICallError

Request = TypeVar('Request')
Response = TypeVar('Response')


class Connection(Generic[Request, Response], AsyncContextManager):
"""
A connection to an underlying stream. Only one call to 'read' may be outstanding at a time.
"""

@abstractmethod
async def write(self, request: Request) -> None:
"""
Write a message to the stream.
Raises:
GoogleAPICallError: When the connection terminates in failure.
"""
raise NotImplementedError()

@abstractmethod
async def read(self) -> Response:
"""
Read a message off of the stream.
Raises:
GoogleAPICallError: When the connection terminates in failure.
"""
raise NotImplementedError()


class ConnectionFactory(Generic[Request, Response]):
"""A factory for producing Connections."""
def new(self) -> Connection[Request, Response]:
raise NotImplementedError()
20 changes: 20 additions & 0 deletions google/cloud/pubsublite/internal/wire/connection_reinitializer.py
@@ -0,0 +1,20 @@
from typing import Generic
from abc import ABCMeta, abstractmethod
from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response


class ConnectionReinitializer(Generic[Request, Response], metaclass=ABCMeta):
"""A class capable of reinitializing a connection after a new one has been created."""
@abstractmethod
def reinitialize(self, connection: Connection[Request, Response]):
"""Reinitialize a connection.
Args:
connection: The connection to reinitialize
Raises:
GoogleAPICallError: If it fails to reinitialize.
"""
raise NotImplementedError()


54 changes: 54 additions & 0 deletions google/cloud/pubsublite/internal/wire/gapic_connection.py
@@ -0,0 +1,54 @@
from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable
import asyncio

from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response, ConnectionFactory
from google.cloud.pubsublite.internal.wire.work_item import WorkItem
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable

T = TypeVar('T')


class GapicConnection(Connection[Request, Response], AsyncIterator[Request], PermanentFailable):
"""A Connection wrapping a gapic AsyncIterator[Request/Response] pair."""
_write_queue: 'asyncio.Queue[WorkItem[Request]]'
_response_it: Optional[AsyncIterator[Response]]

def __init__(self):
super().__init__()
self._write_queue = asyncio.Queue(maxsize=1)

def set_response_it(self, response_it: AsyncIterator[Response]):
self._response_it = response_it

async def write(self, request: Request) -> None:
item = WorkItem(request)
await self.await_or_fail(self._write_queue.put(item))
await self.await_or_fail(item.response_future)

async def read(self) -> Response:
return await self.await_or_fail(self._response_it.__anext__())

def __aenter__(self):
return self

def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass

async def __anext__(self) -> Request:
item: WorkItem[Request] = await self.await_or_fail(self._write_queue.get())
item.response_future.set_result(None)
return item.request

def __aiter__(self) -> AsyncIterator[Response]:
return self


class GapicConnectionFactory(ConnectionFactory[Request, Response]):
"""A ConnectionFactory that produces GapicConnections."""
_producer = Callable[[AsyncIterator[Request]], AsyncIterable[Response]]

def New(self) -> Connection[Request, Response]:
conn = GapicConnection[Request, Response]()
response_iterable = self._producer(conn)
conn.set_response_it(response_iterable.__aiter__())
return conn
31 changes: 31 additions & 0 deletions google/cloud/pubsublite/internal/wire/permanent_failable.py
@@ -0,0 +1,31 @@
import asyncio
from typing import Awaitable, TypeVar

from google.api_core.exceptions import GoogleAPICallError

T = TypeVar('T')


class PermanentFailable:
"""A class that can experience permanent failures, with helpers for forwarding these to client actions."""
_failure_task: asyncio.Future

def __init__(self):
self._failure_task = asyncio.Future()

async def await_or_fail(self, awaitable: Awaitable[T]) -> T:
if self._failure_task.done():
raise self._failure_task.exception()
task = asyncio.ensure_future(awaitable)
done, _ = await asyncio.wait([task, self._failure_task], return_when=asyncio.FIRST_COMPLETED)
if task in done:
try:
return await task
except GoogleAPICallError as e:
self.fail(e)
task.cancel()
raise self._failure_task.exception()

def fail(self, err: GoogleAPICallError):
if not self._failure_task.done():
self._failure_task.set_exception(err)
88 changes: 88 additions & 0 deletions google/cloud/pubsublite/internal/wire/retrying_connection.py
@@ -0,0 +1,88 @@
import asyncio

from typing import Awaitable
from google.api_core.exceptions import GoogleAPICallError, Cancelled
from google.cloud.pubsublite.status_codes import is_retryable
from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer
from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response, ConnectionFactory
from google.cloud.pubsublite.internal.wire.work_item import WorkItem
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable

_MIN_BACKOFF_SECS = .01
_MAX_BACKOFF_SECS = 10


class RetryingConnection(Connection[Request, Response], PermanentFailable):
"""A connection which performs retries on an underlying stream when experiencing retryable errors."""
_connection_factory: ConnectionFactory[Request, Response]
_reinitializer: ConnectionReinitializer[Request, Response]

_loop_task: asyncio.Future

_write_queue: 'asyncio.Queue[WorkItem[Request]]'
_read_queue: 'asyncio.Queue[Response]'

def __init__(self, connection_factory: ConnectionFactory[Request, Response], reinitializer: ConnectionReinitializer[Request, Response]):
super().__init__()
self._connection_factory = connection_factory
self._reinitializer = reinitializer
self._write_queue = asyncio.Queue(maxsize=1)
self._read_queue = asyncio.Queue(maxsize=1)

async def __aenter__(self):
self._loop_task = asyncio.ensure_future(self._run_loop())
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self.fail(Cancelled("Connection shutting down."))

async def write(self, request: Request) -> None:
item = WorkItem(request)
await self.await_or_fail(self._write_queue.put(item))
return await self.await_or_fail(item.response_future)

async def read(self) -> Response:
return await self.await_or_fail(self._read_queue.get())

async def _run_loop(self):
"""
Processes actions on this connection and handles retries until cancelled.
"""
try:
bad_retries = 0
while True:
try:
async with self._connection_factory.new() as connection:
await self._reinitializer.reinitialize(connection)
bad_retries = 0
await self._loop_connection(connection)
except (Exception, GoogleAPICallError) as e:
if not is_retryable(e):
self.fail(e)
return
await asyncio.sleep(min(_MAX_BACKOFF_SECS, _MIN_BACKOFF_SECS * (2**bad_retries)))
bad_retries += 1

except asyncio.CancelledError:
return

async def _loop_connection(self, connection: Connection[Request, Response]):
read_task: Awaitable[Response] = asyncio.ensure_future(connection.read())
write_task: Awaitable[WorkItem[Request]] = asyncio.ensure_future(self._write_queue.get())
while True:
done, _ = await asyncio.wait([write_task, read_task], return_when=asyncio.FIRST_COMPLETED)
if write_task in done:
await self._handle_write(connection, await write_task)
write_task = asyncio.ensure_future(self._write_queue.get())
if read_task in done:
await self._read_queue.put(await read_task)
read_task = asyncio.ensure_future(connection.read())

@staticmethod
async def _handle_write(connection: Connection[Request, Response], to_write: WorkItem[Request]):
try:
await connection.write(to_write.request)
to_write.response_future.set_result(None)
except GoogleAPICallError as e:
to_write.response_future.set_exception(e)
raise e
14 changes: 14 additions & 0 deletions google/cloud/pubsublite/internal/wire/work_item.py
@@ -0,0 +1,14 @@
import asyncio
from typing import Generic, TypeVar

T = TypeVar('T')


class WorkItem(Generic[T]):
"""An item of work and a future to complete when it is finished."""
request: T
response_future: "asyncio.Future[None]"

def __init__(self, request: T):
self.request = request
self.response_future = asyncio.Future()
10 changes: 10 additions & 0 deletions google/cloud/pubsublite/status_codes.py
@@ -0,0 +1,10 @@
from grpc import StatusCode
from google.api_core.exceptions import GoogleAPICallError

retryable_codes = {
StatusCode.DEADLINE_EXCEEDED, StatusCode.ABORTED, StatusCode.INTERNAL, StatusCode.UNAVAILABLE, StatusCode.UNKNOWN
}


def is_retryable(error: GoogleAPICallError) -> bool:
return error.grpc_status_code in retryable_codes
Empty file.
8 changes: 8 additions & 0 deletions google/cloud/pubsublite/testing/test_utils.py
@@ -0,0 +1,8 @@
from typing import List, Union, Any


async def async_iterable(elts: List[Union[Any, Exception]]):
for elt in elts:
if isinstance(elt, Exception):
raise elt
yield elt
12 changes: 8 additions & 4 deletions setup.py
Expand Up @@ -27,6 +27,11 @@
with io.open(readme_filename, encoding="utf-8") as readme_file:
readme = readme_file.read()

dependencies = [
"google-api-core >= 1.22.0",
"absl-py >= 0.9.0",
"proto-plus >= 0.4.0",
]

setuptools.setup(
name="google-cloud-pubsublite",
Expand All @@ -40,10 +45,9 @@
namespace_packages=("google", "google.cloud"),
platforms="Posix; MacOS X; Windows",
include_package_data=True,
install_requires=(
"google-api-core[grpc] >= 1.22.0, < 2.0.0dev",
"proto-plus >= 0.4.0",
),
install_requires=dependencies,
setup_requires=('pytest-runner',),
tests_require=['asynctest', 'pytest', 'pytest-asyncio'],
python_requires=">=3.6",
classifiers=[
"Development Status :: 4 - Beta",
Expand Down
Empty file.
Empty file.
Empty file.
40 changes: 40 additions & 0 deletions tests/unit/pubsublite/internal/wire/gapic_connection_test.py
@@ -0,0 +1,40 @@
import asyncio

import pytest
from google.api_core.exceptions import InternalServerError
from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnection
from google.cloud.pubsublite.testing.test_utils import async_iterable

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio


async def test_read_error_fails():
conn = GapicConnection[int, int]()
conn.set_response_it(async_iterable([InternalServerError("abc")]))
with pytest.raises(InternalServerError):
await conn.read()
with pytest.raises(InternalServerError):
await conn.read()
with pytest.raises(InternalServerError):
await conn.write(3)


async def test_read_success():
conn = GapicConnection[int, int]()
conn.set_response_it(async_iterable([3, 4, 5]))
assert [await conn.read() for _ in range(3)] == [3, 4, 5]


async def test_writes():
conn = GapicConnection[int, int]()
conn.set_response_it(async_iterable([]))
task1 = asyncio.ensure_future(conn.write(1))
task2 = asyncio.ensure_future(conn.write(2))
assert not task1.done()
assert not task2.done()
assert await conn.__anext__() == 1
await task1
assert not task2.done()
assert await conn.__anext__() == 2
await task2

0 comments on commit 11c9a69

Please sign in to comment.