Skip to content

Commit

Permalink
Add support for increasing partitions in python
Browse files Browse the repository at this point in the history
  • Loading branch information
palmere-google committed Dec 10, 2020
1 parent 9d81d52 commit ca3b545
Show file tree
Hide file tree
Showing 6 changed files with 461 additions and 9 deletions.
30 changes: 21 additions & 9 deletions google/cloud/pubsublite/internal/wire/make_publisher.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import AsyncIterator, Mapping, Optional, MutableMapping

from google.cloud.pubsub_v1.types import BatchSettings
Expand All @@ -25,6 +26,15 @@
GapicConnectionFactory,
)
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import (
PartitionCountWatchingPublisher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_publisher import RoutingPublisher
from google.cloud.pubsublite.internal.wire.single_partition_publisher import (
Expand All @@ -37,7 +47,6 @@
from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials


DEFAULT_BATCHING_SETTINGS = BatchSettings(
max_bytes=(
3 * 1024 * 1024
Expand Down Expand Up @@ -87,21 +96,24 @@ def make_publisher(
credentials=credentials, transport=transport, client_options=client_options
) # type: ignore

clients: MutableMapping[Partition, Publisher] = {}

partition_count = admin_client.get_topic_partition_count(topic)
for partition in range(partition_count):
partition = Partition(partition)

def publisher_factory(partition: Partition):
def connection_factory(requests: AsyncIterator[PublishRequest]):
final_metadata = merge_metadata(
metadata, topic_routing_metadata(topic, partition)
)
return client.publish(requests, metadata=list(final_metadata.items()))

clients[partition] = SinglePartitionPublisher(
return SinglePartitionPublisher(
InitialPublishRequest(topic=str(topic), partition=partition.value),
per_partition_batching_settings,
GapicConnectionFactory(connection_factory),
)
return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients)

def policy_factory(partition_count: int):
return DefaultRoutingPolicy(partition_count)

return PartitionCountWatchingPublisher(
PartitionCountWatcherImpl(admin_client, topic, 10),
publisher_factory,
policy_factory,
)
22 changes: 22 additions & 0 deletions google/cloud/pubsublite/internal/wire/partition_count_watcher.py
@@ -0,0 +1,22 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import AsyncContextManager


class PartitionCountWatcher(AsyncContextManager):
@abstractmethod
async def get_partition_count(self) -> int:
raise NotImplementedError()
@@ -0,0 +1,67 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from concurrent.futures.thread import ThreadPoolExecutor
import asyncio

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.types import TopicPath
from google.api_core.exceptions import GoogleAPICallError


class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable):
def __init__(
self, admin: AdminClientInterface, topic_path: TopicPath, duration: float
):
super().__init__()
self._admin = admin
self._topic_path = topic_path
self._duration = duration
self._any_success = False

async def __aenter__(self):
self._thread = ThreadPoolExecutor(max_workers=1)
self._queue = asyncio.Queue(maxsize=1)
self._poll_partition_loop = asyncio.ensure_future(
self.run_poller(self._poll_partition_loop)
)

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._poll_partition_loop.cancel()
await wait_ignore_cancelled(self._poll_partition_loop)
self._thread.shutdown(wait=False)

def _get_partition_count_sync(self) -> int:
return self._admin.get_topic_partition_count(self._topic_path)

async def _poll_partition_loop(self):
try:
partition_count = await asyncio.get_event_loop().run_in_executor(
self._thread, self._get_partition_count_sync
)
self._any_success = True
await self._queue.put(partition_count)
except GoogleAPICallError as e:
if not self._any_success:
raise e
await asyncio.sleep(self._duration)

async def get_partition_count(self) -> int:
return await self.await_unless_failed(self._queue.get())
@@ -0,0 +1,95 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
import threading
from typing import Callable

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy
from google.cloud.pubsublite.types import PublishMetadata, Partition
from google.cloud.pubsublite_v1 import PubSubMessage


class PartitionCountWatchingPublisher(Publisher):
def __init__(
self,
watcher: PartitionCountWatcher,
publisher_factory: Callable[[Partition], Publisher],
policy_factory: Callable[[int], RoutingPolicy],
):
self._publishers = {}
self._lock = threading.Lock()
self._publisher_factory = publisher_factory
self._policy_factory = policy_factory
self._watcher = watcher

async def __aenter__(self):
try:
await self._watcher.__aenter__()
await self._poll_partition_count_action()
except Exception:
await self._watcher.__aexit__(*sys.exc_info())
raise
self._partition_count_poller = asyncio.ensure_future(
self._watch_partition_count()
)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._partition_count_poller.cancel()
await wait_ignore_cancelled(self._partition_count_poller)
await self._watcher.__aexit__(exc_type, exc_val, exc_tb)
with self._lock:
for publisher in self._publishers.values():
await publisher.__aexit__(exc_type, exc_val, exc_tb)

async def _poll_partition_count_action(self):
partition_count = await self._watcher.get_partition_count()
await self._handle_partition_count_update(partition_count)

async def _watch_partition_count(self):
while True:
await self._poll_partition_count_action()

async def _handle_partition_count_update(self, partition_count: int):
with self._lock:
current_count = len(self._publishers)
if current_count == partition_count:
return
if current_count > partition_count:
return

new_publishers = {
Partition(index): self._publisher_factory(Partition(index))
for index in range(current_count, partition_count)
}
await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()])
routing_policy = self._policy_factory(partition_count)

with self._lock:
self._publishers.update(new_publishers)
self._routing_policy = routing_policy

async def publish(self, message: PubSubMessage) -> PublishMetadata:
with self._lock:
partition = self._routing_policy.route(message)
assert partition in self._publishers
publisher = self._publishers[partition]

return await publisher.publish(message)
@@ -0,0 +1,116 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import concurrent
import queue
import threading
from asynctest.mock import MagicMock, CoroutineMock
import pytest

from google.cloud.pubsublite import AdminClientInterface
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
PartitionCountWatcher,
)
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
PartitionCountWatcherImpl,
)
from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import (
PartitionCountWatchingPublisher,
)
from google.cloud.pubsublite.internal.wire.publisher import Publisher
from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy
from google.cloud.pubsublite.testing.test_utils import Box, wire_queues
from google.cloud.pubsublite.types import Partition, TopicPath, CloudZone, CloudRegion
from google.cloud.pubsublite_v1 import PubSubMessage
from google.api_core.exceptions import GoogleAPICallError
from google.api_core.exceptions import GoogleAPICallError

pytestmark = pytest.mark.asyncio


@pytest.fixture()
def mock_publishers():
return {Partition(i): MagicMock(spec=Publisher) for i in range(10)}


@pytest.fixture()
def topic():
zone = CloudZone(region=CloudRegion("a"), zone_id="a")
return TopicPath(project_number=1, location=zone, name="c")


@pytest.fixture()
def mock_admin():
admin = MagicMock(spec=AdminClientInterface)
return admin


@pytest.fixture()
def watcher(mock_admin, topic):
box = Box()

def set_box():
box.val = PartitionCountWatcherImpl(mock_admin, topic, 0.001)

# Initialize publisher on another thread with a different event loop.
thread = threading.Thread(target=set_box)
thread.start()
thread.join()
return box.val


async def test_init(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.return_value = 2
async with watcher:
pass


async def test_get_count_first_failure(watcher, mock_admin, topic):
mock_admin.get_topic_partition_count.side_effect = GoogleAPICallError("error")
with pytest.raises(GoogleAPICallError):
async with watcher:
await watcher.get_partition_count()


async def test_get_multiple_counts(watcher, mock_admin, topic):
q = queue.Queue()
mock_admin.get_topic_partition_count.side_effect = q.get
async with watcher:
task1 = asyncio.ensure_future(watcher.get_partition_count())
task2 = asyncio.ensure_future(watcher.get_partition_count())
assert not task1.done()
assert not task2.done()
q.put(3)
assert await task1 == 3
assert not task2.done()
q.put(4)
assert await task2 == 4


async def test_subsequent_failures_ignored(watcher, mock_admin, topic):
q = queue.Queue()

def side_effect():
value = q.get()
if isinstance(value, Exception):
raise value
return value

mock_admin.get_topic_partition_count.side_effect = lambda x: side_effect()
async with watcher:
q.put(3)
assert await watcher.get_partition_count() == 3
q.put(GoogleAPICallError("error"))
q.put(4)
assert await watcher.get_partition_count() == 4

0 comments on commit ca3b545

Please sign in to comment.