Skip to content

Commit

Permalink
feat: Upgrade Tensorboard from v1beta1 to v1 (#849)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg committed Nov 18, 2021
1 parent 8b9376e commit c40ec85
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 100 deletions.
10 changes: 9 additions & 1 deletion google/cloud/aiplatform/compat/__init__.py
Expand Up @@ -77,13 +77,14 @@
types.specialist_pool = types.specialist_pool_v1beta1
types.specialist_pool_service = types.specialist_pool_service_v1beta1
types.study = types.study_v1beta1
types.training_pipeline = types.training_pipeline_v1beta1
types.tensorboard = types.tensorboard_v1beta1
types.tensorboard_service = types.tensorboard_service_v1beta1
types.tensorboard_data = types.tensorboard_data_v1beta1
types.tensorboard_experiment = types.tensorboard_experiment_v1beta1
types.tensorboard_run = types.tensorboard_run_v1beta1
types.tensorboard_service = types.tensorboard_service_v1beta1
types.tensorboard_time_series = types.tensorboard_time_series_v1beta1
types.training_pipeline = types.training_pipeline_v1beta1

if DEFAULT_VERSION == V1:

Expand Down Expand Up @@ -135,6 +136,13 @@
types.specialist_pool = types.specialist_pool_v1
types.specialist_pool_service = types.specialist_pool_service_v1
types.study = types.study_v1
types.tensorboard = types.tensorboard_v1
types.tensorboard_service = types.tensorboard_service_v1
types.tensorboard_data = types.tensorboard_data_v1
types.tensorboard_experiment = types.tensorboard_experiment_v1
types.tensorboard_run = types.tensorboard_run_v1
types.tensorboard_service = types.tensorboard_service_v1
types.tensorboard_time_series = types.tensorboard_time_series_v1
types.training_pipeline = types.training_pipeline_v1

__all__ = (
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/compat/services/__init__.py
Expand Up @@ -67,6 +67,9 @@
from google.cloud.aiplatform_v1.services.specialist_pool_service import (
client as specialist_pool_service_client_v1,
)
from google.cloud.aiplatform_v1.services.tensorboard_service import (
client as tensorboard_service_client_v1,
)

__all__ = (
# v1
Expand All @@ -78,6 +81,7 @@
pipeline_service_client_v1,
prediction_service_client_v1,
specialist_pool_service_client_v1,
tensorboard_service_client_v1,
# v1beta1
dataset_service_client_v1beta1,
endpoint_service_client_v1beta1,
Expand Down
18 changes: 14 additions & 4 deletions google/cloud/aiplatform/compat/types/__init__.py
Expand Up @@ -57,13 +57,13 @@
specialist_pool as specialist_pool_v1beta1,
specialist_pool_service as specialist_pool_service_v1beta1,
study as study_v1beta1,
training_pipeline as training_pipeline_v1beta1,
tensorboard as tensorboard_v1beta1,
tensorboard_data as tensorboard_data_v1beta1,
tensorboard_experiment as tensorboard_experiment_v1beta1,
tensorboard_run as tensorboard_run_v1beta1,
tensorboard_service as tensorboard_service_v1beta1,
tensorboard_time_series as tensorboard_time_series_v1beta1,
training_pipeline as training_pipeline_v1beta1,
)
from google.cloud.aiplatform_v1.types import (
accelerator_type as accelerator_type_v1,
Expand Down Expand Up @@ -107,6 +107,12 @@
specialist_pool as specialist_pool_v1,
specialist_pool_service as specialist_pool_service_v1,
study as study_v1,
tensorboard as tensorboard_v1,
tensorboard_data as tensorboard_data_v1,
tensorboard_experiment as tensorboard_experiment_v1,
tensorboard_run as tensorboard_run_v1,
tensorboard_service as tensorboard_service_v1,
tensorboard_time_series as tensorboard_time_series_v1,
training_pipeline as training_pipeline_v1,
)

Expand Down Expand Up @@ -152,6 +158,12 @@
prediction_service_v1,
specialist_pool_v1,
specialist_pool_service_v1,
tensorboard_v1,
tensorboard_data_v1,
tensorboard_experiment_v1,
tensorboard_run_v1,
tensorboard_service_v1,
tensorboard_time_series_v1,
training_pipeline_v1,
# v1beta1
accelerator_type_v1beta1,
Expand Down Expand Up @@ -194,13 +206,11 @@
prediction_service_v1beta1,
specialist_pool_v1beta1,
specialist_pool_service_v1beta1,
training_pipeline_v1beta1,
metadata_service_v1beta1,
tensorboard_v1beta1,
tensorboard_service_v1beta1,
tensorboard_data_v1beta1,
tensorboard_experiment_v1beta1,
tensorboard_run_v1beta1,
tensorboard_service_v1beta1,
tensorboard_time_series_v1beta1,
training_pipeline_v1beta1,
)
22 changes: 2 additions & 20 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -35,12 +35,10 @@
batch_prediction_job as gca_bp_job_compat,
completion_stats as gca_completion_stats,
custom_job as gca_custom_job_compat,
custom_job_v1beta1 as gca_custom_job_v1beta1,
explanation as gca_explanation_compat,
io as gca_io_compat,
job_state as gca_job_state,
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
hyperparameter_tuning_job_v1beta1 as gca_hyperparameter_tuning_job_v1beta1,
machine_resources as gca_machine_resources_compat,
study as gca_study_compat,
)
Expand Down Expand Up @@ -1388,17 +1386,11 @@ def run(
self._gca_resource.job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
v1beta1_gca_resource._pb.MergeFromString(
self._gca_resource._pb.SerializeToString()
)
self._gca_resource = v1beta1_gca_resource
self._gca_resource.job_spec.tensorboard = tensorboard

_LOGGER.log_create_with_lro(self.__class__)

version = "v1beta1" if tensorboard else "v1"
self._gca_resource = self.api_client.select_version(version).create_custom_job(
self._gca_resource = self.api_client.create_custom_job(
parent=self._parent, custom_job=self._gca_resource
)

Expand Down Expand Up @@ -1773,21 +1765,11 @@ def run(
self._gca_resource.trial_job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = (
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
)
v1beta1_gca_resource._pb.MergeFromString(
self._gca_resource._pb.SerializeToString()
)
self._gca_resource = v1beta1_gca_resource
self._gca_resource.trial_job_spec.tensorboard = tensorboard

_LOGGER.log_create_with_lro(self.__class__)

version = "v1beta1" if tensorboard else "v1"
self._gca_resource = self.api_client.select_version(
version
).create_hyperparameter_tuning_job(
self._gca_resource = self.api_client.create_hyperparameter_tuning_job(
parent=self._parent, hyperparameter_tuning_job=self._gca_resource
)

Expand Down
12 changes: 3 additions & 9 deletions google/cloud/aiplatform/tensorboard/tensorboard.py
Expand Up @@ -18,17 +18,13 @@
from typing import Optional, Sequence, Dict, Tuple

from google.auth import credentials as auth_credentials
from google.protobuf import field_mask_pb2

from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform.compat.types import tensorboard as gca_tensorboard
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils


from google.cloud.aiplatform.compat.types import tensorboard_v1beta1 as gca_tensorboard

from google.protobuf import field_mask_pb2

_LOGGER = base.Logger(__name__)


Expand Down Expand Up @@ -156,8 +152,7 @@ def create(
)

encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name,
select_version=compat.V1BETA1,
encryption_spec_key_name=encryption_spec_key_name
)

gapic_tensorboard = gca_tensorboard.Tensorboard(
Expand Down Expand Up @@ -254,7 +249,6 @@ def update(
if encryption_spec_key_name:
encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name,
select_version=compat.V1BETA1,
)
update_mask.append("encryption_spec")

Expand Down
6 changes: 5 additions & 1 deletion google/cloud/aiplatform/utils/__init__.py
Expand Up @@ -51,6 +51,7 @@
model_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
tensorboard_service_client_v1,
)

from google.cloud.aiplatform.compat.types import (
Expand All @@ -67,6 +68,7 @@
pipeline_service_client_v1beta1.PipelineServiceClient,
job_service_client_v1beta1.JobServiceClient,
metadata_service_client_v1beta1.MetadataServiceClient,
tensorboard_service_client_v1beta1.TensorboardServiceClient,
# v1
dataset_service_client_v1.DatasetServiceClient,
endpoint_service_client_v1.EndpointServiceClient,
Expand All @@ -75,6 +77,7 @@
prediction_service_client_v1.PredictionServiceClient,
pipeline_service_client_v1.PipelineServiceClient,
job_service_client_v1.JobServiceClient,
tensorboard_service_client_v1.TensorboardServiceClient,
)

RESOURCE_NAME_PATTERN = re.compile(
Expand Down Expand Up @@ -506,8 +509,9 @@ class MetadataClientWithOverride(ClientWithOverride):

class TensorboardClientWithOverride(ClientWithOverride):
_is_temporary = False
_default_version = compat.V1BETA1
_default_version = compat.DEFAULT_VERSION
_version_map = (
(compat.V1, tensorboard_service_client_v1.TensorboardServiceClient),
(compat.V1BETA1, tensorboard_service_client_v1beta1.TensorboardServiceClient),
)

Expand Down
44 changes: 44 additions & 0 deletions tests/system/aiplatform/test_tensorboard.py
@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from google.cloud import aiplatform
from tests.system.aiplatform import e2e_base


class TestTensorboard(e2e_base.TestEndToEnd):

_temp_prefix = "temp-vertex-sdk-e2e-test"

def test_create_and_get_tensorboard(self, shared_state):

aiplatform.init(
project=e2e_base._PROJECT, location=e2e_base._LOCATION,
)

display_name = self._make_display_name("tensorboard")

tb = aiplatform.Tensorboard.create(display_name=display_name)

shared_state["resources"] = [tb]

get_tb = aiplatform.Tensorboard(tb.resource_name)

assert tb.resource_name == get_tb.resource_name

list_tb = aiplatform.Tensorboard.list()

assert len(list_tb) > 0
51 changes: 17 additions & 34 deletions tests/unit/aiplatform/test_custom_job.py
Expand Up @@ -31,18 +31,12 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
from google.cloud.aiplatform.compat.types import (
custom_job_v1beta1 as gca_custom_job_v1beta1,
)
from google.cloud.aiplatform.compat.types import io as gca_io_compat
from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec_compat,
)
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
from google.cloud.aiplatform_v1beta1.services.job_service import (
client as job_service_client_v1beta1,
)

_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
Expand Down Expand Up @@ -114,29 +108,16 @@
)


def _get_custom_job_proto(state=None, name=None, error=None, version="v1"):
def _get_custom_job_proto(state=None, name=None, error=None):
custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
custom_job_proto.name = name
custom_job_proto.state = state
custom_job_proto.error = error

if version == "v1beta1":
v1beta1_custom_job_proto = gca_custom_job_v1beta1.CustomJob()
v1beta1_custom_job_proto._pb.MergeFromString(
custom_job_proto._pb.SerializeToString()
)
custom_job_proto = v1beta1_custom_job_proto
custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME

return custom_job_proto


def _get_custom_job_proto_with_enable_web_access(
state=None, name=None, error=None, version="v1"
):
custom_job_proto = _get_custom_job_proto(
state=state, name=name, error=error, version=version
)
def _get_custom_job_proto_with_enable_web_access(state=None, name=None, error=None):
custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error)
custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS
if state == gca_job_state_compat.JobState.JOB_STATE_RUNNING:
custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS
Expand Down Expand Up @@ -260,24 +241,25 @@ def create_custom_job_mock_with_enable_web_access():


@pytest.fixture
def create_custom_job_mock_fail():
def create_custom_job_mock_with_tensorboard():
with mock.patch.object(
job_service_client.JobServiceClient, "create_custom_job"
) as create_custom_job_mock:
create_custom_job_mock.side_effect = RuntimeError("Mock fail")
custom_job_proto = _get_custom_job_proto(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
)
custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME
create_custom_job_mock.return_value = custom_job_proto
yield create_custom_job_mock


@pytest.fixture
def create_custom_job_v1beta1_mock():
def create_custom_job_mock_fail():
with mock.patch.object(
job_service_client_v1beta1.JobServiceClient, "create_custom_job"
job_service_client.JobServiceClient, "create_custom_job"
) as create_custom_job_mock:
create_custom_job_mock.return_value = _get_custom_job_proto(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
version="v1beta1",
)
create_custom_job_mock.side_effect = RuntimeError("Mock fail")
yield create_custom_job_mock


Expand Down Expand Up @@ -573,7 +555,7 @@ def test_get_web_access_uris_job_succeeded(

@pytest.mark.parametrize("sync", [True, False])
def test_create_custom_job_with_tensorboard(
self, create_custom_job_v1beta1_mock, get_custom_job_mock, sync
self, create_custom_job_mock_with_tensorboard, get_custom_job_mock, sync
):

aiplatform.init(
Expand Down Expand Up @@ -601,9 +583,10 @@ def test_create_custom_job_with_tensorboard(

job.wait()

expected_custom_job = _get_custom_job_proto(version="v1beta1")
expected_custom_job = _get_custom_job_proto()
expected_custom_job.job_spec.tensorboard = _TEST_TENSORBOARD_NAME

create_custom_job_v1beta1_mock.assert_called_once_with(
create_custom_job_mock_with_tensorboard.assert_called_once_with(
parent=_TEST_PARENT, custom_job=expected_custom_job
)

Expand Down

0 comments on commit c40ec85

Please sign in to comment.