Skip to content

Commit

Permalink
feat: Add list() method to all resource nouns (#294)
Browse files Browse the repository at this point in the history
### [Colab for manual testing](https://colab.research.google.com/drive/14iLNaJEyZaPGebCgJZxgUJJeU9r7z_Fu)

### Summary of Changes
- Added a `_list_method` property to `AiPlatformResourceNoun` to store GAPIC method name for each noun
- Added a `create_time` and `update_time` property to `AiPlatformResourceNoun`
- Added a single `list()` method that takes four optional fields and returns a list of SDK types
    - All of the fields except `order_by` are available in every GAPIC list methods
    - Added local sorting for GAPIC list methods that do not take `order_by`
- Added 3 unit tests to check correct GAPIC calls and local sorting
- Added `aiplatform.init()` to test class setup and dropped it from some unit tests

Fixes [b/183498826](http://b/183498826) 🦕
  • Loading branch information
vinnysenthil committed Apr 11, 2021
1 parent 674227d commit 3ec9386
Show file tree
Hide file tree
Showing 10 changed files with 630 additions and 25 deletions.
235 changes: 233 additions & 2 deletions google/cloud/aiplatform/base.py
Expand Up @@ -17,12 +17,13 @@

import abc
from concurrent import futures
import datetime
import functools
import inspect
import threading
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union

import proto
import threading
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -249,6 +250,12 @@ def _getter_method(cls) -> str:
"""Name of getter method of client class for retrieving the resource."""
pass

@property
@abc.abstractmethod
def _list_method(cls) -> str:
"""Name of list method of client class for listing resources."""
pass

@property
@abc.abstractmethod
def _delete_method(cls) -> str:
Expand Down Expand Up @@ -385,6 +392,17 @@ def display_name(self) -> str:
"""Display name of this resource."""
return self._gca_resource.display_name

@property
def create_time(self) -> datetime.datetime:
"""Time this resource was created."""
return self._gca_resource.create_time

@property
def update_time(self) -> datetime.datetime:
"""Time this resource was last updated."""
self._sync_gca_resource()
return self._gca_resource.update_time

def __repr__(self) -> str:
return f"{object.__repr__(self)} \nresource name: {self.resource_name}"

Expand Down Expand Up @@ -617,6 +635,219 @@ def _sync_object_with_future_result(
if value:
setattr(self, attribute, value)

def _construct_sdk_resource_from_gapic(
self,
gapic_resource: proto.Message,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> AiPlatformResourceNoun:
"""Given a GAPIC resource object, return the SDK representation.
Args:
gapic_resource (proto.Message):
A GAPIC representation of an AI Platform resource, usually
retrieved by a get_* or in a list_* API call.
project (str):
Optional. Project to construct SDK object from. If not set,
project set in aiplatform.init will be used.
location (str):
Optional. Location to construct SDK object from. If not set,
location set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to construct SDK object.
Overrides credentials set in aiplatform.init.
Returns:
AiPlatformResourceNoun:
An initialized SDK object that represents GAPIC type.
"""
sdk_resource = self._empty_constructor(
project=project, location=location, credentials=credentials
)
sdk_resource._gca_resource = gapic_resource
return sdk_resource

# TODO(b/144545165): Improve documentation for list filtering once available
# TODO(b/184910159): Expose `page_size` field in list method
@classmethod
def _list(
cls,
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[AiPlatformResourceNoun]:
"""Private method to list all instances of this AI Platform Resource,
takes a `cls_filter` arg to filter to a particular SDK resource subclass.
Args:
cls_filter (Callable[[proto.Message], bool]):
A function that takes one argument, a GAPIC resource, and returns
a bool. If the function returns False, that resource will be
excluded from the returned list. Example usage:
cls_filter = lambda obj: obj.metadata in cls.valid_metadatas
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[AiPlatformResourceNoun] - A list of SDK resource objects
"""
self = cls._empty_constructor(
project=project, location=location, credentials=credentials
)

# Fetch credentials once and re-use for all `_empty_constructor()` calls
creds = initializer.global_config.credentials

resource_list_method = getattr(self.api_client, self._list_method)

list_request = {
"parent": initializer.global_config.common_location_path(
project=project, location=location
),
"filter": filter,
}

if order_by:
list_request["order_by"] = order_by

resource_list = resource_list_method(request=list_request) or []

return [
self._construct_sdk_resource_from_gapic(
gapic_resource, project=project, location=location, credentials=creds
)
for gapic_resource in resource_list
if cls_filter(gapic_resource)
]

@classmethod
def _list_with_local_order(
cls,
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[AiPlatformResourceNoun]:
"""Private method to list all instances of this AI Platform Resource,
takes a `cls_filter` arg to filter to a particular SDK resource subclass.
Provides client-side sorting when a list API doesn't support `order_by`.
Args:
cls_filter (Callable[[proto.Message], bool]):
A function that takes one argument, a GAPIC resource, and returns
a bool. If the function returns False, that resource will be
excluded from the returned list. Example usage:
cls_filter = lambda obj: obj.metadata in cls.valid_metadatas
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[AiPlatformResourceNoun] - A list of SDK resource objects
"""

li = cls._list(
cls_filter=cls_filter,
filter=filter,
order_by=None, # This method will handle the ordering locally
project=project,
location=location,
credentials=credentials,
)

desc = "desc" in order_by
order_by = order_by.replace("desc", "")
order_by = order_by.split(",")

li.sort(
key=lambda x: tuple(getattr(x, field.strip()) for field in order_by),
reverse=desc,
)

return li

@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[AiPlatformResourceNoun]:
"""List all instances of this AI Platform Resource.
Example Usage:
aiplatform.BatchPredictionJobs.list(
filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"',
)
aiplatform.Model.list(order_by="create_time desc, display_name")
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[AiPlatformResourceNoun] - A list of SDK resource objects
"""

return cls._list(
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)

@optional_sync()
def delete(self, sync: bool = True) -> None:
"""Deletes this AI Platform resource. WARNING: This deletion is permament.
Expand Down
59 changes: 57 additions & 2 deletions google/cloud/aiplatform/datasets/dataset.py
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import Optional, Sequence, Dict, Tuple, Union
from typing import Optional, Sequence, Dict, Tuple, Union, List

from google.api_core import operation
from google.auth import credentials as auth_credentials
Expand All @@ -40,9 +40,10 @@ class Dataset(base.AiPlatformResourceNounWithFutureManager):
_is_client_prediction_client = False
_resource_noun = "datasets"
_getter_method = "get_dataset"
_list_method = "list_datasets"
_delete_method = "delete_dataset"

_supported_metadata_schema_uris: Optional[Tuple[str]] = None
_supported_metadata_schema_uris: Tuple[str] = ()

def __init__(
self,
Expand Down Expand Up @@ -494,3 +495,57 @@ def export_data(self, output_dir: str) -> Sequence[str]:

def update(self):
raise NotImplementedError("Update dataset has not been implemented yet")

@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[base.AiPlatformResourceNoun]:
"""List all instances of this Dataset resource.
Example Usage:
aiplatform.TabularDataset.list(
filter='labels.my_key="my_value"',
order_by='display_name'
)
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[base.AiPlatformResourceNoun] - A list of Dataset resource objects
"""

dataset_subclass_filter = (
lambda gapic_obj: gapic_obj.metadata_schema_uri
in cls._supported_metadata_schema_uris
)

return cls._list_with_local_order(
cls_filter=dataset_subclass_filter,
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)

0 comments on commit 3ec9386

Please sign in to comment.