Skip to content

Commit

Permalink
refactor: Support nested resources (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg committed Dec 16, 2021
1 parent 0482d06 commit 321cf9e
Show file tree
Hide file tree
Showing 24 changed files with 413 additions and 513 deletions.
97 changes: 76 additions & 21 deletions google/cloud/aiplatform/base.py
Expand Up @@ -397,7 +397,6 @@ class VertexAiResourceNoun(metaclass=abc.ABCMeta):
Subclasses require two class attributes:
client_class: The client to instantiate to interact with this resource noun.
_is_client_prediction_client: Flag to indicate if the client requires a prediction endpoint.
Subclass is required to populate private attribute _gca_resource which is the
service representation of the resource noun.
Expand All @@ -414,29 +413,43 @@ def client_class(cls) -> Type[utils.VertexAiServiceClientWithOverride]:
@property
@classmethod
@abc.abstractmethod
def _is_client_prediction_client(cls) -> bool:
"""Flag to indicate whether to use prediction endpoint with client."""
pass

@property
@abc.abstractmethod
def _getter_method(cls) -> str:
"""Name of getter method of client class for retrieving the
resource."""
pass

@property
@classmethod
@abc.abstractmethod
def _delete_method(cls) -> str:
"""Name of delete method of client class for deleting the resource."""
pass

@property
@classmethod
@abc.abstractmethod
def _resource_noun(cls) -> str:
"""Resource noun."""
pass

@property
@classmethod
@abc.abstractmethod
def _parse_resource_name_method(cls) -> str:
"""Method name on GAPIC client to parse a resource name."""
pass

@property
@classmethod
@abc.abstractmethod
def _format_resource_name_method(self) -> str:
"""Method name on GAPIC client to format a resource name."""
pass

# Override this value with staticmethod
# to use custom resource id validators per resource
_resource_id_validator: Optional[Callable[[str], None]] = None

def __init__(
self,
project: Optional[str] = None,
Expand Down Expand Up @@ -486,15 +499,48 @@ def _instantiate_client(
client_class=cls.client_class,
credentials=credentials,
location_override=location,
prediction_client=cls._is_client_prediction_client,
)

@classmethod
def _parse_resource_name(cls, resource_name: str) -> Dict[str, str]:
"""
Parses resource name into its component segments.
Args:
resource_name: Resource name of this resource.
Returns:
Dictionary of component segments.
"""
# gets the underlying wrapped gapic client class
return getattr(
cls.client_class.get_gapic_client_class(), cls._parse_resource_name_method
)(resource_name)

@classmethod
def _format_resource_name(cls, **kwargs: str) -> str:
"""
Formats a resource name using its component segments.
Args:
**kwargs: Resource name parts. Singular and snake case. ie:
format_resource_name(
project='my-project',
location='us-central1'
)
Returns:
Resource name.
"""
# gets the underlying wrapped gapic client class
return getattr(
cls.client_class.get_gapic_client_class(), cls._format_resource_name_method
)(**kwargs)

def _get_and_validate_project_location(
self,
resource_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
) -> Tuple:
) -> Tuple[str, str]:

"""Validate the project and location for the resource.
Expand All @@ -507,33 +553,42 @@ def _get_and_validate_project_location(
RuntimeError: If location is different from resource location
"""

fields = utils.extract_fields_from_resource_name(
resource_name, self._resource_noun
)
fields = self._parse_resource_name(resource_name)

if not fields:
return project, location

if location and fields.location != location:
if location and fields["location"] != location:
raise RuntimeError(
f"location {location} is provided, but different from "
f"the resource location {fields.location}"
f"the resource location {fields['location']}"
)

return fields.project, fields.location
return fields["project"], fields["location"]

def _get_gca_resource(
self,
resource_name: str,
parent_resource_name_fields: Optional[Dict[str, str]] = None,
) -> proto.Message:
"""Returns GAPIC service representation of client class resource.
def _get_gca_resource(self, resource_name: str) -> proto.Message:
"""Returns GAPIC service representation of client class resource."""
"""
Args:
resource_name (str):
Required. A fully-qualified resource name or ID.
resource_name (str): Required. A fully-qualified resource name or ID.
parent_resource_name_fields (Dict[str,str]):
Optional. Mapping of parent resource name key to values. These
will be used to compose the resource name if only resource ID is given.
Should not include project and location.
"""

resource_name = utils.full_resource_name(
resource_name=resource_name,
resource_noun=self._resource_noun,
parse_resource_name_method=self._parse_resource_name,
format_resource_name_method=self._format_resource_name,
project=self.project,
location=self.location,
parent_resource_name_fields=parent_resource_name_fields,
resource_id_validator=self._resource_id_validator,
)

return getattr(self.api_client, self._getter_method)(
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/aiplatform/datasets/dataset.py
Expand Up @@ -39,11 +39,12 @@ class _Dataset(base.VertexAiResourceNounWithFutureManager):
"""Managed dataset resource for Vertex AI."""

client_class = utils.DatasetClientWithOverride
_is_client_prediction_client = False
_resource_noun = "datasets"
_getter_method = "get_dataset"
_list_method = "list_datasets"
_delete_method = "delete_dataset"
_parse_resource_name_method = "parse_dataset_path"
_format_resource_name_method = "dataset_path"

_supported_metadata_schema_uris: Tuple[str] = ()

Expand Down
50 changes: 29 additions & 21 deletions google/cloud/aiplatform/featurestore/entity_type.py
Expand Up @@ -36,10 +36,22 @@ class EntityType(base.VertexAiResourceNounWithFutureManager):
client_class = utils.FeaturestoreClientWithOverride

_is_client_prediction_client = False
_resource_noun = None
_resource_noun = "entityTypes"
_getter_method = "get_entity_type"
_list_method = "list_entity_types"
_delete_method = "delete_entity_type"
_parse_resource_name_method = "parse_entity_type_path"
_format_resource_name_method = "entity_type_path"

@staticmethod
def _resource_id_validator(resource_id: str):
"""Validates resource ID.
Args:
resource_id(str):
The resource id to validate.
"""
featurestore_utils.validate_id(resource_id)

def __init__(
self,
Expand Down Expand Up @@ -81,31 +93,26 @@ def __init__(
credentials set in aiplatform.init.
"""

(
featurestore_id,
_,
) = featurestore_utils.validate_and_get_entity_type_resource_ids(
entity_type_name=entity_type_name, featurestore_id=featurestore_id
)

# TODO(b/208269923): Temporary workaround, update when base class supports nested resource
self._resource_noun = f"featurestores/{featurestore_id}/entityTypes"

super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=entity_type_name,
)
self._gca_resource = self._get_gca_resource(resource_name=entity_type_name)
self._gca_resource = self._get_gca_resource(
resource_name=entity_type_name,
parent_resource_name_fields={
featurestore.Featurestore._resource_noun: featurestore_id
}
if featurestore_id
else featurestore_id,
)

