diff --git a/google/cloud/aiplatform/tensorboard/uploader_main.py b/google/cloud/aiplatform/tensorboard/uploader_main.py index 7c868fcfb7..304e46cfb7 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_main.py +++ b/google/cloud/aiplatform/tensorboard/uploader_main.py @@ -28,8 +28,10 @@ from tensorboard.plugins.image import metadata as images_metadata from tensorboard.plugins.graph import metadata as graphs_metadata +from google.api_core import exceptions from google.cloud import storage from google.cloud import aiplatform +from google.cloud.aiplatform import jobs from google.cloud.aiplatform.tensorboard import uploader from google.cloud.aiplatform.utils import TensorboardClientWithOverride @@ -123,9 +125,14 @@ def main(argv): exitcode=0, ) + experiment_name = FLAGS.experiment_name + experiment_display_name = get_experiment_display_name_with_override( + experiment_name, FLAGS.experiment_display_name, project_id, region + ) + tb_uploader = uploader.TensorBoardUploader( - experiment_name=FLAGS.experiment_name, - experiment_display_name=FLAGS.experiment_display_name, + experiment_name=experiment_name, + experiment_display_name=experiment_display_name, tensorboard_resource_name=tensorboard.name, blob_storage_bucket=blob_storage_bucket, blob_storage_folder=blob_storage_folder, @@ -149,6 +156,19 @@ def main(argv): tb_uploader.start_uploading() +def get_experiment_display_name_with_override( + experiment_name, experiment_display_name, project_id, region +): + if experiment_name.isdecimal() and not experiment_display_name: + try: + return jobs.CustomJob.get( + resource_name=experiment_name, project=project_id, location=region, + ).display_name + except exceptions.NotFound: + return experiment_display_name + return experiment_display_name + + def flags_parser(args): # Plumbs the flags defined in this file to the main module, mostly for the # console script wrapper tb-gcp-uploader. diff --git a/tests/unit/aiplatform/test_uploader_main.py b/tests/unit/aiplatform/test_uploader_main.py new file mode 100644 index 0000000000..79c86b22fc --- /dev/null +++ b/tests/unit/aiplatform/test_uploader_main.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from importlib import reload +from unittest.mock import patch + +from google.api_core import exceptions +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.tensorboard import uploader_main +from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat +from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_CUSTOM_JOB_ID = "445768" +_TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_CUSTOM_JOB_ID}" +_TEST_CUSTOM_JOBS_DISPLAY_NAME = "a custom job display name" +_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME = "someDisplayName" + + +def _get_custom_job_proto(state=None, name=None): + custom_job_proto = gca_custom_job_compat.CustomJob() + custom_job_proto.name = name + custom_job_proto.state = state + custom_job_proto.display_name = _TEST_CUSTOM_JOBS_DISPLAY_NAME + return custom_job_proto + + +@pytest.fixture +def get_custom_job_mock_not_found(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = exceptions.NotFound("not found") + yield get_custom_job_mock + + +@pytest.fixture +def get_custom_job_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = [ + _get_custom_job_proto( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_custom_job_mock + + +class TestUploaderMain: + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_get_default_custom_job_display_name(self, get_custom_job_mock): + aiplatform.init(project=_TEST_PROJECT) + assert ( + uploader_main.get_experiment_display_name_with_override( + _TEST_CUSTOM_JOB_ID, None, _TEST_PROJECT, _TEST_LOCATION + ) + == _TEST_CUSTOM_JOBS_DISPLAY_NAME + ) + + def test_non_decimal_experiment_name(self, get_custom_job_mock): + aiplatform.init(project=_TEST_PROJECT) + assert ( + uploader_main.get_experiment_display_name_with_override( + "someExperimentName", + _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME, + _TEST_PROJECT, + _TEST_LOCATION, + ) + == _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME + ) + get_custom_job_mock.assert_not_called() + + def test_display_name_already_specified(self, get_custom_job_mock): + aiplatform.init(project=_TEST_PROJECT) + assert ( + uploader_main.get_experiment_display_name_with_override( + _TEST_CUSTOM_JOB_ID, + _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME, + _TEST_PROJECT, + _TEST_LOCATION, + ) + == _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME + ) + get_custom_job_mock.assert_not_called() + + def test_custom_job_not_found(self, get_custom_job_mock_not_found): + aiplatform.init(project=_TEST_PROJECT) + assert ( + uploader_main.get_experiment_display_name_with_override( + _TEST_CUSTOM_JOB_ID, + _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME, + _TEST_PROJECT, + _TEST_LOCATION, + ) + == _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME + )