Skip to content

Commit

Permalink
feat: default to custom job display name if experiment name looks lik…
Browse files Browse the repository at this point in the history
…e a custom job ID (#833)

Co-authored-by: Yicheng Fang <yichengfang@google.com>
  • Loading branch information
yfang1 and Yicheng Fang committed Nov 17, 2021
1 parent e0fc3d9 commit 8b9376e
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
24 changes: 22 additions & 2 deletions google/cloud/aiplatform/tensorboard/uploader_main.py
Expand Up @@ -28,8 +28,10 @@
from tensorboard.plugins.image import metadata as images_metadata
from tensorboard.plugins.graph import metadata as graphs_metadata

from google.api_core import exceptions
from google.cloud import storage
from google.cloud import aiplatform
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform.tensorboard import uploader
from google.cloud.aiplatform.utils import TensorboardClientWithOverride

Expand Down Expand Up @@ -123,9 +125,14 @@ def main(argv):
exitcode=0,
)

experiment_name = FLAGS.experiment_name
experiment_display_name = get_experiment_display_name_with_override(
experiment_name, FLAGS.experiment_display_name, project_id, region
)

tb_uploader = uploader.TensorBoardUploader(
experiment_name=FLAGS.experiment_name,
experiment_display_name=FLAGS.experiment_display_name,
experiment_name=experiment_name,
experiment_display_name=experiment_display_name,
tensorboard_resource_name=tensorboard.name,
blob_storage_bucket=blob_storage_bucket,
blob_storage_folder=blob_storage_folder,
Expand All @@ -149,6 +156,19 @@ def main(argv):
tb_uploader.start_uploading()


def get_experiment_display_name_with_override(
experiment_name, experiment_display_name, project_id, region
):
if experiment_name.isdecimal() and not experiment_display_name:
try:
return jobs.CustomJob.get(
resource_name=experiment_name, project=project_id, location=region,
).display_name
except exceptions.NotFound:
return experiment_display_name
return experiment_display_name


def flags_parser(args):
# Plumbs the flags defined in this file to the main module, mostly for the
# console script wrapper tb-gcp-uploader.
Expand Down
124 changes: 124 additions & 0 deletions tests/unit/aiplatform/test_uploader_main.py
@@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest

from importlib import reload
from unittest.mock import patch

from google.api_core import exceptions
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.tensorboard import uploader_main
from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat
from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client

_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_CUSTOM_JOB_ID = "445768"
_TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_CUSTOM_JOB_ID}"
_TEST_CUSTOM_JOBS_DISPLAY_NAME = "a custom job display name"
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME = "someDisplayName"


def _get_custom_job_proto(state=None, name=None):
custom_job_proto = gca_custom_job_compat.CustomJob()
custom_job_proto.name = name
custom_job_proto.state = state
custom_job_proto.display_name = _TEST_CUSTOM_JOBS_DISPLAY_NAME
return custom_job_proto


@pytest.fixture
def get_custom_job_mock_not_found():
with patch.object(
job_service_client.JobServiceClient, "get_custom_job"
) as get_custom_job_mock:
get_custom_job_mock.side_effect = exceptions.NotFound("not found")
yield get_custom_job_mock


@pytest.fixture
def get_custom_job_mock():
with patch.object(
job_service_client.JobServiceClient, "get_custom_job"
) as get_custom_job_mock:
get_custom_job_mock.side_effect = [
_get_custom_job_proto(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED,
),
]
yield get_custom_job_mock


class TestUploaderMain:
def setup_method(self):
reload(initializer)
reload(aiplatform)
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_get_default_custom_job_display_name(self, get_custom_job_mock):
aiplatform.init(project=_TEST_PROJECT)
assert (
uploader_main.get_experiment_display_name_with_override(
_TEST_CUSTOM_JOB_ID, None, _TEST_PROJECT, _TEST_LOCATION
)
== _TEST_CUSTOM_JOBS_DISPLAY_NAME
)

def test_non_decimal_experiment_name(self, get_custom_job_mock):
aiplatform.init(project=_TEST_PROJECT)
assert (
uploader_main.get_experiment_display_name_with_override(
"someExperimentName",
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME,
_TEST_PROJECT,
_TEST_LOCATION,
)
== _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME
)
get_custom_job_mock.assert_not_called()

def test_display_name_already_specified(self, get_custom_job_mock):
aiplatform.init(project=_TEST_PROJECT)
assert (
uploader_main.get_experiment_display_name_with_override(
_TEST_CUSTOM_JOB_ID,
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME,
_TEST_PROJECT,
_TEST_LOCATION,
)
== _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME
)
get_custom_job_mock.assert_not_called()

def test_custom_job_not_found(self, get_custom_job_mock_not_found):
aiplatform.init(project=_TEST_PROJECT)
assert (
uploader_main.get_experiment_display_name_with_override(
_TEST_CUSTOM_JOB_ID,
_TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME,
_TEST_PROJECT,
_TEST_LOCATION,
)
== _TEST_PASSED_IN_EXPERIMENT_DISPLAY_NAME
)

0 comments on commit 8b9376e

Please sign in to comment.