@property
def featurestore_name(self) -> str:
"""Full qualified resource name of the managed featurestore in which this EntityType is."""
entity_type_name_components = featurestore_utils.CompatFeaturestoreServiceClient.parse_entity_type_path(
path=self.resource_name
)
return featurestore_utils.CompatFeaturestoreServiceClient.featurestore_path(
entity_type_name_components = self._parse_resource_name(self.resource_name)
return featurestore.Featurestore._format_resource_name(
project=entity_type_name_components["project"],
location=entity_type_name_components["location"],
featurestore=entity_type_name_components["featurestore"],
Expand All @@ -128,12 +135,10 @@ def get_feature(self, feature_id: str) -> "featurestore.Feature":
Returns:
featurestore.Feature - The managed feature resource object.
"""
entity_type_name_components = featurestore_utils.CompatFeaturestoreServiceClient.parse_entity_type_path(
path=self.resource_name
)
entity_type_name_components = self._parse_resource_name(self.resource_name)

return featurestore.Feature(
feature_name=featurestore_utils.CompatFeaturestoreServiceClient.feature_path(
feature_name=featurestore.Feature._format_resource_name(
project=entity_type_name_components["project"],
location=entity_type_name_components["location"],
featurestore=entity_type_name_components["featurestore"],
Expand Down Expand Up @@ -299,9 +304,12 @@ def list(
credentials=credentials,
parent=utils.full_resource_name(
resource_name=featurestore_name,
resource_noun="featurestores",
resource_noun=featurestore.Featurestore._resource_noun,
parse_resource_name_method=featurestore.Featurestore._parse_resource_name,
format_resource_name_method=featurestore.Featurestore._format_resource_name,
project=project,
location=location,
resource_id_validator=featurestore.Featurestore._resource_id_validator,
),
)

Expand Down
71 changes: 41 additions & 30 deletions google/cloud/aiplatform/featurestore/feature.py
Expand Up @@ -36,10 +36,22 @@ class Feature(base.VertexAiResourceNounWithFutureManager):
client_class = utils.FeaturestoreClientWithOverride

_is_client_prediction_client = False
_resource_noun = None
_resource_noun = "features"
_getter_method = "get_feature"
_list_method = "list_features"
_delete_method = "delete_feature"
_parse_resource_name_method = "parse_feature_path"
_format_resource_name_method = "feature_path"

@staticmethod
def _resource_id_validator(resource_id: str):
"""Validates resource ID.
Args:
resource_id(str):
The resource id to validate.
"""
featurestore_utils.validate_id(resource_id)

def __init__(
self,
Expand Down Expand Up @@ -83,38 +95,37 @@ def __init__(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this Feature. Overrides
credentials set in aiplatform.init.
Raises:
ValueError: If only one of featurestore_id or entity_type_id is provided.
"""
(
featurestore_id,
entity_type_id,
_,
) = featurestore_utils.validate_and_get_feature_resource_ids(
feature_name=feature_name,
entity_type_id=entity_type_id,
featurestore_id=featurestore_id,
)

# TODO(b/208269923): Temporary workaround, update when base class supports nested resource
self._resource_noun = (
f"featurestores/{featurestore_id}/entityTypes/{entity_type_id}/features"
)
if bool(featurestore_id) != bool(entity_type_id):
raise ValueError(
"featurestore_id and entity_type_id must both be provided or ommitted."
)

super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=feature_name,
)
self._gca_resource = self._get_gca_resource(resource_name=feature_name)
self._gca_resource = self._get_gca_resource(
resource_name=feature_name,
parent_resource_name_fields={
featurestore.Featurestore._resource_noun: featurestore_id,
featurestore.EntityType._resource_noun: entity_type_id,
}
if featurestore_id
else featurestore_id,
)

@property
def featurestore_name(self) -> str:
"""Full qualified resource name of the managed featurestore in which this Feature is."""
feature_path_components = featurestore_utils.CompatFeaturestoreServiceClient.parse_feature_path(
path=self.resource_name
)
feature_path_components = self._parse_resource_name(self.resource_name)

return featurestore_utils.CompatFeaturestoreServiceClient.featurestore_path(
return featurestore.Featurestore._format_resource_name(
project=feature_path_components["project"],
location=feature_path_components["location"],
featurestore=feature_path_components["featurestore"],
Expand All @@ -131,11 +142,9 @@ def get_featurestore(self) -> "featurestore.Featurestore":
@property
def entity_type_name(self) -> str:
"""Full qualified resource name of the managed entityType in which this Feature is."""
feature_path_components = featurestore_utils.CompatFeaturestoreServiceClient.parse_feature_path(
path=self.resource_name
)
feature_path_components = self._parse_resource_name(self.resource_name)

return featurestore_utils.CompatFeaturestoreServiceClient.entity_type_path(
return featurestore.EntityType._format_resource_name(
project=feature_path_components["project"],
location=feature_path_components["location"],
featurestore=feature_path_components["featurestore"],
Expand Down Expand Up @@ -303,12 +312,6 @@ def list(
Returns:
List[Feature] - A list of managed feature resource objects
"""
(
featurestore_id,
entity_type_id,
) = featurestore_utils.validate_and_get_entity_type_resource_ids(
entity_type_name=entity_type_name, featurestore_id=featurestore_id,
)

return cls._list(
filter=filter,
Expand All @@ -318,9 +321,17 @@ def list(
credentials=credentials,
parent=utils.full_resource_name(
resource_name=entity_type_name,
resource_noun=f"featurestores/{featurestore_id}/entityTypes",
resource_noun=featurestore.EntityType._resource_noun,
parse_resource_name_method=featurestore.EntityType._parse_resource_name,
format_resource_name_method=featurestore.EntityType._format_resource_name,
parent_resource_name_fields={
featurestore.Featurestore._resource_noun: featurestore_id
}
if featurestore_id
else featurestore_id,
project=project,
location=location,
resource_id_validator=featurestore.EntityType._resource_id_validator,
),
)

Expand Down

0 comments on commit 321cf9e

Please sign in to comment.