Skip to content

Commit

Permalink
feat: Add support for increasing partitions in python (#74)
Browse files Browse the repository at this point in the history
* Add support for increasing partitions in python

* updates to address comments

* addressing comments
  • Loading branch information
palmere-google committed Dec 14, 2020
1 parent b5ffc42 commit e117a1a
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -50,6 +50,7 @@ docs.metadata

# Virtual environment
env/
venv/
coverage.xml
sponge_log.xml

Expand Down
30 changes: 19 additions & 11 deletions google/cloud/pubsublite/internal/wire/make_publisher.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import AsyncIterator, Mapping, Optional, MutableMapping
from typing import AsyncIterator, Mapping, Optional

from google.cloud.pubsub_v1.types import BatchSettings

Expand All @@ -25,8 +25,13 @@
GapicConnectionFactory,
)
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import (
PartitionCountWatchingPublisher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_publisher import RoutingPublisher
from google.cloud.pubsublite.internal.wire.single_partition_publisher import (
SinglePartitionPublisher,
)
Expand All @@ -37,14 +42,14 @@
from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials


DEFAULT_BATCHING_SETTINGS = BatchSettings(
max_bytes=(
3 * 1024 * 1024
), # 3 MiB to stay 1 MiB below GRPC's 4 MiB per-message limit.
max_messages=1000,
max_latency=0.05, # 50 ms
)
DEFAULT_PARTITION_POLL_PERIOD = 600 # ten minutes


def make_publisher(
Expand Down Expand Up @@ -87,21 +92,24 @@ def make_publisher(
credentials=credentials, transport=transport, client_options=client_options
) # type: ignore

clients: MutableMapping[Partition, Publisher] = {}

partition_count = admin_client.get_topic_partition_count(topic)
for partition in range(partition_count):
partition = Partition(partition)

def publisher_factory(partition: Partition):
def connection_factory(requests: AsyncIterator[PublishRequest]):
final_metadata = merge_metadata(
metadata, topic_routing_metadata(topic, partition)
)
return client.publish(requests, metadata=list(final_metadata.items()))

clients[partition] = SinglePartitionPublisher(
return SinglePartitionPublisher(
InitialPublishRequest(topic=str(topic), partition=partition.value),
per_partition_batching_settings,
GapicConnectionFactory(connection_factory),
)
return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients)

def policy_factory(partition_count: int):
return DefaultRoutingPolicy(partition_count)

return PartitionCountWatchingPublisher(
PartitionCountWatcherImpl(admin_client, topic, DEFAULT_PARTITION_POLL_PERIOD),
publisher_factory,
policy_factory,
)
22 changes: 22 additions & 0 deletions google/cloud/pubsublite/internal/wire/partition_count_watcher.py
@@ -0,0 +1,22 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import AsyncContextManager


class PartitionCountWatcher(AsyncContextManager):
@abstractmethod
async def get_partition_count(self) -> int:
raise NotImplementedError()
@@ -0,0 +1,75 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from concurrent.futures.thread import ThreadPoolExecutor
import asyncio

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.types import TopicPath
from google.api_core.exceptions import GoogleAPICallError


class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable):
_admin: AdminClientInterface
_topic_path: TopicPath
_duration: float
_any_success: bool
_thread: ThreadPoolExecutor
_queue: asyncio.Queue
_poll_partition_loop: asyncio.Future

def __init__(
self, admin: AdminClientInterface, topic_path: TopicPath, duration: float
):
super().__init__()
self._admin = admin
self._topic_path = topic_path
self._duration = duration
self._any_success = False

async def __aenter__(self):
self._thread = ThreadPoolExecutor(max_workers=1)
self._queue = asyncio.Queue(maxsize=1)
self._poll_partition_loop = asyncio.ensure_future(
self.run_poller(self._poll_partition_loop)
)

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._poll_partition_loop.cancel()
await wait_ignore_cancelled(self._poll_partition_loop)
self._thread.shutdown(wait=False)

def _get_partition_count_sync(self) -> int:
return self._admin.get_topic_partition_count(self._topic_path)

async def _poll_partition_loop(self):
try:
partition_count = await asyncio.get_event_loop().run_in_executor(
self._thread, self._get_partition_count_sync
)
self._any_success = True
await self._queue.put(partition_count)
except GoogleAPICallError as e:
if not self._any_success:
raise e
logging.exception("Failed to retrieve partition count")
await asyncio.sleep(self._duration)

async def get_partition_count(self) -> int:
return await self.await_unless_failed(self._queue.get())
@@ -0,0 +1,94 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
from typing import Callable, Dict

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy
from google.cloud.pubsublite.types import PublishMetadata, Partition
from google.cloud.pubsublite_v1 import PubSubMessage


class PartitionCountWatchingPublisher(Publisher):
_publishers: Dict[Partition, Publisher]
_publisher_factory: Callable[[Partition], Publisher]
_policy_factory: Callable[[int], RoutingPolicy]
_watcher: PartitionCountWatcher
_partition_count_poller: asyncio.Future

def __init__(
self,
watcher: PartitionCountWatcher,
publisher_factory: Callable[[Partition], Publisher],
policy_factory: Callable[[int], RoutingPolicy],
):
self._publishers = {}
self._publisher_factory = publisher_factory
self._policy_factory = policy_factory
self._watcher = watcher

async def __aenter__(self):
try:
await self._watcher.__aenter__()
await self._poll_partition_count_action()
except Exception:
await self._watcher.__aexit__(*sys.exc_info())
raise
self._partition_count_poller = asyncio.ensure_future(
self._watch_partition_count()
)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._partition_count_poller.cancel()
await wait_ignore_cancelled(self._partition_count_poller)
await self._watcher.__aexit__(exc_type, exc_val, exc_tb)
for publisher in self._publishers.values():
await publisher.__aexit__(exc_type, exc_val, exc_tb)

async def _poll_partition_count_action(self):
partition_count = await self._watcher.get_partition_count()
await self._handle_partition_count_update(partition_count)

async def _watch_partition_count(self):
while True:
await self._poll_partition_count_action()

async def _handle_partition_count_update(self, partition_count: int):
current_count = len(self._publishers)
if current_count == partition_count:
return
if current_count > partition_count:
return

new_publishers = {
Partition(index): self._publisher_factory(Partition(index))
for index in range(current_count, partition_count)
}
await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()])
routing_policy = self._policy_factory(partition_count)

