Skip to content

Commit

Permalink
feat: add cancel method to pipeline client (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-yaqi committed Jun 21, 2021
1 parent 74627ba commit 3b19fff
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 9 deletions.
18 changes: 17 additions & 1 deletion google/cloud/aiplatform/pipeline_jobs.py
Expand Up @@ -264,7 +264,7 @@ def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
@property
def _has_run(self) -> bool:
"""Helper property to check if this pipeline job has been run."""
return bool(self._gca_resource.name)
return bool(self._gca_resource.create_time)

@property
def has_failed(self) -> bool:
Expand Down Expand Up @@ -310,3 +310,19 @@ def _block_until_complete(self):
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)

def cancel(self) -> None:
"""Starts asynchronous cancellation on the PipelineJob. The server
makes a best effort to cancel the job, but success is not guaranteed.
On successful cancellation, the PipelineJob is not deleted; instead it
becomes a job with state set to `CANCELLED`.
Raises:
RuntimeError: If this PipelineJob has not started running.
"""
if not self._has_run:
raise RuntimeError(
"This PipelineJob has not been launched, use the `run()` method "
"to start. `cancel()` can only be called on a job that is running."
)
self.api_client.cancel_pipeline_job(name=self.resource_name)
72 changes: 64 additions & 8 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Expand Up @@ -21,12 +21,11 @@
from unittest import mock
from importlib import reload
from unittest.mock import patch
from datetime import datetime

from google.auth import credentials as auth_credentials

from google.cloud import aiplatform
from google.cloud import storage

from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform import initializer
from google.protobuf import json_format
Expand Down Expand Up @@ -72,6 +71,7 @@
_TEST_PIPELINE_RESOURCE_NAME = (
f"{_TEST_PARENT}/fakePipelineJobs/{_TEST_PIPELINE_JOB_ID}"
)
_TEST_PIPELINE_CREATE_TIME = datetime.now()


@pytest.fixture
Expand All @@ -82,13 +82,16 @@ def mock_pipeline_service_create():
mock_create_pipeline_job.return_value = gca_pipeline_job_v1beta1.PipelineJob(
name=_TEST_PIPELINE_JOB_NAME,
state=gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED,
create_time=_TEST_PIPELINE_CREATE_TIME,
)
yield mock_create_pipeline_job


def make_pipeline_job(state):
return gca_pipeline_job_v1beta1.PipelineJob(
name=_TEST_PIPELINE_JOB_NAME, state=state,
name=_TEST_PIPELINE_JOB_NAME,
state=state,
create_time=_TEST_PIPELINE_CREATE_TIME,
)


Expand Down Expand Up @@ -130,6 +133,14 @@ def mock_pipeline_service_get():
yield mock_get_pipeline_job


@pytest.fixture
def mock_pipeline_service_cancel():
with mock.patch.object(
pipeline_service_client_v1beta1.PipelineServiceClient, "cancel_pipeline_job"
) as mock_cancel_pipeline_job:
yield mock_cancel_pipeline_job


@pytest.fixture
def mock_load_json():
with patch.object(storage.Blob, "download_as_bytes") as mock_load_json:
Expand All @@ -155,13 +166,10 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@pytest.mark.usefixtures("mock_load_json")
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_load_json,
sync,
self, mock_pipeline_service_create, mock_pipeline_service_get, sync,
):
aiplatform.init(
project=_TEST_PROJECT,
Expand Down Expand Up @@ -213,3 +221,51 @@ def test_run_call_pipeline_service_create(
assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
)
def test_cancel_pipeline_job(
self, mock_pipeline_service_cancel,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)

job.run()
job.cancel()

mock_pipeline_service_cancel.assert_called_once_with(
name=_TEST_PIPELINE_JOB_NAME
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
)
def test_cancel_pipeline_job_without_running(
self, mock_pipeline_service_cancel,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)

with pytest.raises(RuntimeError) as e:
job.cancel()

assert e.match(regexp=r"PipelineJob has not been launched")

0 comments on commit 3b19fff

Please sign in to comment.