Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add list() method to all resource nouns #294

Merged
merged 12 commits into from Apr 11, 2021
143 changes: 139 additions & 4 deletions google/cloud/aiplatform/base.py
Expand Up @@ -17,16 +17,18 @@

import abc
from concurrent import futures
import datetime
import functools
import inspect
import proto
import threading
from typing import Any, Callable, Dict, Optional, Sequence, Union

import proto

from google.auth import credentials as auth_credentials
from google.cloud import aiplatform
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import initializer
from google.protobuf import field_mask_pb2 as field_mask


class FutureManager(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -232,7 +234,7 @@ class AiPlatformResourceNoun(metaclass=abc.ABCMeta):
@property
@classmethod
@abc.abstractmethod
def client_class(cls) -> utils.AiPlatformServiceClient:
def client_class(cls) -> "utils.AiPlatformServiceClient":
"""Client class required to interact with resource."""
pass

Expand All @@ -249,6 +251,12 @@ def _getter_method(cls) -> str:
"""Name of getter method of client class for retrieving the resource."""
pass

@property
@abc.abstractmethod
def _list_method(cls) -> str:
"""Name of list method of client class for listing resources."""
pass

@property
@abc.abstractmethod
def _delete_method(cls) -> str:
Expand Down Expand Up @@ -287,7 +295,7 @@ def _instantiate_client(
cls,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> utils.AiPlatformServiceClient:
) -> "utils.AiPlatformServiceClient":
"""Helper method to instantiate service client for resource noun.

Args:
Expand Down Expand Up @@ -343,6 +351,16 @@ def display_name(self) -> str:
"""Display name of this resource."""
return self._gca_resource.display_name

@property
def create_time(self) -> datetime.datetime:
"""Time this resource was created."""
return self._gca_resource.create_time

@property
def update_time(self) -> datetime.datetime:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"""Time this resource was last updated."""
return self._gca_resource.update_time

def __repr__(self) -> str:
return f"{object.__repr__(self)} \nresource name: {self.resource_name}"

Expand Down Expand Up @@ -561,6 +579,123 @@ def _sync_object_with_future_result(
if value:
setattr(self, attribute, value)

def _construct_sdk_resource_from_gapic(
self,
gapic_resource: proto.Message,
credentials: Optional[auth_credentials.Credentials] = None,
) -> AiPlatformResourceNoun:
"""Given a GAPIC object, return the SDK representation."""
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
sdk_resource = self._empty_constructor(credentials=credentials)
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
sdk_resource._gca_resource = gapic_resource
return sdk_resource

# TODO(b/144545165) - Improve documentation for list filtering once available
@classmethod
def list(
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
page_size: Optional[int] = None,
read_mask: Optional[field_mask.FieldMask] = None,
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
) -> Sequence[AiPlatformResourceNoun]:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"""List all instances of this AI Platform Resource.

Example Usage:

aiplatform.BatchPredictionJobs.list(
filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"',
)

aiplatform.Model.list(order_by="create_time desc, display_name")

Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
Supported fields: `display_name`, `create_time`, `update_time`
page_size (int):
Optional. The standard list page size.
read_mask (field_mask.FieldMask):
Optional. Mask specifying which fields to read.

Returns:
Sequence[AiPlatformResourceNoun] - A list of SDK resource objects
"""
_UNSUPPORTED_LIST_ORDER_BY_TYPES = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strong preference to break down this conditional logic at the subclass level instead.

One option is to have two private list classes and pass in the filtering attributes.

class AiPlatformResourceNoun:
  def _list(..., 
                order_by: Optional[str],
                gapic_field_filter_name: Optional[str]:None,
                cls_field_filter_name: Optional[str]: None):
    ...
    cls_filter_schema = getattr(cls, cls_field_filter_schema, None) if cls_field_filter_name else set([])
    final_list = [
                self._construct_sdk_resource_from_gapic(
                    gapic_resource, credentials=creds
                )
                for gapic_resource in resource_list
                if gapic_field_filter_key and getattr(gapic_resource, gapic_field_filter_key)
                in cls_filter_schema
            ]
  def _list_with_local_order(...,order_by: Optional[str]):
     li = cls._list(..., order_by=None)
     # order here
     return li
     
  def list(...):
    return cls._list(...)
     
class _TrainingJob:
   def list(...):
      return cls._list_with_local_order(
        ...,
        order,
        gapic_field_filter_name='training_task_definition',
        cls_field_filter_name='_supported_training_schemas') # could just pass in the cls attribute as well

This will be more extensible in the future and allow external teams to use the list functionality without needing to change the base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting together this example! Adopting that option but will have the private _list method take a single cls_filter Callable that takes a gca_resource returns a bool that decides whether to include or exclude a given GAPIC object. Lmk if you have any concerns with that approach.

aiplatform.jobs._Job,
aiplatform.models.Endpoint,
aiplatform.models.Model,
aiplatform.training_jobs._TrainingJob,
)

self = cls._empty_constructor()

creds = initializer.global_config.credentials

resource_list_method = getattr(self.api_client, self._list_method)
order_locally = False

list_request = {
"parent": initializer.global_config.common_location_path(),
"filter": filter,
"page_size": page_size,
"read_mask": read_mask,
}

# If list method does not offer `order_by` field, order locally
if order_by and issubclass(type(self), _UNSUPPORTED_LIST_ORDER_BY_TYPES):
order_locally = True
elif order_by:
list_request["order_by"] = order_by

resource_list = resource_list_method(request=list_request) or []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the empty list needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into a few rare instances where the service returned a None instead of an empty list, this just serves as a fail safe against that.


# Only return objects specific to the calling subclass,
# for example TabularDataset.list() only lists TabularDatasets
if issubclass(type(self), aiplatform.datasets.Dataset):
final_list = [
self._construct_sdk_resource_from_gapic(
gapic_resource, credentials=creds
)
for gapic_resource in resource_list
if gapic_resource.metadata_schema_uri
in self._supported_metadata_schema_uris
]

elif issubclass(type(self), aiplatform.training_jobs._TrainingJob):
final_list = [
self._construct_sdk_resource_from_gapic(
gapic_resource, credentials=creds
)
for gapic_resource in resource_list
if gapic_resource.training_task_definition
in self._supported_training_schemas
]

else:
final_list = [
self._construct_sdk_resource_from_gapic(
gapic_resource, credentials=creds
)
for gapic_resource in resource_list
]

# Client-side sorting when API doesn't support `order_by`
if order_locally:
desc = "desc" in order_by
order_by = order_by.replace("desc", "")
order_by = order_by.split(",")

final_list.sort(
key=lambda x: tuple(getattr(x, field.strip()) for field in order_by),
reverse=desc,
)

return final_list

@optional_sync()
def delete(self, sync: bool = True) -> None:
"""Deletes this AI Platform resource. WARNING: This deletion is permament.
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/aiplatform/datasets/dataset.py
Expand Up @@ -40,9 +40,10 @@ class Dataset(base.AiPlatformResourceNounWithFutureManager):
_is_client_prediction_client = False
_resource_noun = "datasets"
_getter_method = "get_dataset"
_list_method = "list_datasets"
_delete_method = "delete_dataset"

_supported_metadata_schema_uris: Optional[Tuple[str]] = None
_supported_metadata_schema_uris: Optional[Tuple[str]] = ()
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -176,6 +176,7 @@ class BatchPredictionJob(_Job):

_resource_noun = "batchPredictionJobs"
_getter_method = "get_batch_prediction_job"
_list_method = "list_batch_prediction_jobs"
_cancel_method = "cancel_batch_prediction_job"
_delete_method = "delete_batch_prediction_job"
_job_type = "batch-predictions"
Expand Down Expand Up @@ -676,6 +677,7 @@ def iter_outputs(
class CustomJob(_Job):
_resource_noun = "customJobs"
_getter_method = "get_custom_job"
_list_method = "list_custom_job"
_cancel_method = "cancel_custom_job"
_delete_method = "delete_custom_job"
_job_type = "training"
Expand All @@ -685,6 +687,7 @@ class CustomJob(_Job):
class DataLabelingJob(_Job):
_resource_noun = "dataLabelingJobs"
_getter_method = "get_data_labeling_job"
_list_method = "list_data_labeling_jobs"
_cancel_method = "cancel_data_labeling_job"
_delete_method = "delete_data_labeling_job"
_job_type = "labeling-tasks"
Expand All @@ -694,6 +697,7 @@ class DataLabelingJob(_Job):
class HyperparameterTuningJob(_Job):
_resource_noun = "hyperparameterTuningJobs"
_getter_method = "get_hyperparameter_tuning_job"
_list_method = "list_hyperparameter_tuning_jobs"
_cancel_method = "cancel_hyperparameter_tuning_job"
_delete_method = "delete_hyperparameter_tuning_job"
pass
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/models.py
Expand Up @@ -71,6 +71,7 @@ class Endpoint(base.AiPlatformResourceNounWithFutureManager):
_is_client_prediction_client = False
_resource_noun = "endpoints"
_getter_method = "get_endpoint"
_list_method = "list_endpoints"
_delete_method = "delete_endpoint"

def __init__(
Expand Down Expand Up @@ -1083,6 +1084,7 @@ class Model(base.AiPlatformResourceNounWithFutureManager):
_is_client_prediction_client = False
_resource_noun = "models"
_getter_method = "get_model"
_list_method = "list_models"
_delete_method = "delete_model"

@property
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -59,6 +59,7 @@

import proto


logging.basicConfig(level=logging.INFO, stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)

Expand All @@ -77,6 +78,7 @@ class _TrainingJob(base.AiPlatformResourceNounWithFutureManager):
_is_client_prediction_client = False
_resource_noun = "trainingPipelines"
_getter_method = "get_training_pipeline"
_list_method = "list_training_pipelines"
_delete_method = "delete_training_pipeline"

def __init__(
Expand Down
50 changes: 46 additions & 4 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -106,6 +106,21 @@
# misc
_TEST_OUTPUT_DIR = "gs://my-output-bucket"

_TEST_TABULAR_DATASET_LIST = [
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
GapicDataset(
display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR
),
GapicDataset(
display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR
),
GapicDataset(
display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR
),
]

_TEST_LIST_FILTER = 'display_name="abc"'
_TEST_LIST_ORDER_BY = "create_time desc"


@pytest.fixture
def get_dataset_mock():
Expand Down Expand Up @@ -224,6 +239,13 @@ def export_data_mock():
yield export_data_mock


@pytest.fixture
def list_datasets_mock():
with patch.object(DatasetServiceClient, "list_datasets") as list_datasets_mock:
list_datasets_mock.return_value = _TEST_TABULAR_DATASET_LIST
yield list_datasets_mock


# TODO(b/171333554): Move reusable test fixtures to conftest.py file
class TestDataset:
def setup_method(self):
Expand Down Expand Up @@ -669,18 +691,19 @@ class TestTabularDataset:
def setup_method(self):
reload(initializer)
reload(aiplatform)
aiplatform.init(project=_TEST_PROJECT)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_init_dataset_tabular(self, get_dataset_tabular_mock):
aiplatform.init(project=_TEST_PROJECT)

datasets.TabularDataset(dataset_name=_TEST_NAME)
get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME)

@pytest.mark.usefixtures("get_dataset_image_mock")
def test_init_dataset_non_tabular(self):
aiplatform.init(project=_TEST_PROJECT)

with pytest.raises(ValueError):
datasets.TabularDataset(dataset_name=_TEST_NAME)

Expand Down Expand Up @@ -716,7 +739,6 @@ def test_create_dataset_with_default_encryption_key(
@pytest.mark.usefixtures("get_dataset_tabular_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_dataset(self, create_dataset_mock, sync):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = datasets.TabularDataset.create(
display_name=_TEST_DISPLAY_NAME,
Expand All @@ -743,13 +765,33 @@ def test_create_dataset(self, create_dataset_mock, sync):

@pytest.mark.usefixtures("get_dataset_tabular_mock")
def test_no_import_data_method(self):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

with pytest.raises(NotImplementedError):
my_dataset.import_data()

def test_list_dataset(self, list_datasets_mock):

ds_list = aiplatform.TabularDataset.list(
filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY
)

list_datasets_mock.assert_called_once_with(
request={
"parent": _TEST_PARENT,
"filter": _TEST_LIST_FILTER,
"order_by": _TEST_LIST_ORDER_BY,
"page_size": None,
"read_mask": None,
}
)

assert len(ds_list) == len(_TEST_TABULAR_DATASET_LIST)

for ds in ds_list:
assert type(ds) == aiplatform.TabularDataset


class TestTextDataset:
def setup_method(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Expand Up @@ -131,6 +131,7 @@
)

_TEST_JOB_GET_METHOD_NAME = "get_fake_job"
_TEST_JOB_LIST_METHOD_NAME = "list_fake_job"
_TEST_JOB_CANCEL_METHOD_NAME = "cancel_fake_job"
_TEST_JOB_DELETE_METHOD_NAME = "delete_fake_job"
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/fakeJobs/{_TEST_ID}"
Expand Down Expand Up @@ -160,6 +161,7 @@ class FakeJob(jobs._Job):
_job_type = "fake-job"
_resource_noun = "fakeJobs"
_getter_method = _TEST_JOB_GET_METHOD_NAME
_list_method = _TEST_JOB_LIST_METHOD_NAME
_cancel_method = _TEST_JOB_CANCEL_METHOD_NAME
_delete_method = _TEST_JOB_DELETE_METHOD_NAME
resource_name = _TEST_JOB_RESOURCE_NAME
Expand Down
1 change: 1 addition & 0 deletions tests/unit/aiplatform/test_lro.py
Expand Up @@ -49,6 +49,7 @@ class AiPlatformResourceNounImpl(base.AiPlatformResourceNoun):
_is_client_prediction_client = False
_resource_noun = None
_getter_method = None
_list_method = None
_delete_method = None


Expand Down