Skip to content

Commit

Permalink
add extract_project_location_from_resource_name, and re-instantiate c…
Browse files Browse the repository at this point in the history
…lient upon difference
  • Loading branch information
morgandu committed Apr 8, 2021
1 parent 0d9dccf commit 3907e4e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
10 changes: 10 additions & 0 deletions google/cloud/aiplatform/base.py
Expand Up @@ -320,6 +320,16 @@ def _get_gca_resource(self, resource_name: str) -> proto.Message:
project=self.project,
location=self.location,
)
(
resource_project,
resource_location,
) = utils.extract_project_location_from_resource_name(
resource_name=resource_name, resource_noun=self._resource_noun
)
if resource_location != self.location:
self.api_client = self._instantiate_client(
resource_location, self.credentials
)

return getattr(self.api_client, self._getter_method)(name=resource_name)

Expand Down
29 changes: 29 additions & 0 deletions google/cloud/aiplatform/utils.py
Expand Up @@ -128,6 +128,35 @@ def extract_fields_from_resource_name(
return fields


def extract_project_location_from_resource_name(
resource_name: str, resource_noun: Optional[str] = None
) -> Optional[Tuple]:
"""Returns extracted fields from a fully-qualified resource name.
Returns None if name is invalid.
Args:
resource_name (str):
Required. A fully-qualified AI Platform (Unified) resource name
resource_noun (str):
A plural resource noun to validate the resource name against.
For example, you would pass "datasets" to validate
"projects/123/locations/us-central1/datasets/456".
Returns:
(Tuple):
A tuple containing 2 extracted fields from a resource name:
project, location. These fields can be used for
subsequent method calls in the SDK.
"""
fields = extract_fields_from_resource_name(resource_name, resource_noun)

if not fields:
return None

return fields.project, fields.location


def full_resource_name(
resource_name: str,
resource_noun: str,
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Expand Up @@ -21,6 +21,7 @@
from random import choice
from random import randint
from string import ascii_letters
from typing import Tuple

from google.api_core import client_options
from google.api_core import gapic_v1
Expand Down Expand Up @@ -56,6 +57,35 @@ def test_extract_fields_from_resource_name(resource_name: str, expected: bool):
assert expected == bool(utils.extract_fields_from_resource_name(resource_name))


@pytest.mark.parametrize(
"resource_name, expected",
[
(
"projects/123456/locations/us-central1/datasets/987654",
("123456", "us-central1"),
),
(
"projects/857392/locations/us-central1/trainingPipelines/347292",
("857392", "us-central1"),
),
(
"projects/acme-co-proj-1/locations/us-central1/datasets/123456",
("acme-co-proj-1", "us-central1"),
),
("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", None),
("project/123456/locations/us-central1/datasets/987654", None),
("project//locations//datasets/987654", None),
("locations/europe-west4/datasets/987654", None),
("987654", None),
],
)
def test_extract_project_location_from_resource_name(
resource_name: str, expected: Tuple
):
# Given a resource name and expected validity, test test_extract_project_location_from_resource_name()
assert expected == utils.extract_project_location_from_resource_name(resource_name)


@pytest.fixture
def generated_resource_fields():
generated_fields = utils.Fields(
Expand Down

0 comments on commit 3907e4e

Please sign in to comment.