diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index 8bf06ccb49..84b39c2ae8 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -264,7 +264,7 @@ def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]: @property def _has_run(self) -> bool: """Helper property to check if this pipeline job has been run.""" - return bool(self._gca_resource.name) + return bool(self._gca_resource.create_time) @property def has_failed(self) -> bool: @@ -310,3 +310,19 @@ def _block_until_complete(self): log_wait = min(log_wait * multiplier, max_wait) previous_time = current_time time.sleep(wait) + + def cancel(self) -> None: + """Starts asynchronous cancellation on the PipelineJob. The server + makes a best effort to cancel the job, but success is not guaranteed. + On successful cancellation, the PipelineJob is not deleted; instead it + becomes a job with state set to `CANCELLED`. + + Raises: + RuntimeError: If this PipelineJob has not started running. + """ + if not self._has_run: + raise RuntimeError( + "This PipelineJob has not been launched, use the `run()` method " + "to start. `cancel()` can only be called on a job that is running." + ) + self.api_client.cancel_pipeline_job(name=self.resource_name) diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index e156163348..18dc692d38 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -21,12 +21,11 @@ from unittest import mock from importlib import reload from unittest.mock import patch +from datetime import datetime from google.auth import credentials as auth_credentials - from google.cloud import aiplatform from google.cloud import storage - from google.cloud.aiplatform import pipeline_jobs from google.cloud.aiplatform import initializer from google.protobuf import json_format @@ -72,6 +71,7 @@ _TEST_PIPELINE_RESOURCE_NAME = ( f"{_TEST_PARENT}/fakePipelineJobs/{_TEST_PIPELINE_JOB_ID}" ) +_TEST_PIPELINE_CREATE_TIME = datetime.now() @pytest.fixture @@ -82,13 +82,16 @@ def mock_pipeline_service_create(): mock_create_pipeline_job.return_value = gca_pipeline_job_v1beta1.PipelineJob( name=_TEST_PIPELINE_JOB_NAME, state=gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED, + create_time=_TEST_PIPELINE_CREATE_TIME, ) yield mock_create_pipeline_job def make_pipeline_job(state): return gca_pipeline_job_v1beta1.PipelineJob( - name=_TEST_PIPELINE_JOB_NAME, state=state, + name=_TEST_PIPELINE_JOB_NAME, + state=state, + create_time=_TEST_PIPELINE_CREATE_TIME, ) @@ -130,6 +133,14 @@ def mock_pipeline_service_get(): yield mock_get_pipeline_job +@pytest.fixture +def mock_pipeline_service_cancel(): + with mock.patch.object( + pipeline_service_client_v1beta1.PipelineServiceClient, "cancel_pipeline_job" + ) as mock_cancel_pipeline_job: + yield mock_cancel_pipeline_job + + @pytest.fixture def mock_load_json(): with patch.object(storage.Blob, "download_as_bytes") as mock_load_json: @@ -155,13 +166,10 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) + @pytest.mark.usefixtures("mock_load_json") @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create( - self, - mock_pipeline_service_create, - mock_pipeline_service_get, - mock_load_json, - sync, + self, mock_pipeline_service_create, mock_pipeline_service_get, sync, ): aiplatform.init( project=_TEST_PROJECT, @@ -213,3 +221,51 @@ def test_run_call_pipeline_service_create( assert job._gca_resource == make_pipeline_job( gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json", + ) + def test_cancel_pipeline_job( + self, mock_pipeline_service_cancel, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_ID, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + ) + + job.run() + job.cancel() + + mock_pipeline_service_cancel.assert_called_once_with( + name=_TEST_PIPELINE_JOB_NAME + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json", + ) + def test_cancel_pipeline_job_without_running( + self, mock_pipeline_service_cancel, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_ID, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + ) + + with pytest.raises(RuntimeError) as e: + job.cancel() + + assert e.match(regexp=r"PipelineJob has not been launched")