From 674227d2e7ed4a4a4e180213dc1178dde7d65a3a Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Fri, 9 Apr 2021 16:46:34 -0700 Subject: [PATCH] feat: parse project location when passed full resource name to get apis (#297) --- google/cloud/aiplatform/base.py | 62 ++++++++++++++++++++- google/cloud/aiplatform/datasets/dataset.py | 5 +- google/cloud/aiplatform/jobs.py | 7 ++- google/cloud/aiplatform/models.py | 14 ++++- google/cloud/aiplatform/training_jobs.py | 5 +- tests/unit/aiplatform/test_datasets.py | 33 +++++++++++ tests/unit/aiplatform/test_training_jobs.py | 39 ++++++++++++- 7 files changed, 156 insertions(+), 9 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 70b35329d1..78f8807334 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -20,7 +20,7 @@ import functools import inspect import threading -from typing import Any, Callable, Dict, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union import proto @@ -266,6 +266,7 @@ def __init__( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, ): """Initializes class with project, location, and api_client. @@ -274,8 +275,14 @@ def __init__( location(str): The location of the resource noun. credentials(google.auth.crendentials.Crendentials): Optional custom credentials to use when accessing interacting with resource noun. + resource_name(str): A fully-qualified resource name or ID. """ + if resource_name: + project, location = self._get_and_validate_project_location( + resource_name=resource_name, project=project, location=location + ) + self.project = project or initializer.global_config.project self.location = location or initializer.global_config.location self.credentials = credentials or initializer.global_config.credentials @@ -306,6 +313,41 @@ def _instantiate_client( prediction_client=cls._is_client_prediction_client, ) + def _get_and_validate_project_location( + self, + resource_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + ) -> Tuple: + + """Validate the project and location for the resource. + + Args: + resource_name(str): Required. A fully-qualified resource name or ID. + project(str): Project of the resource noun. + location(str): The location of the resource noun. + + Raises: + RuntimeError if location is different from resource location + """ + + if not project and not location: + return project, location + + fields = utils.extract_fields_from_resource_name( + resource_name, self._resource_noun + ) + if not fields: + return project, location + + if location and fields.location != location: + raise RuntimeError( + f"location {location} is provided, but different from " + f"the resource location {fields.location}" + ) + + return fields.project, fields.location + def _get_gca_resource(self, resource_name: str) -> proto.Message: """Returns GAPIC service representation of client class resource.""" """ @@ -493,6 +535,7 @@ def __init__( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, ): """Initializes class with project, location, and api_client. @@ -502,9 +545,14 @@ def __init__( credentials(google.auth.crendentials.Crendentials): Optional. custom credentials to use when accessing interacting with resource noun. + resource_name(str): A fully-qualified resource name or ID. """ AiPlatformResourceNoun.__init__( - self, project=project, location=location, credentials=credentials + self, + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, ) FutureManager.__init__(self) @@ -514,6 +562,7 @@ def _empty_constructor( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, ) -> "AiPlatformResourceNounWithFutureManager": """Initializes with all attributes set to None. @@ -526,11 +575,18 @@ def _empty_constructor( credentials(google.auth.crendentials.Crendentials): Optional. custom credentials to use when accessing interacting with resource noun. + resource_name(str): A fully-qualified resource name or ID. Returns: An instance of this class with attributes set to None. """ self = cls.__new__(cls) - AiPlatformResourceNoun.__init__(self, project, location, credentials) + AiPlatformResourceNoun.__init__( + self, + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, + ) FutureManager.__init__(self) self._gca_resource = None return self diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 207c4a6f8d..00c03c4928 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -71,7 +71,10 @@ def __init__( """ super().__init__( - project=project, location=location, credentials=credentials, + project=project, + location=location, + credentials=credentials, + resource_name=dataset_name, ) self._gca_resource = self._get_gca_resource(resource_name=dataset_name) self._validate_metadata_schema_uri() diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 4f6fd6d094..104ce4fd96 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -107,7 +107,12 @@ def __init__( Custom credentials to use. If not set, credentials set in aiplatform.init will be used. """ - super().__init__(project=project, location=location, credentials=credentials) + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=job_name, + ) self._gca_resource = self._get_gca_resource(resource_name=job_name) @property diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index b19ace6d74..2440d182a5 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -99,7 +99,12 @@ def __init__( credentials set in aiplatform.init. """ - super().__init__(project=project, location=location, credentials=credentials) + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=endpoint_name, + ) self._gca_resource = self._get_gca_resource(resource_name=endpoint_name) self._prediction_client = self._instantiate_prediction_client( location=location or initializer.global_config.location, @@ -1144,7 +1149,12 @@ def __init__( credentials set in aiplatform.init will be used. """ - super().__init__(project=project, location=location, credentials=credentials) + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=model_name, + ) self._gca_resource = self._get_gca_resource(resource_name=model_name) # TODO(b/170979552) Add support for predict schemata diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 5ad46b9ddc..8cfe40f125 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -180,7 +180,10 @@ def get( # These parameters won't be used as user can not run the job again. # If they try, an exception will be raised. self = cls._empty_constructor( - project=project, location=location, credentials=credentials + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, ) self._gca_resource = self._get_gca_resource(resource_name=resource_name) diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 2ac7489d5f..52bc4327f2 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -48,6 +48,7 @@ _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_ALT_PROJECT = "test-project_alt" _TEST_ALT_LOCATION = "europe-west4" _TEST_INVALID_LOCATION = "us-central2" @@ -259,6 +260,38 @@ def test_init_dataset(self, get_dataset_mock): datasets.Dataset(dataset_name=_TEST_NAME) get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + def test_init_dataset_with_id_only_with_project_and_location( + self, get_dataset_mock + ): + aiplatform.init(project=_TEST_PROJECT) + datasets.Dataset( + dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_project_and_location(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.Dataset( + dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.Dataset( + dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_project_and_alt_location(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(RuntimeError): + datasets.Dataset( + dataset_name=_TEST_NAME, + project=_TEST_PROJECT, + location=_TEST_ALT_LOCATION, + ) + def test_init_dataset_with_id_only(self, get_dataset_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) datasets.Dataset(dataset_name=_TEST_ID) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 33d43321ef..07585d7c3a 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -65,7 +65,6 @@ _TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/" _TEST_LOCAL_SCRIPT_FILE_NAME = "____test____script.py" _TEST_LOCAL_SCRIPT_FILE_PATH = f"path/to/{_TEST_LOCAL_SCRIPT_FILE_NAME}" -_TEST_PROJECT = "test-project" _TEST_PYTHON_SOURCE = """ print('hello world') """ @@ -107,6 +106,8 @@ _TEST_NAME = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipelines/{_TEST_ID}" ) +_TEST_ALT_PROJECT = "test-project-alt" +_TEST_ALT_LOCATION = "europe-west4" _TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml" _TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml" @@ -1381,6 +1382,42 @@ def test_get_training_job_with_id_only(self, get_training_job_custom_mock): training_jobs.CustomTrainingJob.get(resource_name=_TEST_ID) get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + def test_get_training_job_with_id_only_with_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_alt_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_project_and_alt_location(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(RuntimeError): + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, + project=_TEST_PROJECT, + location=_TEST_ALT_LOCATION, + ) + @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create_with_nontabular_dataset( self,