Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update cps async clients #41

Merged
merged 2 commits into from Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,5 +1,5 @@
from asyncio import Future, Queue, ensure_future
from typing import Callable, NamedTuple, Dict, Set
from typing import Callable, NamedTuple, Dict, Set, Optional

from google.cloud.pubsub_v1.subscriber.message import Message

Expand All @@ -18,21 +18,31 @@ class _RunningSubscriber(NamedTuple):


class AssigningSubscriber(AsyncSubscriber, PermanentFailable):
_assigner: Assigner
_assigner_factory: Callable[[], Assigner]
_subscriber_factory: PartitionSubscriberFactory

_subscribers: Dict[Partition, _RunningSubscriber]
_messages: "Queue[Message]"

# Lazily initialized to ensure they are initialized on the thread where __aenter__ is called.
_assigner: Optional[Assigner]
_messages: Optional["Queue[Message]"]
_assign_poller: Future

def __init__(
self, assigner: Assigner, subscriber_factory: PartitionSubscriberFactory
self,
assigner_factory: Callable[[], Assigner],
subscriber_factory: PartitionSubscriberFactory,
):
"""
Accepts a factory for an Assigner instead of an Assigner because GRPC asyncio uses the current thread's event
loop.
"""
super().__init__()
self._assigner = assigner
self._assigner_factory = assigner_factory
self._assigner = None
self._subscriber_factory = subscriber_factory
self._subscribers = {}
self._messages = Queue()
self._messages = None

async def read(self) -> Message:
return await self.await_unless_failed(self._messages.get())
Expand Down Expand Up @@ -65,6 +75,8 @@ async def _assign_action(self):
del self._subscribers[partition]

async def __aenter__(self):
self._messages = Queue()
self._assigner = self._assigner_factory()
await self._assigner.__aenter__()
self._assign_poller = ensure_future(self.run_poller(self._assign_action))
return self
Expand Down
@@ -1,4 +1,4 @@
from typing import Mapping
from typing import Mapping, Callable, Optional

from google.pubsub_v1 import PubsubMessage

Expand All @@ -10,11 +10,17 @@


class AsyncPublisherImpl(AsyncPublisher):
_publisher: Publisher

def __init__(self, publisher: Publisher):
_publisher_factory: Callable[[], Publisher]
_publisher: Optional[Publisher]

def __init__(self, publisher_factory: Callable[[], Publisher]):
"""
Accepts a factory for a Publisher instead of a Publisher because GRPC asyncio uses the current thread's event
loop.
"""
super().__init__()
self._publisher = publisher
self._publisher_factory = publisher_factory
self._publisher = None

async def publish(
self, data: bytes, ordering_key: str = "", **attrs: Mapping[str, str]
Expand All @@ -26,6 +32,7 @@ async def publish(
return (await self._publisher.publish(psl_message)).encode()

async def __aenter__(self):
self._publisher = self._publisher_factory()
await self._publisher.__aenter__()
return self

Expand Down
11 changes: 7 additions & 4 deletions google/cloud/pubsublite/cloudpubsub/make_publisher.py
Expand Up @@ -40,10 +40,13 @@ def make_async_publisher(
GoogleApiCallException on any error determining topic structure.
"""
metadata = merge_metadata(pubsub_context(framework="CLOUD_PUBSUB_SHIM"), metadata)
underlying = make_wire_publisher(
topic, batching_delay_secs, credentials, client_options, metadata
)
return AsyncPublisherImpl(underlying)

def underlying_factory():
return make_wire_publisher(
topic, batching_delay_secs, credentials, client_options, metadata
)

return AsyncPublisherImpl(underlying_factory)


def make_publisher(
Expand Down
12 changes: 7 additions & 5 deletions google/cloud/pubsublite/cloudpubsub/make_subscriber.py
@@ -1,5 +1,5 @@
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, Mapping, Set, AsyncIterator
from typing import Optional, Mapping, Set, AsyncIterator, Callable
from uuid import uuid4

from google.api_core.client_options import ClientOptions
Expand Down Expand Up @@ -170,14 +170,16 @@ def make_async_subscriber(
client_options = ClientOptions(
api_endpoint=regional_endpoint(subscription.location.region)
)
assigner: Assigner
assigner_factory: Callable[[], Assigner]
if fixed_partitions:
assigner = FixedSetAssigner(fixed_partitions)
assigner_factory = lambda: FixedSetAssigner(fixed_partitions) # noqa: E731
else:
assignment_client = PartitionAssignmentServiceAsyncClient(
credentials=credentials, client_options=client_options
) # type: ignore
assigner = _make_dynamic_assigner(subscription, assignment_client, metadata)
assigner_factory = lambda: _make_dynamic_assigner( # noqa: E731
subscription, assignment_client, metadata
)

subscribe_client = SubscriberServiceAsyncClient(
credentials=credentials, client_options=client_options
Expand All @@ -196,7 +198,7 @@ def make_async_subscriber(
nack_handler,
message_transformer,
)
return AssigningSubscriber(assigner, partition_subscriber_factory)
return AssigningSubscriber(assigner_factory, partition_subscriber_factory)


def make_subscriber(
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/pubsublite/internal/wire/connection.py
Expand Up @@ -34,5 +34,5 @@ async def read(self) -> Response:
class ConnectionFactory(Generic[Request, Response]):
"""A factory for producing Connections."""

def new(self) -> Connection[Request, Response]:
async def new(self) -> Connection[Request, Response]:
raise NotImplementedError()
18 changes: 11 additions & 7 deletions google/cloud/pubsublite/internal/wire/gapic_connection.py
@@ -1,4 +1,4 @@
from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable
from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable, Awaitable
import asyncio

from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition
Expand Down Expand Up @@ -44,10 +44,10 @@ async def read(self) -> Response:
self.fail(e)
raise self.error()

def __aenter__(self):
async def __aenter__(self):
return self

def __aexit__(self, exc_type, exc_value, traceback) -> None:
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass

async def __anext__(self) -> Request:
Expand All @@ -64,15 +64,19 @@ def __aiter__(self) -> AsyncIterator[Response]:
class GapicConnectionFactory(ConnectionFactory[Request, Response]):
"""A ConnectionFactory that produces GapicConnections."""

_producer = Callable[[AsyncIterator[Request]], AsyncIterable[Response]]
_producer = Callable[[AsyncIterator[Request]], Awaitable[AsyncIterable[Response]]]

def __init__(
self, producer: Callable[[AsyncIterator[Request]], AsyncIterable[Response]]
self,
producer: Callable[
[AsyncIterator[Request]], Awaitable[AsyncIterable[Response]]
],
):
self._producer = producer

def new(self) -> Connection[Request, Response]:
async def new(self) -> Connection[Request, Response]:
conn = GapicConnection[Request, Response]()
response_iterable = self._producer(conn)
response_fut = self._producer(conn)
response_iterable = await response_fut
conn.set_response_it(response_iterable.__aiter__())
return conn
8 changes: 7 additions & 1 deletion google/cloud/pubsublite/internal/wire/retrying_connection.py
Expand Up @@ -65,7 +65,8 @@ async def _run_loop(self):
bad_retries = 0
while True:
try:
async with self._connection_factory.new() as connection:
conn_fut = self._connection_factory.new()
async with (await conn_fut) as connection:
# Needs to happen prior to reinitialization to clear outstanding waiters.
if last_failure is not None:
while not self._write_queue.empty():
Expand All @@ -89,6 +90,11 @@ async def _run_loop(self):

except asyncio.CancelledError:
return
except Exception as e:
import traceback

traceback.print_exc()
print(e)

async def _loop_connection(self, connection: Connection[Request, Response]):
read_task: Awaitable[Response] = asyncio.ensure_future(connection.read())
Expand Down
13 changes: 11 additions & 2 deletions google/cloud/pubsublite_v1/services/cursor_service/async_client.py
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -103,7 +112,7 @@ def streaming_commit_cursor(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[cursor.StreamingCommitCursorResponse]:
) -> Awaitable[AsyncIterable[cursor.StreamingCommitCursorResponse]]:
r"""Establishes a stream with the server for managing
committed cursors.

Expand Down
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -107,7 +116,7 @@ def assign_partitions(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[subscriber.PartitionAssignment]:
) -> Awaitable[AsyncIterable[subscriber.PartitionAssignment]]:
r"""Assign partitions for this client to handle for the
specified subscription.
The client must send an
Expand Down
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -103,7 +112,7 @@ def publish(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[publisher.PublishResponse]:
) -> Awaitable[AsyncIterable[publisher.PublishResponse]]:
r"""Establishes a stream with the server for publishing
messages. Once the stream is initialized, the client
publishes messages by sending publish requests on the
Expand All @@ -125,7 +134,7 @@ def publish(
sent along with the request as metadata.

Returns:
AsyncIterable[~.publisher.PublishResponse]:
Awaitable[AsyncIterable[~.publisher.PublishResponse]]:
Response to a PublishRequest.
"""

Expand Down
Expand Up @@ -18,7 +18,16 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, AsyncIterable, AsyncIterator, Sequence, Tuple, Type, Union
from typing import (
Dict,
AsyncIterable,
Awaitable,
AsyncIterator,
Sequence,
Tuple,
Type,
Union,
)
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -100,7 +109,7 @@ def subscribe(
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> AsyncIterable[subscriber.SubscribeResponse]:
) -> Awaitable[AsyncIterable[subscriber.SubscribeResponse]]:
r"""Establishes a stream with the server for receiving
messages.

Expand Down
@@ -1,6 +1,7 @@
from typing import Set

from asynctest.mock import MagicMock, call
import threading
import pytest
from google.api_core.exceptions import FailedPrecondition
from google.cloud.pubsub_v1.subscriber.message import Message
Expand All @@ -13,7 +14,7 @@
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite.testing.test_utils import wire_queues
from google.cloud.pubsublite.testing.test_utils import wire_queues, Box

# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
Expand All @@ -36,7 +37,16 @@ def subscriber_factory():

@pytest.fixture()
def subscriber(assigner, subscriber_factory):
return AssigningSubscriber(assigner, subscriber_factory)
box = Box()

def set_box():
box.val = AssigningSubscriber(lambda: assigner, subscriber_factory)

# Initialize AssigningSubscriber on another thread with a different event loop.
thread = threading.Thread(target=set_box)
thread.start()
thread.join()
return box.val


async def test_init(subscriber, assigner):
Expand Down