Skip to content

Commit

Permalink
fix: log pipeline completion and raise pipeline failures (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg committed Jul 8, 2021
1 parent f6f9a97 commit 2508fe9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
11 changes: 11 additions & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Expand Up @@ -44,6 +44,10 @@
]
)

_PIPELINE_ERROR_STATES = set(
[gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED]
)

# Vertex AI Pipelines service API job name relative name prefix pattern.
_JOB_NAME_PATTERN = "{parent}/pipelineJobs/{job_id}"

Expand Down Expand Up @@ -311,6 +315,13 @@ def _block_until_complete(self):
previous_time = current_time
time.sleep(wait)

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _PIPELINE_ERROR_STATES:
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
else:
_LOGGER.log_action_completed_against_resource("run", "completed", self)

def cancel(self) -> None:
"""Starts asynchronous cancellation on the PipelineJob. The server
makes a best effort to cancel the job, but success is not guaranteed.
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Expand Up @@ -133,6 +133,26 @@ def mock_pipeline_service_get():
yield mock_get_pipeline_job


@pytest.fixture
def mock_pipeline_service_get_with_fail():
with mock.patch.object(
pipeline_service_client_v1beta1.PipelineServiceClient, "get_pipeline_job"
) as mock_get_pipeline_job:
mock_get_pipeline_job.side_effect = [
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED
),
]

yield mock_get_pipeline_job


@pytest.fixture
def mock_pipeline_service_cancel():
with mock.patch.object(
Expand Down Expand Up @@ -269,3 +289,33 @@ def test_cancel_pipeline_job_without_running(
job.cancel()

assert e.match(regexp=r"PipelineJob has not been launched")

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get_with_fail",
"mock_load_json",
)
@pytest.mark.parametrize("sync", [True, False])
def test_pipeline_failure_raises(self, sync):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

with pytest.raises(RuntimeError):
job.run(
service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync,
)

if not sync:
job.wait()

0 comments on commit 2508fe9

Please sign in to comment.