Skip to content

Commit 2e627f8

Browse files
authored
fix: check if training_task_metadata is populated before logging backingCustomJob (#494)
1 parent 3b19fff commit 2e627f8

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,8 @@ def _prepare_training_task_inputs_and_output_dir(
12431243

12441244
def _wait_callback(self):
12451245
if (
1246-
self._gca_resource.training_task_metadata.get("backingCustomJob")
1246+
self._gca_resource.training_task_metadata
1247+
and self._gca_resource.training_task_metadata.get("backingCustomJob")
12471248
and not self._has_logged_custom_job
12481249
):
12491250
_LOGGER.info(f"View backing custom job:\n{self._custom_job_console_uri()}")

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,15 @@ def mock_pipeline_service_create():
443443
yield mock_create_training_pipeline
444444

445445

446-
def make_training_pipeline(state):
446+
def make_training_pipeline(state, add_training_task_metadata=True):
447447
return gca_training_pipeline.TrainingPipeline(
448448
name=_TEST_PIPELINE_RESOURCE_NAME,
449449
state=state,
450450
model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME),
451451
training_task_inputs={"tensorboard": _TEST_TENSORBOARD_RESOURCE_NAME},
452-
training_task_metadata={"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME},
452+
training_task_metadata={"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME}
453+
if add_training_task_metadata
454+
else None,
453455
)
454456

455457

@@ -460,7 +462,11 @@ def mock_pipeline_service_get():
460462
) as mock_get_training_pipeline:
461463
mock_get_training_pipeline.side_effect = [
462464
make_training_pipeline(
463-
gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING
465+
gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
466+
add_training_task_metadata=False,
467+
),
468+
make_training_pipeline(
469+
gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
464470
),
465471
make_training_pipeline(
466472
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED

0 commit comments

Comments
 (0)