Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for increasing partitions in python #74

Merged
merged 3 commits into from Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 18 additions & 11 deletions google/cloud/pubsublite/internal/wire/make_publisher.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import AsyncIterator, Mapping, Optional, MutableMapping
from typing import AsyncIterator, Mapping, Optional

from google.cloud.pubsub_v1.types import BatchSettings

Expand All @@ -25,8 +25,13 @@
GapicConnectionFactory,
)
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import (
PartitionCountWatchingPublisher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_publisher import RoutingPublisher
from google.cloud.pubsublite.internal.wire.single_partition_publisher import (
SinglePartitionPublisher,
)
Expand All @@ -37,7 +42,6 @@
from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials


DEFAULT_BATCHING_SETTINGS = BatchSettings(
max_bytes=(
3 * 1024 * 1024
Expand Down Expand Up @@ -87,21 +91,24 @@ def make_publisher(
credentials=credentials, transport=transport, client_options=client_options
) # type: ignore

clients: MutableMapping[Partition, Publisher] = {}

partition_count = admin_client.get_topic_partition_count(topic)
for partition in range(partition_count):
partition = Partition(partition)

def publisher_factory(partition: Partition):
def connection_factory(requests: AsyncIterator[PublishRequest]):
final_metadata = merge_metadata(
metadata, topic_routing_metadata(topic, partition)
)
return client.publish(requests, metadata=list(final_metadata.items()))

clients[partition] = SinglePartitionPublisher(
return SinglePartitionPublisher(
InitialPublishRequest(topic=str(topic), partition=partition.value),
per_partition_batching_settings,
GapicConnectionFactory(connection_factory),
)
return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients)

def policy_factory(partition_count: int):
return DefaultRoutingPolicy(partition_count)

return PartitionCountWatchingPublisher(
PartitionCountWatcherImpl(admin_client, topic, 10),
palmere-google marked this conversation as resolved.
Show resolved Hide resolved
publisher_factory,
policy_factory,
)
22 changes: 22 additions & 0 deletions google/cloud/pubsublite/internal/wire/partition_count_watcher.py
@@ -0,0 +1,22 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import AsyncContextManager


class PartitionCountWatcher(AsyncContextManager):
@abstractmethod
async def get_partition_count(self) -> int:
raise NotImplementedError()
@@ -0,0 +1,65 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures.thread import ThreadPoolExecutor
import asyncio

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.types import TopicPath
from google.api_core.exceptions import GoogleAPICallError


class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable):
def __init__(
self, admin: AdminClientInterface, topic_path: TopicPath, duration: float
):
super().__init__()
self._admin = admin
self._topic_path = topic_path
self._duration = duration
self._any_success = False
palmere-google marked this conversation as resolved.
Show resolved Hide resolved

async def __aenter__(self):
self._thread = ThreadPoolExecutor(max_workers=1)
self._queue = asyncio.Queue(maxsize=1)
self._poll_partition_loop = asyncio.ensure_future(
self.run_poller(self._poll_partition_loop)
)

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._poll_partition_loop.cancel()
await wait_ignore_cancelled(self._poll_partition_loop)
self._thread.shutdown(wait=False)

def _get_partition_count_sync(self) -> int:
return self._admin.get_topic_partition_count(self._topic_path)

async def _poll_partition_loop(self):
try:
partition_count = await asyncio.get_event_loop().run_in_executor(
self._thread, self._get_partition_count_sync
)
self._any_success = True
await self._queue.put(partition_count)
except GoogleAPICallError as e:
if not self._any_success:
dpcollins-google marked this conversation as resolved.
Show resolved Hide resolved
raise e
await asyncio.sleep(self._duration)

async def get_partition_count(self) -> int:
return await self.await_unless_failed(self._queue.get())
@@ -0,0 +1,95 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
import threading
from typing import Callable

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy
from google.cloud.pubsublite.types import PublishMetadata, Partition
from google.cloud.pubsublite_v1 import PubSubMessage


class PartitionCountWatchingPublisher(Publisher):
def __init__(
self,
watcher: PartitionCountWatcher,
publisher_factory: Callable[[Partition], Publisher],
policy_factory: Callable[[int], RoutingPolicy],
):
self._publishers = {}
self._lock = threading.Lock()
palmere-google marked this conversation as resolved.
Show resolved Hide resolved
palmere-google marked this conversation as resolved.
Show resolved Hide resolved
self._publisher_factory = publisher_factory
self._policy_factory = policy_factory
self._watcher = watcher

async def __aenter__(self):
try:
await self._watcher.__aenter__()
await self._poll_partition_count_action()
except Exception:
await self._watcher.__aexit__(*sys.exc_info())
raise
self._partition_count_poller = asyncio.ensure_future(
self._watch_partition_count()
)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._partition_count_poller.cancel()
await wait_ignore_cancelled(self._partition_count_poller)
await self._watcher.__aexit__(exc_type, exc_val, exc_tb)
with self._lock:
for publisher in self._publishers.values():
await publisher.__aexit__(exc_type, exc_val, exc_tb)

async def _poll_partition_count_action(self):
partition_count = await self._watcher.get_partition_count()
await self._handle_partition_count_update(partition_count)

async def _watch_partition_count(self):
while True:
await self._poll_partition_count_action()

async def _handle_partition_count_update(self, partition_count: int):
with self._lock:
current_count = len(self._publishers)
if current_count == partition_count:
return
if current_count > partition_count:
return

new_publishers = {
Partition(index): self._publisher_factory(Partition(index))
for index in range(current_count, partition_count)
}
await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run nox from the root directory before merging this.

routing_policy = self._policy_factory(partition_count)

with self._lock:
self._publishers.update(new_publishers)
palmere-google marked this conversation as resolved.
Show resolved Hide resolved
self._routing_policy = routing_policy

async def publish(self, message: PubSubMessage) -> PublishMetadata:
with self._lock:
partition = self._routing_policy.route(message)
assert partition in self._publishers
publisher = self._publishers[partition]

return await publisher.publish(message)
@@ -0,0 +1,106 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import queue
import threading
from asynctest.mock import MagicMock
import pytest

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.testing.test_utils import Box
from google.cloud.pubsublite.types import Partition, TopicPath, CloudZone, CloudRegion
from google.api_core.exceptions import GoogleAPICallError

pytestmark = pytest.mark.asyncio


@pytest.fixture()
def mock_publishers():
return {Partition(i): MagicMock(spec=Publisher) for i in range(10)}


@pytest.fixture()
def topic():
zone = CloudZone(region=CloudRegion("a"), zone_id="a")
return TopicPath(project_number=1, location=zone, name="c")


@pytest.fixture()
def mock_admin():
admin = MagicMock(spec=AdminClientInterface)
return admin


@pytest.fixture()
def watcher(mock_admin, topic):
box = Box()

def set_box():
box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001)

# Initialize publisher on another thread with a different event loop.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any asyncio class cannot be accessed from a different event loop, or it is incorrect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, before I didn't understand this was necessarily constructed in the same event loop as the publisher.

I still kind of think the property that everything works even if you call aenter from a different thread than you constructed the watcher is nice. I'm happy to remove it though if you think that would be better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a utility to the testing helpers that does exactly this? I think this is a good idea, and we should do this more places. I.e. a "runOnThread" function that takes a Callable[[], T] and returns a T

thread = threading.Thread(target=set_box)
thread.start()
thread.join()
return box.val


async def test_init(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.return_value = 2
async with watcher:
pass


async def test_get_count_first_failure(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.side_effect = GoogleAPICallError("error")
with pytest.raises(GoogleAPICallError):
async with watcher:
await watcher.get_partition_count()


async def test_get_multiple_counts(watcher, mock_admin, topic):
q = queue.Queue()
mock_admin.get_topic_partition_count.side_effect = q.get
async with watcher:
task1 = asyncio.ensure_future(watcher.get_partition_count())
task2 = asyncio.ensure_future(watcher.get_partition_count())
assert not task1.done()
assert not task2.done()
q.put(3)
assert await task1 == 3
assert not task2.done()
q.put(4)
assert await task2 == 4


async def test_subsequent_failures_ignored(watcher, mock_admin, topic):
q = queue.Queue()

def side_effect():
value = q.get()
if isinstance(value, Exception):
raise value
return value

mock_admin.get_topic_partition_count.side_effect = lambda x: side_effect()
async with watcher:
q.put(3)
assert await watcher.get_partition_count() == 3
q.put(GoogleAPICallError("error"))
q.put(4)
assert await watcher.get_partition_count() == 4