self._publishers.update(new_publishers)
self._routing_policy = routing_policy

async def publish(self, message: PubSubMessage) -> PublishMetadata:
partition = self._routing_policy.route(message)
assert partition in self._publishers
publisher = self._publishers[partition]
return await publisher.publish(message)
16 changes: 15 additions & 1 deletion google/cloud/pubsublite/testing/test_utils.py
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import asyncio
from typing import List, Union, Any, TypeVar, Generic, Optional
import threading
from typing import List, Union, Any, TypeVar, Generic, Optional, Callable

from asynctest import CoroutineMock

Expand Down Expand Up @@ -62,3 +63,16 @@ def wire_queues(mock: CoroutineMock) -> QueuePair:

class Box(Generic[T]):
val: Optional[T]


def run_on_thread(func: Callable[[], T]) -> T:
box = Box()

def set_box():
box.val = func()

# Initialize watcher on another thread with a different event loop.
thread = threading.Thread(target=set_box)
thread.start()
thread.join()
return box.val
@@ -0,0 +1,95 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import queue
from asynctest.mock import MagicMock
import pytest

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.testing.test_utils import run_on_thread
from google.cloud.pubsublite.types import Partition, TopicPath
from google.api_core.exceptions import GoogleAPICallError

pytestmark = pytest.mark.asyncio


@pytest.fixture()
def mock_publishers():
return {Partition(i): MagicMock(spec=Publisher) for i in range(10)}


@pytest.fixture()
def topic():
return TopicPath.parse("projects/1/locations/us-central1-a/topics/topic")


@pytest.fixture()
def mock_admin():
admin = MagicMock(spec=AdminClientInterface)
return admin


@pytest.fixture()
def watcher(mock_admin, topic):
return run_on_thread(lambda: PartitionCountWatcherImpl(mock_admin, topic, 0.001))


async def test_init(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.return_value = 2
async with watcher:
pass


async def test_get_count_first_failure(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.side_effect = GoogleAPICallError("error")
with pytest.raises(GoogleAPICallError):
async with watcher:
await watcher.get_partition_count()


async def test_get_multiple_counts(watcher, mock_admin, topic):
q = queue.Queue()
mock_admin.get_topic_partition_count.side_effect = q.get
async with watcher:
task1 = asyncio.ensure_future(watcher.get_partition_count())
task2 = asyncio.ensure_future(watcher.get_partition_count())
assert not task1.done()
assert not task2.done()
q.put(3)
assert await task1 == 3
assert not task2.done()
q.put(4)
assert await task2 == 4


async def test_subsequent_failures_ignored(watcher, mock_admin, topic):
q = queue.Queue()

def side_effect():
value = q.get()
if isinstance(value, Exception):
raise value
return value

mock_admin.get_topic_partition_count.side_effect = lambda x: side_effect()
async with watcher:
q.put(3)
assert await watcher.get_partition_count() == 3
q.put(GoogleAPICallError("error"))
q.put(4)
assert await watcher.get_partition_count() == 4

0 comments on commit e117a1a

Please sign in to comment.