Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement admin client. #17

Merged
merged 8 commits into from Sep 15, 2020
73 changes: 73 additions & 0 deletions google/cloud/pubsublite/admin_client.py
@@ -0,0 +1,73 @@
from abc import ABC, abstractmethod
from typing import List, Optional

from google.api_core.client_options import ClientOptions
from google.protobuf.field_mask_pb2 import FieldMask

from google.cloud.pubsublite.endpoints import regional_endpoint
from google.cloud.pubsublite.internal.wire.admin_client_impl import AdminClientImpl
from google.cloud.pubsublite.location import CloudRegion
from google.cloud.pubsublite.paths import TopicPath, LocationPath, SubscriptionPath
from google.cloud.pubsublite_v1 import Topic, Subscription, AdminServiceClient
from google.auth.credentials import Credentials


class AdminClient(ABC):
@abstractmethod
def region(self) -> CloudRegion:
"""The region this client is for."""

@abstractmethod
def create_topic(self, topic: Topic) -> Topic:
"""Create a topic, returns the created topic."""

@abstractmethod
def get_topic(self, topic_path: TopicPath) -> Topic:
"""Get the topic object from the server."""

@abstractmethod
def get_topic_partition_count(self, topic_path: TopicPath) -> int:
"""Get the number of partitions in the provided topic."""

@abstractmethod
def list_topics(self, location_path: LocationPath) -> List[Topic]:
"""List the Pub/Sub lite topics that exist for a project in a given location."""

@abstractmethod
def update_topic(self, topic: Topic, update_mask: FieldMask) -> Topic:
"""Update the masked fields of the provided topic."""

@abstractmethod
def delete_topic(self, topic_path: TopicPath):
"""Delete a topic and all associated messages."""

@abstractmethod
def list_topic_subscriptions(self, topic_path: TopicPath):
"""List the subscriptions that exist for a given topic."""

@abstractmethod
def create_subscription(self, subscription: Subscription) -> Subscription:
"""Create a subscription, returns the created subscription."""

@abstractmethod
def get_subscription(self, subscription_path: SubscriptionPath) -> Subscription:
"""Get the subscription object from the server."""

@abstractmethod
def list_subscriptions(self, location_path: LocationPath) -> List[Subscription]:
"""List the Pub/Sub lite subscriptions that exist for a project in a given location."""

@abstractmethod
def update_subscription(self, subscription: Subscription, update_mask: FieldMask) -> Subscription:
"""Update the masked fields of the provided subscription."""

@abstractmethod
def delete_subscription(self, subscription_path: SubscriptionPath):
"""Delete a subscription and all associated messages."""


def make_admin_client(region: CloudRegion, credentials: Optional[Credentials] = None,
client_options: Optional[ClientOptions] = None) -> AdminClient:
if client_options is None:
client_options = ClientOptions(api_endpoint=regional_endpoint(region))
return AdminClientImpl(AdminServiceClient(client_options=client_options, credentials=credentials), region)
61 changes: 61 additions & 0 deletions google/cloud/pubsublite/internal/wire/admin_client_impl.py
@@ -0,0 +1,61 @@
from typing import List

from google.protobuf.field_mask_pb2 import FieldMask

from google.cloud.pubsublite.admin_client import AdminClient
from google.cloud.pubsublite.location import CloudRegion
from google.cloud.pubsublite.paths import SubscriptionPath, LocationPath, TopicPath
from google.cloud.pubsublite_v1 import Subscription, Topic, AdminServiceClient, TopicPartitions


class AdminClientImpl(AdminClient):
_underlying: AdminServiceClient
_region: CloudRegion

def __init__(self, underlying: AdminServiceClient, region: CloudRegion):
self._underlying = underlying
self._region = region

def region(self) -> CloudRegion:
return self._region

def create_topic(self, topic: Topic) -> Topic:
path = TopicPath.parse(topic.name)
return self._underlying.create_topic(parent=str(path.to_location_path()), topic=topic, topic_id=path.name)

def get_topic(self, topic_path: TopicPath) -> Topic:
return self._underlying.get_topic(name=str(topic_path))

def get_topic_partition_count(self, topic_path: TopicPath) -> int:
partitions: TopicPartitions = self._underlying.get_topic_partitions(name=str(topic_path))
return partitions.partition_count

def list_topics(self, location_path: LocationPath) -> List[Topic]:
return [x for x in self._underlying.list_topics(parent=str(location_path))]

def update_topic(self, topic: Topic, update_mask: FieldMask) -> Topic:
return self._underlying.update_topic(topic=topic, update_mask=update_mask)

def delete_topic(self, topic_path: TopicPath):
self._underlying.delete_topic(name=str(topic_path))

def list_topic_subscriptions(self, topic_path: TopicPath):
subscription_strings = [x for x in self._underlying.list_topic_subscriptions(name=str(topic_path))]
return [SubscriptionPath.parse(x) for x in subscription_strings]

def create_subscription(self, subscription: Subscription) -> Subscription:
path = SubscriptionPath.parse(subscription.name)
return self._underlying.create_subscription(parent=str(path.to_location_path()), subscription=subscription,
subscription_id=path.name)

def get_subscription(self, subscription_path: SubscriptionPath) -> Subscription:
return self._underlying.get_subscription(name=str(subscription_path))

def list_subscriptions(self, location_path: LocationPath) -> List[Subscription]:
return [x for x in self._underlying.list_subscriptions(parent=str(location_path))]

def update_subscription(self, subscription: Subscription, update_mask: FieldMask) -> Subscription:
return self._underlying.update_subscription(subscription=subscription, update_mask=update_mask)

