diff --git a/google/cloud/aiplatform/utils/tensorboard_utils.py b/google/cloud/aiplatform/utils/tensorboard_utils.py new file mode 100644 index 0000000000..d3cb1ef704 --- /dev/null +++ b/google/cloud/aiplatform/utils/tensorboard_utils.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Sequence, Dict +from google.cloud.aiplatform_v1beta1.services.tensorboard_service.client import ( + TensorboardServiceClient, +) + +_SERVING_DOMAIN = "tensorboard.googleusercontent.com" + + +def _parse_experiment_name(experiment_name: str) -> Dict[str, str]: + """Parses an experiment_name into its component segments. + + Args: + experiment_name: Resource name of the TensorboardExperiment. E.g. + "projects/123/locations/asia-east1/tensorboards/456/experiments/exp1" + + Returns: + Components of the experiment name. + + Raises: + ValueError if the experiment_name is invalid. + """ + matched = TensorboardServiceClient.parse_tensorboard_experiment_path( + experiment_name + ) + if not matched: + raise ValueError(f"Invalid experiment name: {experiment_name}.") + return matched + + +def get_experiment_url(experiment_name: str) -> str: + """Get URL for comparing experiments. + + Args: + experiment_name: Resource name of the TensorboardExperiment. E.g. + "projects/123/locations/asia-east1/tensorboards/456/experiments/exp1" + + Returns: + URL for the tensorboard web app. + """ + location = _parse_experiment_name(experiment_name)["location"] + name_for_url = experiment_name.replace("/", "+") + return f"https://{location}.{_SERVING_DOMAIN}/experiment/{name_for_url}" + + +def get_experiments_compare_url(experiment_names: Sequence[str]) -> str: + """Get URL for comparing experiments. + + Args: + experiment_names: Resource names of the TensorboardExperiments that needs to + be compared. + + Returns: + URL for the tensorboard web app. + """ + if len(experiment_names) < 2: + raise ValueError("At least two experiment_names are required.") + + locations = { + _parse_experiment_name(experiment_name)["location"] + for experiment_name in experiment_names + } + if len(locations) != 1: + raise ValueError( + f"Got experiments from different locations: {', '.join(locations)}." + ) + location = locations.pop() + + experiment_url_segments = [] + for idx, experiment_name in enumerate(experiment_names): + name_segments = _parse_experiment_name(experiment_name) + experiment_url_segments.append( + "{cnt}-{experiment}:{project}+{location}+{tensorboard}+{experiment}".format( + cnt=idx + 1, **name_segments + ) + ) + encoded_names = ",".join(experiment_url_segments) + return f"https://{location}.{_SERVING_DOMAIN}/compare/{encoded_names}" diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index ed85fb9f0a..bdc674ebc0 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -28,6 +28,7 @@ from google.cloud.aiplatform import compat from google.cloud.aiplatform import utils from google.cloud.aiplatform.utils import pipeline_utils +from google.cloud.aiplatform.utils import tensorboard_utils from google.cloud.aiplatform_v1beta1.services.model_service import ( client as model_service_client_v1beta1, @@ -454,3 +455,54 @@ def test_pipeline_utils_runtime_config_builder_parameter_not_found(self): my_builder.build() assert e.match(regexp=r"The pipeline parameter no_such_param is not found") + + +class TestTensorboardUtils: + def test_tensorboard_get_experiment_url(self): + actual = tensorboard_utils.get_experiment_url( + "projects/123/locations/asia-east1/tensorboards/456/experiments/exp1" + ) + assert actual == ( + "https://asia-east1.tensorboard." + + "googleusercontent.com/experiment/projects+123+locations+asia-east1+tensorboards+456+experiments+exp1" + ) + + def test_get_experiments_url_bad_experiment_name(self): + with pytest.raises(ValueError, match="Invalid experiment name: foo-bar."): + tensorboard_utils.get_experiment_url("foo-bar") + + def test_tensorboard_get_experiments_compare_url(self): + actual = tensorboard_utils.get_experiments_compare_url( + ( + "projects/123/locations/asia-east1/tensorboards/456/experiments/exp1", + "projects/123/locations/asia-east1/tensorboards/456/experiments/exp2", + ) + ) + assert actual == ( + "https://asia-east1.tensorboard." + + "googleusercontent.com/compare/1-exp1:123+asia-east1+456+exp1," + + "2-exp2:123+asia-east1+456+exp2" + ) + + def test_tensorboard_get_experiments_compare_url_fail_just_one_exp(self): + with pytest.raises( + ValueError, match="At least two experiment_names are required." + ): + tensorboard_utils.get_experiments_compare_url( + ("projects/123/locations/asia-east1/tensorboards/456/experiments/exp1",) + ) + + def test_tensorboard_get_experiments_compare_url_fail_diff_region(self): + with pytest.raises( + ValueError, match="Got experiments from different locations: asia-east.", + ): + tensorboard_utils.get_experiments_compare_url( + ( + "projects/123/locations/asia-east1/tensorboards/456/experiments/exp1", + "projects/123/locations/asia-east2/tensorboards/456/experiments/exp2", + ) + ) + + def test_get_experiments_compare_url_bad_experiment_name(self): + with pytest.raises(ValueError, match="Invalid experiment name: foo-bar."): + tensorboard_utils.get_experiments_compare_url(("foo-bar", "foo-bar1"))