Skip to content

Commit

Permalink
feat: Implement admin client. (#17)
Browse files Browse the repository at this point in the history
* feat: Implement AdminClient, which helps users perform admin operations in a given region.
  • Loading branch information
dpcollins-google committed Sep 15, 2020
1 parent 697df4a commit 3068da5
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 7 deletions.
73 changes: 73 additions & 0 deletions google/cloud/pubsublite/admin_client.py
@@ -0,0 +1,73 @@
from abc import ABC, abstractmethod
from typing import List, Optional

from google.api_core.client_options import ClientOptions
from google.protobuf.field_mask_pb2 import FieldMask

from google.cloud.pubsublite.endpoints import regional_endpoint
from google.cloud.pubsublite.internal.wire.admin_client_impl import AdminClientImpl
from google.cloud.pubsublite.location import CloudRegion
from google.cloud.pubsublite.paths import TopicPath, LocationPath, SubscriptionPath
from google.cloud.pubsublite_v1 import Topic, Subscription, AdminServiceClient
from google.auth.credentials import Credentials


class AdminClient(ABC):
@abstractmethod
def region(self) -> CloudRegion:
"""The region this client is for."""

@abstractmethod
def create_topic(self, topic: Topic) -> Topic:
"""Create a topic, returns the created topic."""

@abstractmethod
def get_topic(self, topic_path: TopicPath) -> Topic:
"""Get the topic object from the server."""

@abstractmethod
def get_topic_partition_count(self, topic_path: TopicPath) -> int:
"""Get the number of partitions in the provided topic."""

@abstractmethod
def list_topics(self, location_path: LocationPath) -> List[Topic]:
"""List the Pub/Sub lite topics that exist for a project in a given location."""

@abstractmethod
def update_topic(self, topic: Topic, update_mask: FieldMask) -> Topic:
"""Update the masked fields of the provided topic."""

@abstractmethod
def delete_topic(self, topic_path: TopicPath):
"""Delete a topic and all associated messages."""

@abstractmethod
def list_topic_subscriptions(self, topic_path: TopicPath):
"""List the subscriptions that exist for a given topic."""

@abstractmethod
def create_subscription(self, subscription: Subscription) -> Subscription:
"""Create a subscription, returns the created subscription."""

@abstractmethod
def get_subscription(self, subscription_path: SubscriptionPath) -> Subscription:
"""Get the subscription object from the server."""

@abstractmethod
def list_subscriptions(self, location_path: LocationPath) -> List[Subscription]:
"""List the Pub/Sub lite subscriptions that exist for a project in a given location."""

@abstractmethod
def update_subscription(self, subscription: Subscription, update_mask: FieldMask) -> Subscription:
"""Update the masked fields of the provided subscription."""

@abstractmethod
def delete_subscription(self, subscription_path: SubscriptionPath):
"""Delete a subscription and all associated messages."""


def make_admin_client(region: CloudRegion, credentials: Optional[Credentials] = None,
client_options: Optional[ClientOptions] = None) -> AdminClient:
if client_options is None:
client_options = ClientOptions(api_endpoint=regional_endpoint(region))
return AdminClientImpl(AdminServiceClient(client_options=client_options, credentials=credentials), region)
61 changes: 61 additions & 0 deletions google/cloud/pubsublite/internal/wire/admin_client_impl.py
@@ -0,0 +1,61 @@
from typing import List

from google.protobuf.field_mask_pb2 import FieldMask

from google.cloud.pubsublite.admin_client import AdminClient
from google.cloud.pubsublite.location import CloudRegion
from google.cloud.pubsublite.paths import SubscriptionPath, LocationPath, TopicPath
from google.cloud.pubsublite_v1 import Subscription, Topic, AdminServiceClient, TopicPartitions


class AdminClientImpl(AdminClient):
_underlying: AdminServiceClient
_region: CloudRegion

def __init__(self, underlying: AdminServiceClient, region: CloudRegion):
self._underlying = underlying
self._region = region

def region(self) -> CloudRegion:
return self._region

def create_topic(self, topic: Topic) -> Topic:
path = TopicPath.parse(topic.name)
return self._underlying.create_topic(parent=str(path.to_location_path()), topic=topic, topic_id=path.name)

def get_topic(self, topic_path: TopicPath) -> Topic:
return self._underlying.get_topic(name=str(topic_path))

def get_topic_partition_count(self, topic_path: TopicPath) -> int:
partitions: TopicPartitions = self._underlying.get_topic_partitions(name=str(topic_path))
return partitions.partition_count

def list_topics(self, location_path: LocationPath) -> List[Topic]:
return [x for x in self._underlying.list_topics(parent=str(location_path))]

def update_topic(self, topic: Topic, update_mask: FieldMask) -> Topic:
return self._underlying.update_topic(topic=topic, update_mask=update_mask)

def delete_topic(self, topic_path: TopicPath):
self._underlying.delete_topic(name=str(topic_path))

def list_topic_subscriptions(self, topic_path: TopicPath):
subscription_strings = [x for x in self._underlying.list_topic_subscriptions(name=str(topic_path))]
return [SubscriptionPath.parse(x) for x in subscription_strings]

def create_subscription(self, subscription: Subscription) -> Subscription:
path = SubscriptionPath.parse(subscription.name)
return self._underlying.create_subscription(parent=str(path.to_location_path()), subscription=subscription,
subscription_id=path.name)

def get_subscription(self, subscription_path: SubscriptionPath) -> Subscription:
return self._underlying.get_subscription(name=str(subscription_path))

def list_subscriptions(self, location_path: LocationPath) -> List[Subscription]:
return [x for x in self._underlying.list_subscriptions(parent=str(location_path))]

def update_subscription(self, subscription: Subscription, update_mask: FieldMask) -> Subscription:
return self._underlying.update_subscription(subscription=subscription, update_mask=update_mask)

def delete_subscription(self, subscription_path: SubscriptionPath):
self._underlying.delete_subscription(name=str(subscription_path))
12 changes: 5 additions & 7 deletions google/cloud/pubsublite/internal/wire/make_publisher.py
@@ -1,5 +1,6 @@
from typing import AsyncIterator, Mapping, Optional, MutableMapping

from google.cloud.pubsublite.admin_client import make_admin_client
from google.cloud.pubsublite.endpoints import regional_endpoint
from google.cloud.pubsublite.internal.wire.default_routing_policy import DefaultRoutingPolicy
from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnectionFactory
Expand All @@ -12,8 +13,6 @@
from google.cloud.pubsublite.routing_metadata import topic_routing_metadata
from google.cloud.pubsublite_v1 import InitialPublishRequest, PublishRequest
from google.cloud.pubsublite_v1.services.publisher_service import async_client
from google.cloud.pubsublite_v1.services.admin_service.client import AdminServiceClient
from google.cloud.pubsublite_v1.types.admin import GetTopicPartitionsRequest
from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials

Expand All @@ -40,17 +39,16 @@ def make_publisher(
Throws:
GoogleApiCallException on any error determining topic structure.
"""
admin_client = make_admin_client(region=topic.location.region, credentials=credentials, client_options=client_options)
if client_options is None:
client_options = ClientOptions(api_endpoint=regional_endpoint(topic.location.region))
client = async_client.PublisherServiceAsyncClient(
credentials=credentials, client_options=client_options) # type: ignore

admin_client = AdminServiceClient(credentials=credentials, client_options=client_options)
partitions = admin_client.get_topic_partitions(GetTopicPartitionsRequest(name=str(topic)))

clients: MutableMapping[Partition, Publisher] = {}

for partition in range(partitions.partition_count):
partition_count = admin_client.get_topic_partition_count(topic)
for partition in range(partition_count):
partition = Partition(partition)

def connection_factory(requests: AsyncIterator[PublishRequest]):
Expand All @@ -59,4 +57,4 @@ def connection_factory(requests: AsyncIterator[PublishRequest]):

clients[partition] = SinglePartitionPublisher(InitialPublishRequest(topic=str(topic), partition=partition.value),
batching_delay_secs, GapicConnectionFactory(connection_factory))
return RoutingPublisher(DefaultRoutingPolicy(partitions.partition_count), clients)
return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients)
10 changes: 10 additions & 0 deletions google/cloud/pubsublite/location.py
@@ -1,5 +1,7 @@
from typing import NamedTuple

from google.api_core.exceptions import InvalidArgument


class CloudRegion(NamedTuple):
name: str
Expand All @@ -11,3 +13,11 @@ class CloudZone(NamedTuple):

def __str__(self):
return f"{self.region.name}-{self.zone_id}"

@staticmethod
def parse(to_parse: str):
splits = to_parse.split('-')
if len(splits) != 3 or len(splits[2]) != 1:
raise InvalidArgument("Invalid zone name: " + to_parse)
region = CloudRegion(name=splits[0] + '-' + splits[1])
return CloudZone(region, zone_id=splits[2])
44 changes: 44 additions & 0 deletions google/cloud/pubsublite/paths.py
@@ -1,8 +1,18 @@
from typing import NamedTuple

from google.api_core.exceptions import InvalidArgument

from google.cloud.pubsublite.location import CloudZone


class LocationPath(NamedTuple):
project_number: int
location: CloudZone

def __str__(self):
return f"projects/{self.project_number}/locations/{self.location}"


class TopicPath(NamedTuple):
project_number: int
location: CloudZone
Expand All @@ -11,6 +21,23 @@ class TopicPath(NamedTuple):
def __str__(self):
return f"projects/{self.project_number}/locations/{self.location}/topics/{self.name}"

def to_location_path(self):
return LocationPath(self.project_number, self.location)

@staticmethod
def parse(to_parse: str) -> "TopicPath":
splits = to_parse.split("/")
if len(splits) != 6 or splits[0] != "projects" or splits[2] != "locations" or splits[4] != "topics":
raise InvalidArgument(
"Topic path must be formatted like projects/{project_number}/locations/{location}/topics/{name} but was instead " + to_parse)
project_number: int
try:
project_number = int(splits[1])
except ValueError:
raise InvalidArgument(
"Topic path must be formatted like projects/{project_number}/locations/{location}/topics/{name} but was instead " + to_parse)
return TopicPath(project_number, CloudZone.parse(splits[3]), splits[5])


class SubscriptionPath(NamedTuple):
project_number: int
Expand All @@ -19,3 +46,20 @@ class SubscriptionPath(NamedTuple):

def __str__(self):
return f"projects/{self.project_number}/locations/{self.location}/subscriptions/{self.name}"

def to_location_path(self):
return LocationPath(self.project_number, self.location)

@staticmethod
def parse(to_parse: str) -> "SubscriptionPath":
splits = to_parse.split("/")
if len(splits) != 6 or splits[0] != "projects" or splits[2] != "locations" or splits[4] != "subscriptions":
raise InvalidArgument(
"Subscription path must be formatted like projects/{project_number}/locations/{location}/subscriptions/{name} but was instead " + to_parse)
project_number: int
try:
project_number = int(splits[1])
except ValueError:
raise InvalidArgument(
"Subscription path must be formatted like projects/{project_number}/locations/{location}/subscriptions/{name} but was instead " + to_parse)
return SubscriptionPath(project_number, CloudZone.parse(splits[3]), splits[5])

0 comments on commit 3068da5

Please sign in to comment.