From 6bc4c612d5471911f82ee5ada9fb3a9307ee836f Mon Sep 17 00:00:00 2001 From: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> Date: Mon, 10 May 2021 16:34:25 -0400 Subject: [PATCH] fix: env formatiing (#379) --- google/cloud/aiplatform/training_jobs.py | 17 +++++++++--- tests/unit/aiplatform/test_training_jobs.py | 30 ++++++++++++++++----- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index f80174efdc..8ef054fc97 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -2121,7 +2121,10 @@ def _run( spec["pythonPackageSpec"]["args"] = args if environment_variables: - spec["pythonPackageSpec"]["env"] = environment_variables + spec["pythonPackageSpec"]["env"] = [ + {"name": key, "value": value} + for key, value in environment_variables.items() + ] ( training_task_inputs, @@ -2671,7 +2674,10 @@ def _run( spec["containerSpec"]["args"] = args if environment_variables: - spec["containerSpec"]["env"] = environment_variables + spec["containerSpec"]["env"] = [ + {"name": key, "value": value} + for key, value in environment_variables.items() + ] ( training_task_inputs, @@ -3734,7 +3740,7 @@ def run( Args: dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset,]): AI Platform to fit this training against. Custom training script should - retrieve datasets through passed in environement variables uris: + retrieve datasets through passed in environment variables uris: os.environ["AIP_TRAINING_DATA_URI"] os.environ["AIP_VALIDATION_DATA_URI"] @@ -3984,7 +3990,10 @@ def _run( spec["pythonPackageSpec"]["args"] = args if environment_variables: - spec["pythonPackageSpec"]["env"] = environment_variables + spec["pythonPackageSpec"]["env"] = [ + {"name": key, "value": value} + for key, value in environment_variables.items() + ] ( training_task_inputs, diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 44e662a36e..c3c0e33863 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -622,7 +622,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( ) true_args = _TEST_RUN_ARGS - true_env = _TEST_ENVIRONMENT_VARIABLES + true_env = [ + {"name": key, "value": value} + for key, value in _TEST_ENVIRONMENT_VARIABLES.items() + ] true_worker_pool_spec = { "replicaCount": _TEST_REPLICA_COUNT, @@ -777,7 +780,10 @@ def test_run_call_pipeline_service_create_with_bigquery_destination( model_from_job.wait() true_args = _TEST_RUN_ARGS - true_env = _TEST_ENVIRONMENT_VARIABLES + true_env = [ + {"name": key, "value": value} + for key, value in _TEST_ENVIRONMENT_VARIABLES.items() + ] true_worker_pool_spec = { "replicaCount": _TEST_REPLICA_COUNT, @@ -1049,7 +1055,10 @@ def test_run_call_pipeline_service_create_with_no_dataset( ) true_args = _TEST_RUN_ARGS - true_env = _TEST_ENVIRONMENT_VARIABLES + true_env = [ + {"name": key, "value": value} + for key, value in _TEST_ENVIRONMENT_VARIABLES.items() + ] true_worker_pool_spec = { "replicaCount": _TEST_REPLICA_COUNT, @@ -1297,7 +1306,10 @@ def test_run_call_pipeline_service_create_distributed_training( ) true_args = _TEST_RUN_ARGS - true_env = _TEST_ENVIRONMENT_VARIABLES + true_env = [ + {"name": key, "value": value} + for key, value in _TEST_ENVIRONMENT_VARIABLES.items() + ] true_worker_pool_spec = [ { @@ -1763,7 +1775,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_from_job.wait() true_args = _TEST_RUN_ARGS - true_env = _TEST_ENVIRONMENT_VARIABLES + true_env = [ + {"name": key, "value": value} + for key, value in _TEST_ENVIRONMENT_VARIABLES.items() + ] true_worker_pool_spec = { "replicaCount": _TEST_REPLICA_COUNT, @@ -2972,7 +2987,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_from_job.wait() true_args = _TEST_RUN_ARGS - true_env = _TEST_ENVIRONMENT_VARIABLES + true_env = [ + {"name": key, "value": value} + for key, value in _TEST_ENVIRONMENT_VARIABLES.items() + ] true_worker_pool_spec = { "replicaCount": _TEST_REPLICA_COUNT,