From 3907e4e874b0188658ea8f07cef3cdbe639a4397 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Wed, 7 Apr 2021 18:16:18 -0700 Subject: [PATCH] add extract_project_location_from_resource_name, and re-instantiate client upon difference --- google/cloud/aiplatform/base.py | 10 ++++++++++ google/cloud/aiplatform/utils.py | 29 ++++++++++++++++++++++++++++ tests/unit/aiplatform/test_utils.py | 30 +++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 70b35329d1..f0629eaa19 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -320,6 +320,16 @@ def _get_gca_resource(self, resource_name: str) -> proto.Message: project=self.project, location=self.location, ) + ( + resource_project, + resource_location, + ) = utils.extract_project_location_from_resource_name( + resource_name=resource_name, resource_noun=self._resource_noun + ) + if resource_location != self.location: + self.api_client = self._instantiate_client( + resource_location, self.credentials + ) return getattr(self.api_client, self._getter_method)(name=resource_name) diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index ec39038942..58e135dcf1 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -128,6 +128,35 @@ def extract_fields_from_resource_name( return fields +def extract_project_location_from_resource_name( + resource_name: str, resource_noun: Optional[str] = None +) -> Optional[Tuple]: + """Returns extracted fields from a fully-qualified resource name. + Returns None if name is invalid. + + Args: + resource_name (str): + Required. A fully-qualified AI Platform (Unified) resource name + + resource_noun (str): + A plural resource noun to validate the resource name against. + For example, you would pass "datasets" to validate + "projects/123/locations/us-central1/datasets/456". + + Returns: + (Tuple): + A tuple containing 2 extracted fields from a resource name: + project, location. These fields can be used for + subsequent method calls in the SDK. + """ + fields = extract_fields_from_resource_name(resource_name, resource_noun) + + if not fields: + return None + + return fields.project, fields.location + + def full_resource_name( resource_name: str, resource_noun: str, diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 3032475069..9985165ece 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -21,6 +21,7 @@ from random import choice from random import randint from string import ascii_letters +from typing import Tuple from google.api_core import client_options from google.api_core import gapic_v1 @@ -56,6 +57,35 @@ def test_extract_fields_from_resource_name(resource_name: str, expected: bool): assert expected == bool(utils.extract_fields_from_resource_name(resource_name)) +@pytest.mark.parametrize( + "resource_name, expected", + [ + ( + "projects/123456/locations/us-central1/datasets/987654", + ("123456", "us-central1"), + ), + ( + "projects/857392/locations/us-central1/trainingPipelines/347292", + ("857392", "us-central1"), + ), + ( + "projects/acme-co-proj-1/locations/us-central1/datasets/123456", + ("acme-co-proj-1", "us-central1"), + ), + ("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", None), + ("project/123456/locations/us-central1/datasets/987654", None), + ("project//locations//datasets/987654", None), + ("locations/europe-west4/datasets/987654", None), + ("987654", None), + ], +) +def test_extract_project_location_from_resource_name( + resource_name: str, expected: Tuple +): + # Given a resource name and expected validity, test test_extract_project_location_from_resource_name() + assert expected == utils.extract_project_location_from_resource_name(resource_name) + + @pytest.fixture def generated_resource_fields(): generated_fields = utils.Fields(