Skip to content

Commit

Permalink
feat: add util functions to get URLs for Tensorboard web app. (#635)
Browse files Browse the repository at this point in the history
  • Loading branch information
karthikv2k committed Aug 23, 2021
1 parent 52a7b7c commit 8d88c00
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
93 changes: 93 additions & 0 deletions google/cloud/aiplatform/utils/tensorboard_utils.py
@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Sequence, Dict
from google.cloud.aiplatform_v1beta1.services.tensorboard_service.client import (
TensorboardServiceClient,
)

_SERVING_DOMAIN = "tensorboard.googleusercontent.com"


def _parse_experiment_name(experiment_name: str) -> Dict[str, str]:
"""Parses an experiment_name into its component segments.
Args:
experiment_name: Resource name of the TensorboardExperiment. E.g.
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
Returns:
Components of the experiment name.
Raises:
ValueError if the experiment_name is invalid.
"""
matched = TensorboardServiceClient.parse_tensorboard_experiment_path(
experiment_name
)
if not matched:
raise ValueError(f"Invalid experiment name: {experiment_name}.")
return matched


def get_experiment_url(experiment_name: str) -> str:
"""Get URL for comparing experiments.
Args:
experiment_name: Resource name of the TensorboardExperiment. E.g.
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
Returns:
URL for the tensorboard web app.
"""
location = _parse_experiment_name(experiment_name)["location"]
name_for_url = experiment_name.replace("/", "+")
return f"https://{location}.{_SERVING_DOMAIN}/experiment/{name_for_url}"


def get_experiments_compare_url(experiment_names: Sequence[str]) -> str:
"""Get URL for comparing experiments.
Args:
experiment_names: Resource names of the TensorboardExperiments that needs to
be compared.
Returns:
URL for the tensorboard web app.
"""
if len(experiment_names) < 2:
raise ValueError("At least two experiment_names are required.")

locations = {
_parse_experiment_name(experiment_name)["location"]
for experiment_name in experiment_names
}
if len(locations) != 1:
raise ValueError(
f"Got experiments from different locations: {', '.join(locations)}."
)
location = locations.pop()

experiment_url_segments = []
for idx, experiment_name in enumerate(experiment_names):
name_segments = _parse_experiment_name(experiment_name)
experiment_url_segments.append(
"{cnt}-{experiment}:{project}+{location}+{tensorboard}+{experiment}".format(
cnt=idx + 1, **name_segments
)
)
encoded_names = ",".join(experiment_url_segments)
return f"https://{location}.{_SERVING_DOMAIN}/compare/{encoded_names}"
52 changes: 52 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Expand Up @@ -28,6 +28,7 @@
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import pipeline_utils
from google.cloud.aiplatform.utils import tensorboard_utils

from google.cloud.aiplatform_v1beta1.services.model_service import (
client as model_service_client_v1beta1,
Expand Down Expand Up @@ -454,3 +455,54 @@ def test_pipeline_utils_runtime_config_builder_parameter_not_found(self):
my_builder.build()

assert e.match(regexp=r"The pipeline parameter no_such_param is not found")


class TestTensorboardUtils:
def test_tensorboard_get_experiment_url(self):
actual = tensorboard_utils.get_experiment_url(
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
)
assert actual == (
"https://asia-east1.tensorboard."
+ "googleusercontent.com/experiment/projects+123+locations+asia-east1+tensorboards+456+experiments+exp1"
)

def test_get_experiments_url_bad_experiment_name(self):
with pytest.raises(ValueError, match="Invalid experiment name: foo-bar."):
tensorboard_utils.get_experiment_url("foo-bar")

def test_tensorboard_get_experiments_compare_url(self):
actual = tensorboard_utils.get_experiments_compare_url(
(
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1",
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp2",
)
)
assert actual == (
"https://asia-east1.tensorboard."
+ "googleusercontent.com/compare/1-exp1:123+asia-east1+456+exp1,"
+ "2-exp2:123+asia-east1+456+exp2"
)

def test_tensorboard_get_experiments_compare_url_fail_just_one_exp(self):
with pytest.raises(
ValueError, match="At least two experiment_names are required."
):
tensorboard_utils.get_experiments_compare_url(
("projects/123/locations/asia-east1/tensorboards/456/experiments/exp1",)
)

def test_tensorboard_get_experiments_compare_url_fail_diff_region(self):
with pytest.raises(
ValueError, match="Got experiments from different locations: asia-east.",
):
tensorboard_utils.get_experiments_compare_url(
(
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1",
"projects/123/locations/asia-east2/tensorboards/456/experiments/exp2",
)
)

def test_get_experiments_compare_url_bad_experiment_name(self):
with pytest.raises(ValueError, match="Invalid experiment name: foo-bar."):
tensorboard_utils.get_experiments_compare_url(("foo-bar", "foo-bar1"))

0 comments on commit 8d88c00

Please sign in to comment.