From 2e627f876e1d7dd03e5d6bd2e81e6234e361a9df Mon Sep 17 00:00:00 2001 From: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> Date: Mon, 21 Jun 2021 18:55:42 -0400 Subject: [PATCH] fix: check if training_task_metadata is populated before logging backingCustomJob (#494) --- google/cloud/aiplatform/training_jobs.py | 3 ++- tests/unit/aiplatform/test_training_jobs.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 99f4f088a5..0b66c74fc1 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1243,7 +1243,8 @@ def _prepare_training_task_inputs_and_output_dir( def _wait_callback(self): if ( - self._gca_resource.training_task_metadata.get("backingCustomJob") + self._gca_resource.training_task_metadata + and self._gca_resource.training_task_metadata.get("backingCustomJob") and not self._has_logged_custom_job ): _LOGGER.info(f"View backing custom job:\n{self._custom_job_console_uri()}") diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index b160204d7d..0995e0cb95 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -443,13 +443,15 @@ def mock_pipeline_service_create(): yield mock_create_training_pipeline -def make_training_pipeline(state): +def make_training_pipeline(state, add_training_task_metadata=True): return gca_training_pipeline.TrainingPipeline( name=_TEST_PIPELINE_RESOURCE_NAME, state=state, model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), training_task_inputs={"tensorboard": _TEST_TENSORBOARD_RESOURCE_NAME}, - training_task_metadata={"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME}, + training_task_metadata={"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME} + if add_training_task_metadata + else None, ) @@ -460,7 +462,11 @@ def mock_pipeline_service_get(): ) as mock_get_training_pipeline: mock_get_training_pipeline.side_effect = [ make_training_pipeline( - gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING + gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + add_training_task_metadata=False, + ), + make_training_pipeline( + gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, ), make_training_pipeline( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED