diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index 84b39c2ae8..7214e6247a 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -44,6 +44,10 @@ ] ) +_PIPELINE_ERROR_STATES = set( + [gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED] +) + # Vertex AI Pipelines service API job name relative name prefix pattern. _JOB_NAME_PATTERN = "{parent}/pipelineJobs/{job_id}" @@ -311,6 +315,13 @@ def _block_until_complete(self): previous_time = current_time time.sleep(wait) + # Error is only populated when the job state is + # JOB_STATE_FAILED or JOB_STATE_CANCELLED. + if self._gca_resource.state in _PIPELINE_ERROR_STATES: + raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error) + else: + _LOGGER.log_action_completed_against_resource("run", "completed", self) + def cancel(self) -> None: """Starts asynchronous cancellation on the PipelineJob. The server makes a best effort to cancel the job, but success is not guaranteed. diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index 18dc692d38..e02082b646 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -133,6 +133,26 @@ def mock_pipeline_service_get(): yield mock_get_pipeline_job +@pytest.fixture +def mock_pipeline_service_get_with_fail(): + with mock.patch.object( + pipeline_service_client_v1beta1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_pipeline_job: + mock_get_pipeline_job.side_effect = [ + make_pipeline_job( + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING + ), + make_pipeline_job( + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING + ), + make_pipeline_job( + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED + ), + ] + + yield mock_get_pipeline_job + + @pytest.fixture def mock_pipeline_service_cancel(): with mock.patch.object( @@ -269,3 +289,33 @@ def test_cancel_pipeline_job_without_running( job.cancel() assert e.match(regexp=r"PipelineJob has not been launched") + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get_with_fail", + "mock_load_json", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_pipeline_failure_raises(self, sync): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_ID, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + ) + + with pytest.raises(RuntimeError): + job.run( + service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync, + ) + + if not sync: + job.wait()