Skip to content

Commit

Permalink
feat: Add training_utils folder and environment_variables for training
Browse files Browse the repository at this point in the history
  • Loading branch information
mkovalski committed Oct 21, 2021
1 parent 0477f5a commit 141c008
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 0 deletions.
15 changes: 15 additions & 0 deletions google/cloud/aiplatform/training_utils/__init__.py
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
76 changes: 76 additions & 0 deletions google/cloud/aiplatform/training_utils/environment_variables.py
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import os

from typing import Dict, Optional


def _json_helper(env_var: str) -> Optional[Dict]:
"""Helper to convert a dictionary represented as a string to a dictionary.
Args:
env_var (str):
Required. The name of the environment variable.
Returns:
A dictionary if the variable was found, None otherwise.
"""
env = os.environ.get(env_var)
if env is not None:
return json.loads(env)
else:
return None


# Cloud Storage URI of a directory intended for training data.
training_data_uri = os.environ.get("AIP_TRAINING_DATA_URI")

# Cloud Storage URI of a directory intended for validation data.
validation_data_uri = os.environ.get("AIP_VALIDATION_DATA_URI")

# Cloud Storage URI of a directory intended for test data.
test_data_uri = os.environ.get("AIP_TEST_DATA_URI")

# Cloud Storage URI of a directory intended for saving model artefacts.
model_dir = os.environ.get("AIP_MODEL_DIR")

# Cloud Storage URI of a directory intended for saving checkpoints.
checkpoint_dir = os.environ.get("AIP_CHECKPOINT_DIR")

# Cloud Storage URI of a directory intended for saving TensorBoard logs.
tensorboard_log_dir = os.environ.get("AIP_TENSORBOARD_LOG_DIR")

# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#cluster-variables
cluster_spec = _json_helper("CLUSTER_SPEC")

This comment has been minimized.

Copy link
@Ark-kun

Ark-kun Sep 12, 2022

Contributor

What happens if CLUSTER_SPEC contains broken JSON data?


# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#tf-config
tf_config = _json_helper("TF_CONFIG")

# Profiler port used for capturing profiling samples.
tf_profiler_port = os.environ.get("AIP_TF_PROFILER_PORT")

# API URI used for the tensorboard uploader.
tensorboard_api_uri = os.environ.get("AIP_TENSORBOARD_API_URI")

# The name of the tensorboard resource, in the form:
# `projects/{project_id}/locations/{location}/tensorboards/{tensorboard_name}`
tensorboard_resource_name = os.environ.get("AIP_TENSORBOARD_RESOURCE_NAME")

# The name given to the training job.
cloud_ml_job_id = os.environ.get("CLOUD_ML_JOB_ID")
194 changes: 194 additions & 0 deletions tests/unit/aiplatform/test_training_utils.py
@@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from importlib import reload
import json
import os
import pytest

from google.cloud.aiplatform.training_utils import environment_variables
from unittest import mock

_TEST_TRAINING_DATA_URI = "gs://training-data-uri"
_TEST_VALIDATION_DATA_URI = "gs://test-validation-data-uri"
_TEST_TEST_DATA_URI = "gs://test-data-uri"
_TEST_MODEL_DIR = "gs://test-model-dir"
_TEST_CHECKPOINT_DIR = "gs://test-checkpoint-dir"
_TEST_TENSORBOARD_LOG_DIR = "gs://test-tensorboard-log-dir"
_TEST_CLUSTER_SPEC = """{
"cluster": {
"worker_pools":[
{
"index":0,
"replicas":[
"training-workerpool0-ab-0:2222"
]
},
{
"index":1,
"replicas":[
"training-workerpool1-ab-0:2222",
"training-workerpool1-ab-1:2222"
]
}
]
},
"environment": "cloud",
"task": {
"worker_pool_index":0,
"replica_index":0,
"trial":"TRIAL_ID"
}
}"""
_TEST_AIP_TF_PROFILER_PORT = "1234"
_TEST_TENSORBOARD_API_URI = "http://testuri.com"
_TEST_TENSORBOARD_RESOURCE_NAME = (
"projects/myproj/locations/us-central1/tensorboards/1234"
)
_TEST_CLOUD_ML_JOB_ID = "myjob"