def delete_subscription(self, subscription_path: SubscriptionPath):
self._underlying.delete_subscription(name=str(subscription_path))
12 changes: 5 additions & 7 deletions google/cloud/pubsublite/internal/wire/make_publisher.py
@@ -1,5 +1,6 @@
from typing import AsyncIterator, Mapping, Optional, MutableMapping

from google.cloud.pubsublite.admin_client import make_admin_client
from google.cloud.pubsublite.endpoints import regional_endpoint
from google.cloud.pubsublite.internal.wire.default_routing_policy import DefaultRoutingPolicy
from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnectionFactory
Expand All @@ -12,8 +13,6 @@
from google.cloud.pubsublite.routing_metadata import topic_routing_metadata
from google.cloud.pubsublite_v1 import InitialPublishRequest, PublishRequest
from google.cloud.pubsublite_v1.services.publisher_service import async_client
from google.cloud.pubsublite_v1.services.admin_service.client import AdminServiceClient
from google.cloud.pubsublite_v1.types.admin import GetTopicPartitionsRequest
from google.api_core.client_options import ClientOptions
from google.auth.credentials import Credentials

Expand All @@ -40,17 +39,16 @@ def make_publisher(
Throws:
GoogleApiCallException on any error determining topic structure.
"""
admin_client = make_admin_client(region=topic.location.region, credentials=credentials, client_options=client_options)
if client_options is None:
client_options = ClientOptions(api_endpoint=regional_endpoint(topic.location.region))
client = async_client.PublisherServiceAsyncClient(
credentials=credentials, client_options=client_options) # type: ignore

admin_client = AdminServiceClient(credentials=credentials, client_options=client_options)
partitions = admin_client.get_topic_partitions(GetTopicPartitionsRequest(name=str(topic)))

clients: MutableMapping[Partition, Publisher] = {}

for partition in range(partitions.partition_count):
partition_count = admin_client.get_topic_partition_count(topic)
for partition in range(partition_count):
partition = Partition(partition)

def connection_factory(requests: AsyncIterator[PublishRequest]):
Expand All @@ -59,4 +57,4 @@ def connection_factory(requests: AsyncIterator[PublishRequest]):

clients[partition] = SinglePartitionPublisher(InitialPublishRequest(topic=str(topic), partition=partition.value),
batching_delay_secs, GapicConnectionFactory(connection_factory))
return RoutingPublisher(DefaultRoutingPolicy(partitions.partition_count), clients)
return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients)
10 changes: 10 additions & 0 deletions google/cloud/pubsublite/location.py
@@ -1,5 +1,7 @@
from typing import NamedTuple

from google.api_core.exceptions import InvalidArgument


class CloudRegion(NamedTuple):
name: str
Expand All @@ -11,3 +13,11 @@ class CloudZone(NamedTuple):

def __str__(self):
return f"{self.region.name}-{self.zone_id}"

@staticmethod
def parse(to_parse: str):
splits = to_parse.split('-')
if len(splits) != 3 or len(splits[2]) != 1:
raise InvalidArgument("Invalid zone name: " + to_parse)
region = CloudRegion(name=splits[0] + '-' + splits[1])
return CloudZone(region, zone_id=splits[2])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider checking splits[2] is a number

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nack. splits[2] is not a number, it is a single lowercase character. 'us-central1-a' for example.

44 changes: 44 additions & 0 deletions google/cloud/pubsublite/paths.py
@@ -1,8 +1,18 @@
from typing import NamedTuple

from google.api_core.exceptions import InvalidArgument

from google.cloud.pubsublite.location import CloudZone


class LocationPath(NamedTuple):
project_number: int
location: CloudZone

def __str__(self):
return f"projects/{self.project_number}/locations/{self.location}"


class TopicPath(NamedTuple):
project_number: int
location: CloudZone
Expand All @@ -11,6 +21,23 @@ class TopicPath(NamedTuple):
def __str__(self):
return f"projects/{self.project_number}/locations/{self.location}/topics/{self.name}"

def to_location_path(self):
return LocationPath(self.project_number, self.location)

@staticmethod
def parse(to_parse: str) -> "TopicPath":
splits = to_parse.split("/")
if len(splits) != 6 or splits[0] != "projects" or splits[2] != "locations" or splits[4] != "topics":
raise InvalidArgument(
"Topic path must be formatted like projects/{project_number}/locations/{location}/topics/{name} but was instead " + to_parse)
project_number: int
try:
project_number = int(splits[1])
except ValueError:
raise InvalidArgument(
"Topic path must be formatted like projects/{project_number}/locations/{location}/topics/{name} but was instead " + to_parse)
return TopicPath(project_number, CloudZone.parse(splits[3]), splits[5])


class SubscriptionPath(NamedTuple):
project_number: int
Expand All @@ -19,3 +46,20 @@ class SubscriptionPath(NamedTuple):

def __str__(self):
return f"projects/{self.project_number}/locations/{self.location}/subscriptions/{self.name}"

def to_location_path(self):
return LocationPath(self.project_number, self.location)

@staticmethod
def parse(to_parse: str) -> "SubscriptionPath":
splits = to_parse.split("/")
if len(splits) != 6 or splits[0] != "projects" or splits[2] != "locations" or splits[4] != "subscriptions":
raise InvalidArgument(
"Subscription path must be formatted like projects/{project_number}/locations/{location}/subscriptions/{name} but was instead " + to_parse)
project_number: int
try:
project_number = int(splits[1])
except ValueError:
raise InvalidArgument(
"Subscription path must be formatted like projects/{project_number}/locations/{location}/subscriptions/{name} but was instead " + to_parse)
return SubscriptionPath(project_number, CloudZone.parse(splits[3]), splits[5])