diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 78f8807334..b9648be89d 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -17,12 +17,13 @@ import abc from concurrent import futures +import datetime import functools import inspect -import threading -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union import proto +import threading +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union from google.auth import credentials as auth_credentials from google.cloud.aiplatform import initializer @@ -249,6 +250,12 @@ def _getter_method(cls) -> str: """Name of getter method of client class for retrieving the resource.""" pass + @property + @abc.abstractmethod + def _list_method(cls) -> str: + """Name of list method of client class for listing resources.""" + pass + @property @abc.abstractmethod def _delete_method(cls) -> str: @@ -385,6 +392,17 @@ def display_name(self) -> str: """Display name of this resource.""" return self._gca_resource.display_name + @property + def create_time(self) -> datetime.datetime: + """Time this resource was created.""" + return self._gca_resource.create_time + + @property + def update_time(self) -> datetime.datetime: + """Time this resource was last updated.""" + self._sync_gca_resource() + return self._gca_resource.update_time + def __repr__(self) -> str: return f"{object.__repr__(self)} \nresource name: {self.resource_name}" @@ -617,6 +635,219 @@ def _sync_object_with_future_result( if value: setattr(self, attribute, value) + def _construct_sdk_resource_from_gapic( + self, + gapic_resource: proto.Message, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> AiPlatformResourceNoun: + """Given a GAPIC resource object, return the SDK representation. + + Args: + gapic_resource (proto.Message): + A GAPIC representation of an AI Platform resource, usually + retrieved by a get_* or in a list_* API call. + project (str): + Optional. Project to construct SDK object from. If not set, + project set in aiplatform.init will be used. + location (str): + Optional. Location to construct SDK object from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to construct SDK object. + Overrides credentials set in aiplatform.init. + + Returns: + AiPlatformResourceNoun: + An initialized SDK object that represents GAPIC type. + """ + sdk_resource = self._empty_constructor( + project=project, location=location, credentials=credentials + ) + sdk_resource._gca_resource = gapic_resource + return sdk_resource + + # TODO(b/144545165): Improve documentation for list filtering once available + # TODO(b/184910159): Expose `page_size` field in list method + @classmethod + def _list( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + self = cls._empty_constructor( + project=project, location=location, credentials=credentials + ) + + # Fetch credentials once and re-use for all `_empty_constructor()` calls + creds = initializer.global_config.credentials + + resource_list_method = getattr(self.api_client, self._list_method) + + list_request = { + "parent": initializer.global_config.common_location_path( + project=project, location=location + ), + "filter": filter, + } + + if order_by: + list_request["order_by"] = order_by + + resource_list = resource_list_method(request=list_request) or [] + + return [ + self._construct_sdk_resource_from_gapic( + gapic_resource, project=project, location=location, credentials=creds + ) + for gapic_resource in resource_list + if cls_filter(gapic_resource) + ] + + @classmethod + def _list_with_local_order( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + Provides client-side sorting when a list API doesn't support `order_by`. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + + li = cls._list( + cls_filter=cls_filter, + filter=filter, + order_by=None, # This method will handle the ordering locally + project=project, + location=location, + credentials=credentials, + ) + + desc = "desc" in order_by + order_by = order_by.replace("desc", "") + order_by = order_by.split(",") + + li.sort( + key=lambda x: tuple(getattr(x, field.strip()) for field in order_by), + reverse=desc, + ) + + return li + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """List all instances of this AI Platform Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + aiplatform.Model.list(order_by="create_time desc, display_name") + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + + return cls._list( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + @optional_sync() def delete(self, sync: bool = True) -> None: """Deletes this AI Platform resource. WARNING: This deletion is permament. diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 00c03c4928..999b80c5e2 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Optional, Sequence, Dict, Tuple, Union +from typing import Optional, Sequence, Dict, Tuple, Union, List from google.api_core import operation from google.auth import credentials as auth_credentials @@ -40,9 +40,10 @@ class Dataset(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "datasets" _getter_method = "get_dataset" + _list_method = "list_datasets" _delete_method = "delete_dataset" - _supported_metadata_schema_uris: Optional[Tuple[str]] = None + _supported_metadata_schema_uris: Tuple[str] = () def __init__( self, @@ -494,3 +495,57 @@ def export_data(self, output_dir: str) -> Sequence[str]: def update(self): raise NotImplementedError("Update dataset has not been implemented yet") + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Dataset resource. + + Example Usage: + + aiplatform.TabularDataset.list( + filter='labels.my_key="my_value"', + order_by='display_name' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[base.AiPlatformResourceNoun] - A list of Dataset resource objects + """ + + dataset_subclass_filter = ( + lambda gapic_obj: gapic_obj.metadata_schema_uri + in cls._supported_metadata_schema_uris + ) + + return cls._list_with_local_order( + cls_filter=dataset_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 104ce4fd96..1a5083ee52 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Iterable, Optional, Union, Sequence, Dict +from typing import Iterable, Optional, Union, Sequence, Dict, List import abc import sys @@ -173,6 +173,53 @@ def _block_until_complete(self): if self.state in _JOB_ERROR_STATES: raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error) + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Job Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of Job resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + def cancel(self) -> None: """Cancels this Job. Success of cancellation is not guaranteed. Use `Job.state` property to verify if cancellation was successful.""" @@ -183,6 +230,7 @@ class BatchPredictionJob(_Job): _resource_noun = "batchPredictionJobs" _getter_method = "get_batch_prediction_job" + _list_method = "list_batch_prediction_jobs" _cancel_method = "cancel_batch_prediction_job" _delete_method = "delete_batch_prediction_job" _job_type = "batch-predictions" @@ -704,6 +752,7 @@ def iter_outputs( class CustomJob(_Job): _resource_noun = "customJobs" _getter_method = "get_custom_job" + _list_method = "list_custom_job" _cancel_method = "cancel_custom_job" _delete_method = "delete_custom_job" _job_type = "training" @@ -713,6 +762,7 @@ class CustomJob(_Job): class DataLabelingJob(_Job): _resource_noun = "dataLabelingJobs" _getter_method = "get_data_labeling_job" + _list_method = "list_data_labeling_jobs" _cancel_method = "cancel_data_labeling_job" _delete_method = "delete_data_labeling_job" _job_type = "labeling-tasks" @@ -722,6 +772,7 @@ class DataLabelingJob(_Job): class HyperparameterTuningJob(_Job): _resource_noun = "hyperparameterTuningJobs" _getter_method = "get_hyperparameter_tuning_job" + _list_method = "list_hyperparameter_tuning_jobs" _cancel_method = "cancel_hyperparameter_tuning_job" _delete_method = "delete_hyperparameter_tuning_job" pass diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 2440d182a5..60935e50b5 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -24,6 +24,7 @@ from google.cloud.aiplatform import explain from google.cloud.aiplatform import initializer from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import models from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.services import endpoint_service_client @@ -72,6 +73,7 @@ class Endpoint(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "endpoints" _getter_method = "get_endpoint" + _list_method = "list_endpoints" _delete_method = "delete_endpoint" def __init__( @@ -1055,6 +1057,53 @@ def explain( explanations=explain_response.explanations, ) + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["models.Endpoint"]: + """List all Endpoint resource instances. + + Example Usage: + + aiplatform.Endpoint.list( + filter='labels.my_label="my_label_value" OR display_name=!"old_endpoint"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[models.Endpoint] - A list of Endpoint resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + def list_models( self, ) -> Sequence[ @@ -1112,6 +1161,7 @@ class Model(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "models" _getter_method = "get_model" + _list_method = "list_models" _delete_method = "delete_model" @property @@ -1865,3 +1915,50 @@ def batch_predict( encryption_spec_key_name=encryption_spec_key_name, sync=sync, ) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["models.Model"]: + """List all Model resource instances. + + Example Usage: + + aiplatform.Model.list( + filter='labels.my_label="my_label_value" AND display_name="my_model"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[models.Model] - A list of Model resource objects + """ + + return cls._list( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 8cfe40f125..02264f9244 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -55,6 +55,7 @@ import proto + logging.basicConfig(level=logging.INFO, stream=sys.stdout) _LOGGER = logging.getLogger(__name__) @@ -74,6 +75,7 @@ class _TrainingJob(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "trainingPipelines" _getter_method = "get_training_pipeline" + _list_method = "list_training_pipelines" _delete_method = "delete_training_pipeline" def __init__( @@ -693,6 +695,60 @@ def _assert_has_run(self) -> bool: ) return False + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["base.AiPlatformResourceNoune"]: + """List all instances of this TrainingJob resource. + + Example Usage: + + aiplatform.CustomTrainingJob.list( + filter='display_name="experiment_a27"', + order_by='create_time desc' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of TrainingJob resource objects + """ + + training_job_subclass_filter = ( + lambda gapic_obj: gapic_obj.training_task_definition + in cls._supported_training_schemas + ) + + return cls._list_with_local_order( + cls_filter=training_job_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + def cancel(self) -> None: """Starts asynchronous cancellation on the TrainingJob. The server makes a best effort to cancel the job, but success is not guaranteed. diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 52bc4327f2..f49c38e62f 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -110,6 +110,27 @@ # misc _TEST_OUTPUT_DIR = "gs://my-output-bucket" +_TEST_DATASET_LIST = [ + gca_dataset.Dataset( + display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + gca_dataset.Dataset( + display_name="d", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR + ), + gca_dataset.Dataset( + display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + gca_dataset.Dataset( + display_name="e", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT + ), + gca_dataset.Dataset( + display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY = "create_time desc" + @pytest.fixture def get_dataset_mock(): @@ -246,6 +267,15 @@ def export_data_mock(): yield export_data_mock +@pytest.fixture +def list_datasets_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "list_datasets" + ) as list_datasets_mock: + list_datasets_mock.return_value = _TEST_DATASET_LIST + yield list_datasets_mock + + # TODO(b/171333554): Move reusable test fixtures to conftest.py file class TestDataset: def setup_method(self): @@ -723,18 +753,19 @@ class TestTabularDataset: def setup_method(self): reload(initializer) reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT) def teardown_method(self): initializer.global_pool.shutdown(wait=True) def test_init_dataset_tabular(self, get_dataset_tabular_mock): - aiplatform.init(project=_TEST_PROJECT) + datasets.TabularDataset(dataset_name=_TEST_NAME) get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME) @pytest.mark.usefixtures("get_dataset_image_mock") def test_init_dataset_non_tabular(self): - aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): datasets.TabularDataset(dataset_name=_TEST_NAME) @@ -770,7 +801,6 @@ def test_create_dataset_with_default_encryption_key( @pytest.mark.usefixtures("get_dataset_tabular_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_dataset(self, create_dataset_mock, sync): - aiplatform.init(project=_TEST_PROJECT) my_dataset = datasets.TabularDataset.create( display_name=_TEST_DISPLAY_NAME, @@ -797,13 +827,28 @@ def test_create_dataset(self, create_dataset_mock, sync): @pytest.mark.usefixtures("get_dataset_tabular_mock") def test_no_import_data_method(self): - aiplatform.init(project=_TEST_PROJECT) my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) with pytest.raises(NotImplementedError): my_dataset.import_data() + def test_list_dataset(self, list_datasets_mock): + + ds_list = aiplatform.TabularDataset.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY + ) + + list_datasets_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + # Ensure returned list is smaller since it filtered out non-tabular datasets + assert len(ds_list) < len(_TEST_DATASET_LIST) + + for ds in ds_list: + assert type(ds) == aiplatform.TabularDataset + class TestTextDataset: def setup_method(self): diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 7b18e1e497..ea74c89e5e 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -19,6 +19,7 @@ from unittest import mock from importlib import reload +from datetime import datetime, timedelta from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials @@ -138,6 +139,23 @@ ) +_TEST_ENDPOINT_LIST = [ + gca_endpoint.Endpoint( + display_name="aac", create_time=datetime.now() - timedelta(minutes=15) + ), + gca_endpoint.Endpoint( + display_name="aab", create_time=datetime.now() - timedelta(minutes=5) + ), + gca_endpoint.Endpoint( + display_name="aaa", create_time=datetime.now() - timedelta(minutes=10) + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY_CREATE_TIME = "create_time desc" +_TEST_LIST_ORDER_BY_DISPLAY_NAME = "display_name" + + @pytest.fixture def get_endpoint_mock(): with mock.patch.object( @@ -264,6 +282,15 @@ def sdk_undeploy_all_mock(): yield sdk_undeploy_all_mock +@pytest.fixture +def list_endpoints_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "list_endpoints" + ) as list_endpoints_mock: + list_endpoints_mock.return_value = _TEST_ENDPOINT_LIST + yield list_endpoints_mock + + @pytest.fixture def create_client_mock(): with mock.patch.object( @@ -307,6 +334,7 @@ class TestEndpoint: def setup_method(self): reload(initializer) reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -974,6 +1002,46 @@ def test_undeploy_all(self, sdk_private_undeploy_mock, sync): any_order=True, ) + def test_list_endpoint_order_by_time(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in descending order of create_time""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_CREATE_TIME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ep_list[0].create_time > ep_list[1].create_time > ep_list[2].create_time + + def test_list_endpoint_order_by_display_name(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in order of display_name""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_DISPLAY_NAME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ( + ep_list[0].display_name < ep_list[1].display_name < ep_list[2].display_name + ) + @pytest.mark.usefixtures("get_endpoint_with_models_mock") @pytest.mark.parametrize("sync", [True, False]) def test_delete_endpoint_without_force( diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 53fe9d2d0a..8d466d6741 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -141,6 +141,7 @@ ) _TEST_JOB_GET_METHOD_NAME = "get_fake_job" +_TEST_JOB_LIST_METHOD_NAME = "list_fake_job" _TEST_JOB_CANCEL_METHOD_NAME = "cancel_fake_job" _TEST_JOB_DELETE_METHOD_NAME = "delete_fake_job" _TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/fakeJobs/{_TEST_ID}" @@ -170,6 +171,7 @@ class FakeJob(jobs._Job): _job_type = "fake-job" _resource_noun = "fakeJobs" _getter_method = _TEST_JOB_GET_METHOD_NAME + _list_method = _TEST_JOB_LIST_METHOD_NAME _cancel_method = _TEST_JOB_CANCEL_METHOD_NAME _delete_method = _TEST_JOB_DELETE_METHOD_NAME resource_name = _TEST_JOB_RESOURCE_NAME diff --git a/tests/unit/aiplatform/test_lro.py b/tests/unit/aiplatform/test_lro.py index 0ce2e85594..d5844b572d 100644 --- a/tests/unit/aiplatform/test_lro.py +++ b/tests/unit/aiplatform/test_lro.py @@ -50,6 +50,7 @@ class AiPlatformResourceNounImpl(base.AiPlatformResourceNoun): _is_client_prediction_client = False _resource_noun = None _getter_method = None + _list_method = None _delete_method = None diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index ff0310a003..47b000d189 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -155,6 +155,8 @@ _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID ) +_TEST_OUTPUT_DIR = "gs://my-output-bucket" + @pytest.fixture def get_endpoint_mock(): @@ -367,6 +369,7 @@ class TestModel: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -469,8 +472,6 @@ def test_upload_uploads_and_gets_model( def test_upload_raises_with_impartial_explanation_spec(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with pytest.raises(ValueError) as e: models.Model.upload( display_name=_TEST_MODEL_NAME, @@ -652,7 +653,7 @@ def test_upload_uploads_and_gets_model_with_custom_location( @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_endpoint = models.Endpoint(_TEST_ID) @@ -681,7 +682,7 @@ def test_deploy(self, deploy_model_mock, sync): ) @pytest.mark.parametrize("sync", [True, False]) def test_deploy_no_endpoint(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_endpoint = test_model.deploy(sync=sync) @@ -708,7 +709,7 @@ def test_deploy_no_endpoint(self, deploy_model_mock, sync): ) @pytest.mark.parametrize("sync", [True, False]) def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_endpoint = test_model.deploy( machine_type=_TEST_MACHINE_TYPE, @@ -789,7 +790,7 @@ def test_deploy_no_endpoint_with_explanations( "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" ) def test_deploy_raises_with_impartial_explanation_spec(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) with pytest.raises(ValueError) as e: @@ -855,9 +856,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a def test_batch_predict_gcs_source_and_dest( self, create_batch_prediction_job_mock, sync ): - aiplatform.init( - project=_TEST_PROJECT, location=_TEST_LOCATION, - ) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -899,7 +898,7 @@ def test_batch_predict_gcs_source_and_dest( def test_batch_predict_gcs_source_bq_dest( self, create_batch_prediction_job_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -1012,7 +1011,7 @@ def test_batch_predict_with_all_args( @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_no_source(self, create_batch_prediction_job_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call without source @@ -1026,7 +1025,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call with two sources @@ -1042,7 +1041,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_no_destination(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call without destination @@ -1056,7 +1055,7 @@ def test_batch_predict_no_destination(self): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_wrong_instance_format(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -1072,7 +1071,7 @@ def test_batch_predict_wrong_instance_format(self): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_wrong_prediction_format(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -1089,7 +1088,7 @@ def test_batch_predict_wrong_prediction_format(self): @pytest.mark.usefixtures("get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_delete_model(self, delete_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_model.delete(sync=sync)