Skip to content

Commit

Permalink
fix: env formatiing (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg committed May 10, 2021
1 parent 8945865 commit 6bc4c61
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
17 changes: 13 additions & 4 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -2121,7 +2121,10 @@ def _run(
spec["pythonPackageSpec"]["args"] = args

if environment_variables:
spec["pythonPackageSpec"]["env"] = environment_variables
spec["pythonPackageSpec"]["env"] = [
{"name": key, "value": value}
for key, value in environment_variables.items()
]

(
training_task_inputs,
Expand Down Expand Up @@ -2671,7 +2674,10 @@ def _run(
spec["containerSpec"]["args"] = args

if environment_variables:
spec["containerSpec"]["env"] = environment_variables
spec["containerSpec"]["env"] = [
{"name": key, "value": value}
for key, value in environment_variables.items()
]

(
training_task_inputs,
Expand Down Expand Up @@ -3734,7 +3740,7 @@ def run(
Args:
dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset,]):
AI Platform to fit this training against. Custom training script should
retrieve datasets through passed in environement variables uris:
retrieve datasets through passed in environment variables uris:
os.environ["AIP_TRAINING_DATA_URI"]
os.environ["AIP_VALIDATION_DATA_URI"]
Expand Down Expand Up @@ -3984,7 +3990,10 @@ def _run(
spec["pythonPackageSpec"]["args"] = args

if environment_variables:
spec["pythonPackageSpec"]["env"] = environment_variables
spec["pythonPackageSpec"]["env"] = [
{"name": key, "value": value}
for key, value in environment_variables.items()
]

(
training_task_inputs,
Expand Down
30 changes: 24 additions & 6 deletions tests/unit/aiplatform/test_training_jobs.py
Expand Up @@ -622,7 +622,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
)

true_args = _TEST_RUN_ARGS
true_env = _TEST_ENVIRONMENT_VARIABLES
true_env = [
{"name": key, "value": value}
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

true_worker_pool_spec = {
"replicaCount": _TEST_REPLICA_COUNT,
Expand Down Expand Up @@ -777,7 +780,10 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
model_from_job.wait()

true_args = _TEST_RUN_ARGS
true_env = _TEST_ENVIRONMENT_VARIABLES
true_env = [
{"name": key, "value": value}
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

true_worker_pool_spec = {
"replicaCount": _TEST_REPLICA_COUNT,
Expand Down Expand Up @@ -1049,7 +1055,10 @@ def test_run_call_pipeline_service_create_with_no_dataset(
)

true_args = _TEST_RUN_ARGS
true_env = _TEST_ENVIRONMENT_VARIABLES
true_env = [
{"name": key, "value": value}
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

true_worker_pool_spec = {
"replicaCount": _TEST_REPLICA_COUNT,
Expand Down Expand Up @@ -1297,7 +1306,10 @@ def test_run_call_pipeline_service_create_distributed_training(
)

true_args = _TEST_RUN_ARGS
true_env = _TEST_ENVIRONMENT_VARIABLES
true_env = [
{"name": key, "value": value}
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

true_worker_pool_spec = [
{
Expand Down Expand Up @@ -1763,7 +1775,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
model_from_job.wait()

true_args = _TEST_RUN_ARGS
true_env = _TEST_ENVIRONMENT_VARIABLES
true_env = [
{"name": key, "value": value}
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

true_worker_pool_spec = {
"replicaCount": _TEST_REPLICA_COUNT,
Expand Down Expand Up @@ -2972,7 +2987,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
model_from_job.wait()

true_args = _TEST_RUN_ARGS
true_env = _TEST_ENVIRONMENT_VARIABLES
true_env = [
{"name": key, "value": value}
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

true_worker_pool_spec = {
"replicaCount": _TEST_REPLICA_COUNT,
Expand Down

0 comments on commit 6bc4c61

Please sign in to comment.