Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for increasing partitions in python #74

Merged
merged 3 commits into from Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -19,6 +19,7 @@ develop-eggs
.installed.cfg
lib
lib64
venv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this down with 'env/' and add the slash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

__pycache__

# Installer logs
Expand Down
30 changes: 19 additions & 11 deletions google/cloud/pubsublite/internal/wire/make_publisher.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import AsyncIterator, Mapping, Optional, MutableMapping
from typing import AsyncIterator, Mapping, Optional

from google.cloud.pubsub_v1.types import BatchSettings

Expand All @@ -25,8 +25,13 @@
GapicConnectionFactory,
)
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import (
PartitionCountWatchingPublisher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_publisher import RoutingPublisher
from google.cloud.pubsublite.internal.wire.single_partition_publisher import (
SinglePartitionPublisher,
)
Expand All @@ -37,14 +42,14 @@
from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials


DEFAULT_BATCHING_SETTINGS = BatchSettings(
max_bytes=(
3 * 1024 * 1024
), # 3 MiB to stay 1 MiB below GRPC's 4 MiB per-message limit.
max_messages=1000,
max_latency=0.05, # 50 ms
)
DEFAULT_PARTITION_POLL_PERIOD = 600 # ten minutes


def make_publisher(
Expand Down Expand Up @@ -87,21 +92,24 @@ def make_publisher(
credentials=credentials, transport=transport, client_options=client_options
) # type: ignore

clients: MutableMapping[Partition, Publisher] = {}

partition_count = admin_client.get_topic_partition_count(topic)
for partition in range(partition_count):
partition = Partition(partition)

def publisher_factory(partition: Partition):
def connection_factory(requests: AsyncIterator[PublishRequest]):
final_metadata = merge_metadata(
metadata, topic_routing_metadata(topic, partition)
)
return client.publish(requests, metadata=list(final_metadata.items()))

clients[partition] = SinglePartitionPublisher(
return SinglePartitionPublisher(
InitialPublishRequest(topic=str(topic), partition=partition.value),
per_partition_batching_settings,
GapicConnectionFactory(connection_factory),
)
return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients)

def policy_factory(partition_count: int):
return DefaultRoutingPolicy(partition_count)

return PartitionCountWatchingPublisher(
PartitionCountWatcherImpl(admin_client, topic, DEFAULT_PARTITION_POLL_PERIOD),
publisher_factory,
policy_factory,
)
22 changes: 22 additions & 0 deletions google/cloud/pubsublite/internal/wire/partition_count_watcher.py
@@ -0,0 +1,22 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import AsyncContextManager


class PartitionCountWatcher(AsyncContextManager):
@abstractmethod
async def get_partition_count(self) -> int:
raise NotImplementedError()
@@ -0,0 +1,75 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from concurrent.futures.thread import ThreadPoolExecutor
import asyncio

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.types import TopicPath
from google.api_core.exceptions import GoogleAPICallError


class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable):
_admin: AdminClientInterface
_topic_path: TopicPath
_duration: float
_any_success: bool
_thread: ThreadPoolExecutor
_queue: asyncio.Queue
_poll_partition_loop: asyncio.Future

def __init__(
self, admin: AdminClientInterface, topic_path: TopicPath, duration: float
):
super().__init__()
self._admin = admin
self._topic_path = topic_path
self._duration = duration
self._any_success = False
palmere-google marked this conversation as resolved.
Show resolved Hide resolved

async def __aenter__(self):
self._thread = ThreadPoolExecutor(max_workers=1)
self._queue = asyncio.Queue(maxsize=1)
self._poll_partition_loop = asyncio.ensure_future(
self.run_poller(self._poll_partition_loop)
)

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._poll_partition_loop.cancel()
await wait_ignore_cancelled(self._poll_partition_loop)
self._thread.shutdown(wait=False)

def _get_partition_count_sync(self) -> int:
return self._admin.get_topic_partition_count(self._topic_path)

async def _poll_partition_loop(self):
try:
partition_count = await asyncio.get_event_loop().run_in_executor(
self._thread, self._get_partition_count_sync
)
self._any_success = True
await self._queue.put(partition_count)
except GoogleAPICallError as e:
if not self._any_success:
dpcollins-google marked this conversation as resolved.
Show resolved Hide resolved
raise e
logging.exception("Failed to retrieve partition count")
await asyncio.sleep(self._duration)

async def get_partition_count(self) -> int:
return await self.await_unless_failed(self._queue.get())
@@ -0,0 +1,94 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
from typing import Callable, Dict

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy
from google.cloud.pubsublite.types import PublishMetadata, Partition
from google.cloud.pubsublite_v1 import PubSubMessage


class PartitionCountWatchingPublisher(Publisher):
_publishers: Dict[Partition, Publisher]
_publisher_factory: Callable[[Partition], Publisher]
_policy_factory: Callable[[int], RoutingPolicy]
_watcher: PartitionCountWatcher
_partition_count_poller: asyncio.Future

def __init__(
self,
watcher: PartitionCountWatcher,
publisher_factory: Callable[[Partition], Publisher],
policy_factory: Callable[[int], RoutingPolicy],
):
self._publishers = {}
self._publisher_factory = publisher_factory
self._policy_factory = policy_factory
self._watcher = watcher

async def __aenter__(self):
try:
await self._watcher.__aenter__()
await self._poll_partition_count_action()
except Exception:
await self._watcher.__aexit__(*sys.exc_info())
raise
self._partition_count_poller = asyncio.ensure_future(
self._watch_partition_count()
)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._partition_count_poller.cancel()
await wait_ignore_cancelled(self._partition_count_poller)
await self._watcher.__aexit__(exc_type, exc_val, exc_tb)
for publisher in self._publishers.values():
await publisher.__aexit__(exc_type, exc_val, exc_tb)

async def _poll_partition_count_action(self):
partition_count = await self._watcher.get_partition_count()
await self._handle_partition_count_update(partition_count)

async def _watch_partition_count(self):
while True:
await self._poll_partition_count_action()

async def _handle_partition_count_update(self, partition_count: int):
current_count = len(self._publishers)
if current_count == partition_count:
return
if current_count > partition_count:
return

new_publishers = {
Partition(index): self._publisher_factory(Partition(index))
for index in range(current_count, partition_count)
}
await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run nox from the root directory before merging this.

routing_policy = self._policy_factory(partition_count)

self._publishers.update(new_publishers)
self._routing_policy = routing_policy

async def publish(self, message: PubSubMessage) -> PublishMetadata:
partition = self._routing_policy.route(message)
assert partition in self._publishers
publisher = self._publishers[partition]
return await publisher.publish(message)
@@ -0,0 +1,106 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import queue
import threading
from asynctest.mock import MagicMock
import pytest

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.testing.test_utils import Box
from google.cloud.pubsublite.types import Partition, TopicPath, CloudZone, CloudRegion
from google.api_core.exceptions import GoogleAPICallError

pytestmark = pytest.mark.asyncio


@pytest.fixture()
def mock_publishers():
return {Partition(i): MagicMock(spec=Publisher) for i in range(10)}


@pytest.fixture()
def topic():
zone = CloudZone(region=CloudRegion("a"), zone_id="a")
return TopicPath(project_number=1, location=zone, name="c")


@pytest.fixture()
def mock_admin():
admin = MagicMock(spec=AdminClientInterface)
return admin


@pytest.fixture()
def watcher(mock_admin, topic):
box = Box()

def set_box():
box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001)

# Initialize watcher on another thread with a different event loop.
thread = threading.Thread(target=set_box)
thread.start()
thread.join()
return box.val


async def test_init(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.return_value = 2
async with watcher:
pass


async def test_get_count_first_failure(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.side_effect = GoogleAPICallError("error")
with pytest.raises(GoogleAPICallError):
async with watcher:
await watcher.get_partition_count()


async def test_get_multiple_counts(watcher, mock_admin, topic):
q = queue.Queue()
mock_admin.get_topic_partition_count.side_effect = q.get
async with watcher:
task1 = asyncio.ensure_future(watcher.get_partition_count())
task2 = asyncio.ensure_future(watcher.get_partition_count())
assert not task1.done()
assert not task2.done()
q.put(3)
assert await task1 == 3
assert not task2.done()
q.put(4)
assert await task2 == 4


async def test_subsequent_failures_ignored(watcher, mock_admin, topic):
q = queue.Queue()

def side_effect():
value = q.get()
if isinstance(value, Exception):
raise value
return value

mock_admin.get_topic_partition_count.side_effect = lambda x: side_effect()
async with watcher:
q.put(3)
assert await watcher.get_partition_count() == 3
q.put(GoogleAPICallError("error"))
q.put(4)
assert await watcher.get_partition_count() == 4