Skip to content

Commit

Permalink
fix: Enable MetadataStore to use credentials when aiplatfrom.init pas…
Browse files Browse the repository at this point in the history
…sed experiment and credentials. (#460)
  • Loading branch information
sasha-gitg committed Jun 3, 2021
1 parent b4211f2 commit e7bf0d8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
14 changes: 8 additions & 6 deletions google/cloud/aiplatform/initializer.py
Expand Up @@ -92,11 +92,19 @@ def init(
if metadata.metadata_service.experiment_name:
logging.info("project/location updated, reset Metadata config.")
metadata.metadata_service.reset()

if project:
self._project = project
if location:
utils.validate_region(location)
self._location = location
if staging_bucket:
self._staging_bucket = staging_bucket
if credentials:
self._credentials = credentials
if encryption_spec_key_name:
self._encryption_spec_key_name = encryption_spec_key_name

if experiment:
metadata.metadata_service.set_experiment(
experiment=experiment, description=experiment_description
Expand All @@ -105,12 +113,6 @@ def init(
raise ValueError(
"Experiment name needs to be set in `init` in order to add experiment descriptions."
)
if staging_bucket:
self._staging_bucket = staging_bucket
if credentials:
self._credentials = credentials
if encryption_spec_key_name:
self._encryption_spec_key_name = encryption_spec_key_name

def get_encryption_spec(
self,
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/metadata/metadata_store.py
Expand Up @@ -205,7 +205,7 @@ def _get(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Optional[_MetadataStore]":
) -> Optional["_MetadataStore"]:
"""Returns a MetadataStore resource.
Args:
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/aiplatform/test_metadata.py
Expand Up @@ -16,10 +16,13 @@
#

from importlib import reload
from unittest import mock
from unittest.mock import patch, call

import pytest
from google.api_core import exceptions
from google.api_core import operation
from google.auth import credentials

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -106,6 +109,32 @@ def get_metadata_store_mock():
yield get_metadata_store_mock


@pytest.fixture
def get_metadata_store_mock_raise_not_found_exception():
with patch.object(
MetadataServiceClient, "get_metadata_store"
) as get_metadata_store_mock:
get_metadata_store_mock.side_effect = [
exceptions.NotFound("Test store not found."),
GapicMetadataStore(name=_TEST_METADATASTORE,),
]

yield get_metadata_store_mock


@pytest.fixture
def create_metadata_store_mock():
with patch.object(
MetadataServiceClient, "create_metadata_store"
) as create_metadata_store_mock:
create_metadata_store_lro_mock = mock.Mock(operation.Operation)
create_metadata_store_lro_mock.result.return_value = GapicMetadataStore(
name=_TEST_METADATASTORE,
)
create_metadata_store_mock.return_value = create_metadata_store_lro_mock
yield create_metadata_store_mock


@pytest.fixture
def get_context_mock():
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
Expand Down Expand Up @@ -364,6 +393,54 @@ def test_init_experiment_with_existing_metadataStore_and_context(
get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE)
get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME)

def test_init_experiment_with_credentials(
self, get_metadata_store_mock, get_context_mock
):
creds = credentials.AnonymousCredentials()

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
experiment=_TEST_EXPERIMENT,
credentials=creds,
)

assert (
metadata.metadata_service._experiment.api_client._transport._credentials
== creds
)

get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE)
get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME)

def test_init_and_get_metadata_store_with_credentials(
self, get_metadata_store_mock
):
creds = credentials.AnonymousCredentials()

aiplatform.init(
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds
)

store = metadata._MetadataStore.get_or_create()

assert store.api_client._transport._credentials == creds

@pytest.mark.usefixtures(
"get_metadata_store_mock_raise_not_found_exception",
"create_metadata_store_mock",
)
def test_init_and_get_then_create_metadata_store_with_credentials(self):
creds = credentials.AnonymousCredentials()

aiplatform.init(
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds
)

store = metadata._MetadataStore.get_or_create()

assert store.api_client._transport._credentials == creds

def test_init_experiment_with_existing_description(
self, get_metadata_store_mock, get_context_mock
):
Expand Down

0 comments on commit e7bf0d8

Please sign in to comment.