class TestTrainingUtils:
@pytest.fixture
def mock_environment(self):
env_vars = {
"AIP_TRAINING_DATA_URI": _TEST_TRAINING_DATA_URI,
"AIP_VALIDATION_DATA_URI": _TEST_VALIDATION_DATA_URI,
"AIP_TEST_DATA_URI": _TEST_TEST_DATA_URI,
"AIP_MODEL_DIR": _TEST_MODEL_DIR,
"AIP_CHECKPOINT_DIR": _TEST_CHECKPOINT_DIR,
"AIP_TENSORBOARD_LOG_DIR": _TEST_TENSORBOARD_LOG_DIR,
"AIP_TF_PROFILER_PORT": _TEST_AIP_TF_PROFILER_PORT,
"AIP_TENSORBOARD_API_URI": _TEST_TENSORBOARD_API_URI,
"AIP_TENSORBOARD_RESOURCE_NAME": _TEST_TENSORBOARD_RESOURCE_NAME,
"CLOUD_ML_JOB_ID": _TEST_CLOUD_ML_JOB_ID,
"CLUSTER_SPEC": _TEST_CLUSTER_SPEC,
"TF_CONFIG": _TEST_CLUSTER_SPEC,
}
with mock.patch.dict(os.environ, env_vars, clear=True):
yield

@pytest.mark.usefixtures("mock_environment")
def test_training_data_uri(self):
reload(environment_variables)
assert environment_variables.training_data_uri == _TEST_TRAINING_DATA_URI

def test_training_data_uri_none(self):
reload(environment_variables)
assert environment_variables.training_data_uri is None

@pytest.mark.usefixtures("mock_environment")
def test_validation_data_uri(self):
reload(environment_variables)
assert environment_variables.validation_data_uri == _TEST_VALIDATION_DATA_URI

def test_validation_data_uri_none(self):
reload(environment_variables)
assert environment_variables.validation_data_uri is None

@pytest.mark.usefixtures("mock_environment")
def test_test_data_uri(self):
reload(environment_variables)
assert environment_variables.test_data_uri == _TEST_TEST_DATA_URI

def test_test_data_uri_none(self):
reload(environment_variables)
assert environment_variables.test_data_uri is None

@pytest.mark.usefixtures("mock_environment")
def test_model_dir(self):
reload(environment_variables)
assert environment_variables.model_dir == _TEST_MODEL_DIR

def test_model_dir_none(self):
reload(environment_variables)
assert environment_variables.model_dir is None

@pytest.mark.usefixtures("mock_environment")
def test_checkpoint_dir(self):
reload(environment_variables)
assert environment_variables.checkpoint_dir == _TEST_CHECKPOINT_DIR

def test_checkpoint_dir_none(self):
reload(environment_variables)
assert environment_variables.checkpoint_dir is None

@pytest.mark.usefixtures("mock_environment")
def test_tensorboard_log_dir(self):
reload(environment_variables)
assert environment_variables.tensorboard_log_dir == _TEST_TENSORBOARD_LOG_DIR

def test_tensorboard_log_dir_none(self):
reload(environment_variables)
assert environment_variables.tensorboard_log_dir is None

@pytest.mark.usefixtures("mock_environment")
def test_cluster_spec(self):
reload(environment_variables)
assert environment_variables.cluster_spec == json.loads(_TEST_CLUSTER_SPEC)

def test_cluster_spec_none(self):
reload(environment_variables)
assert environment_variables.cluster_spec is None

@pytest.mark.usefixtures("mock_environment")
def test_tf_config(self):
reload(environment_variables)
assert environment_variables.tf_config == json.loads(_TEST_CLUSTER_SPEC)

def test_tf_config_none(self):
reload(environment_variables)
assert environment_variables.tf_config is None

@pytest.mark.usefixtures("mock_environment")
def test_tf_profiler_port(self):
reload(environment_variables)
assert environment_variables.tf_profiler_port == _TEST_AIP_TF_PROFILER_PORT

def test_tf_profiler_port_none(self):
reload(environment_variables)
assert environment_variables.tf_profiler_port is None

@pytest.mark.usefixtures("mock_environment")
def test_tensorboard_api_uri(self):
reload(environment_variables)
assert environment_variables.tensorboard_api_uri == _TEST_TENSORBOARD_API_URI

def test_tensorboard_api_uri_none(self):
reload(environment_variables)
assert environment_variables.tensorboard_api_uri is None

@pytest.mark.usefixtures("mock_environment")
def test_tensorboard_resource_name(self):
reload(environment_variables)
assert (
environment_variables.tensorboard_resource_name
== _TEST_TENSORBOARD_RESOURCE_NAME
)

def test_tensorboard_resource_name_none(self):
reload(environment_variables)
assert environment_variables.tensorboard_resource_name is None

@pytest.mark.usefixtures("mock_environment")
def test_cloud_ml_job_id(self):
reload(environment_variables)
assert environment_variables.cloud_ml_job_id == _TEST_CLOUD_ML_JOB_ID

def test_cloud_ml_job_id_none(self):
reload(environment_variables)
assert environment_variables.cloud_ml_job_id is None

0 comments on commit 141c008

Please sign in to comment.