Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add training_utils folder and environment_variables for training
- Loading branch information
Showing
3 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
76 changes: 76 additions & 0 deletions
76
google/cloud/aiplatform/training_utils/environment_variables.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import json | ||
import os | ||
|
||
from typing import Dict, Optional | ||
|
||
|
||
def _json_helper(env_var: str) -> Optional[Dict]: | ||
"""Helper to convert a dictionary represented as a string to a dictionary. | ||
Args: | ||
env_var (str): | ||
Required. The name of the environment variable. | ||
Returns: | ||
A dictionary if the variable was found, None otherwise. | ||
""" | ||
env = os.environ.get(env_var) | ||
if env is not None: | ||
return json.loads(env) | ||
else: | ||
return None | ||
|
||
|
||
# Cloud Storage URI of a directory intended for training data. | ||
training_data_uri = os.environ.get("AIP_TRAINING_DATA_URI") | ||
|
||
# Cloud Storage URI of a directory intended for validation data. | ||
validation_data_uri = os.environ.get("AIP_VALIDATION_DATA_URI") | ||
|
||
# Cloud Storage URI of a directory intended for test data. | ||
test_data_uri = os.environ.get("AIP_TEST_DATA_URI") | ||
|
||
# Cloud Storage URI of a directory intended for saving model artefacts. | ||
model_dir = os.environ.get("AIP_MODEL_DIR") | ||
|
||
# Cloud Storage URI of a directory intended for saving checkpoints. | ||
checkpoint_dir = os.environ.get("AIP_CHECKPOINT_DIR") | ||
|
||
# Cloud Storage URI of a directory intended for saving TensorBoard logs. | ||
tensorboard_log_dir = os.environ.get("AIP_TENSORBOARD_LOG_DIR") | ||
|
||
# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#cluster-variables | ||
cluster_spec = _json_helper("CLUSTER_SPEC") | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#tf-config | ||
tf_config = _json_helper("TF_CONFIG") | ||
|
||
# Profiler port used for capturing profiling samples. | ||
tf_profiler_port = os.environ.get("AIP_TF_PROFILER_PORT") | ||
|
||
# API URI used for the tensorboard uploader. | ||
tensorboard_api_uri = os.environ.get("AIP_TENSORBOARD_API_URI") | ||
|
||
# The name of the tensorboard resource, in the form: | ||
# `projects/{project_id}/locations/{location}/tensorboards/{tensorboard_name}` | ||
tensorboard_resource_name = os.environ.get("AIP_TENSORBOARD_RESOURCE_NAME") | ||
|
||
# The name given to the training job. | ||
cloud_ml_job_id = os.environ.get("CLOUD_ML_JOB_ID") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from importlib import reload | ||
import json | ||
import os | ||
import pytest | ||
|
||
from google.cloud.aiplatform.training_utils import environment_variables | ||
from unittest import mock | ||
|
||
_TEST_TRAINING_DATA_URI = "gs://training-data-uri" | ||
_TEST_VALIDATION_DATA_URI = "gs://test-validation-data-uri" | ||
_TEST_TEST_DATA_URI = "gs://test-data-uri" | ||
_TEST_MODEL_DIR = "gs://test-model-dir" | ||
_TEST_CHECKPOINT_DIR = "gs://test-checkpoint-dir" | ||
_TEST_TENSORBOARD_LOG_DIR = "gs://test-tensorboard-log-dir" | ||
_TEST_CLUSTER_SPEC = """{ | ||
"cluster": { | ||
"worker_pools":[ | ||
{ | ||
"index":0, | ||
"replicas":[ | ||
"training-workerpool0-ab-0:2222" | ||
] | ||
}, | ||
{ | ||
"index":1, | ||
"replicas":[ | ||
"training-workerpool1-ab-0:2222", | ||
"training-workerpool1-ab-1:2222" | ||
] | ||
} | ||
] | ||
}, | ||
"environment": "cloud", | ||
"task": { | ||
"worker_pool_index":0, | ||
"replica_index":0, | ||
"trial":"TRIAL_ID" | ||
} | ||
}""" | ||
_TEST_AIP_TF_PROFILER_PORT = "1234" | ||
_TEST_TENSORBOARD_API_URI = "http://testuri.com" | ||
_TEST_TENSORBOARD_RESOURCE_NAME = ( | ||
"projects/myproj/locations/us-central1/tensorboards/1234" | ||
) | ||
_TEST_CLOUD_ML_JOB_ID = "myjob" | ||
|
||
|
||
class TestTrainingUtils: | ||
@pytest.fixture | ||
def mock_environment(self): | ||
env_vars = { | ||
"AIP_TRAINING_DATA_URI": _TEST_TRAINING_DATA_URI, | ||
"AIP_VALIDATION_DATA_URI": _TEST_VALIDATION_DATA_URI, | ||
"AIP_TEST_DATA_URI": _TEST_TEST_DATA_URI, | ||
"AIP_MODEL_DIR": _TEST_MODEL_DIR, | ||
"AIP_CHECKPOINT_DIR": _TEST_CHECKPOINT_DIR, | ||
"AIP_TENSORBOARD_LOG_DIR": _TEST_TENSORBOARD_LOG_DIR, | ||
"AIP_TF_PROFILER_PORT": _TEST_AIP_TF_PROFILER_PORT, | ||
"AIP_TENSORBOARD_API_URI": _TEST_TENSORBOARD_API_URI, | ||
"AIP_TENSORBOARD_RESOURCE_NAME": _TEST_TENSORBOARD_RESOURCE_NAME, | ||
"CLOUD_ML_JOB_ID": _TEST_CLOUD_ML_JOB_ID, | ||
"CLUSTER_SPEC": _TEST_CLUSTER_SPEC, | ||
"TF_CONFIG": _TEST_CLUSTER_SPEC, | ||
} | ||
with mock.patch.dict(os.environ, env_vars, clear=True): | ||
yield | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_training_data_uri(self): | ||
reload(environment_variables) | ||
assert environment_variables.training_data_uri == _TEST_TRAINING_DATA_URI | ||
|
||
def test_training_data_uri_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.training_data_uri is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_validation_data_uri(self): | ||
reload(environment_variables) | ||
assert environment_variables.validation_data_uri == _TEST_VALIDATION_DATA_URI | ||
|
||
def test_validation_data_uri_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.validation_data_uri is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_test_data_uri(self): | ||
reload(environment_variables) | ||
assert environment_variables.test_data_uri == _TEST_TEST_DATA_URI | ||
|
||
def test_test_data_uri_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.test_data_uri is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_model_dir(self): | ||
reload(environment_variables) | ||
assert environment_variables.model_dir == _TEST_MODEL_DIR | ||
|
||
def test_model_dir_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.model_dir is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_checkpoint_dir(self): | ||
reload(environment_variables) | ||
assert environment_variables.checkpoint_dir == _TEST_CHECKPOINT_DIR | ||
|
||
def test_checkpoint_dir_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.checkpoint_dir is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_tensorboard_log_dir(self): | ||
reload(environment_variables) | ||
assert environment_variables.tensorboard_log_dir == _TEST_TENSORBOARD_LOG_DIR | ||
|
||
def test_tensorboard_log_dir_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.tensorboard_log_dir is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_cluster_spec(self): | ||
reload(environment_variables) | ||
assert environment_variables.cluster_spec == json.loads(_TEST_CLUSTER_SPEC) | ||
|
||
def test_cluster_spec_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.cluster_spec is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_tf_config(self): | ||
reload(environment_variables) | ||
assert environment_variables.tf_config == json.loads(_TEST_CLUSTER_SPEC) | ||
|
||
def test_tf_config_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.tf_config is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_tf_profiler_port(self): | ||
reload(environment_variables) | ||
assert environment_variables.tf_profiler_port == _TEST_AIP_TF_PROFILER_PORT | ||
|
||
def test_tf_profiler_port_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.tf_profiler_port is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_tensorboard_api_uri(self): | ||
reload(environment_variables) | ||
assert environment_variables.tensorboard_api_uri == _TEST_TENSORBOARD_API_URI | ||
|
||
def test_tensorboard_api_uri_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.tensorboard_api_uri is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_tensorboard_resource_name(self): | ||
reload(environment_variables) | ||
assert ( | ||
environment_variables.tensorboard_resource_name | ||
== _TEST_TENSORBOARD_RESOURCE_NAME | ||
) | ||
|
||
def test_tensorboard_resource_name_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.tensorboard_resource_name is None | ||
|
||
@pytest.mark.usefixtures("mock_environment") | ||
def test_cloud_ml_job_id(self): | ||
reload(environment_variables) | ||
assert environment_variables.cloud_ml_job_id == _TEST_CLOUD_ML_JOB_ID | ||
|
||
def test_cloud_ml_job_id_none(self): | ||
reload(environment_variables) | ||
assert environment_variables.cloud_ml_job_id is None |
What happens if
CLUSTER_SPEC
contains broken JSON data?