Skip to content

Commit

Permalink
fix: create pipeline job with user-specified job id (#567)
Browse files Browse the repository at this point in the history
* fix: create pipeline job with user-specified job id

* Add comments for pipelineJob name not used for service
  • Loading branch information
ji-yaqi committed Jul 26, 2021
1 parent fe5c702 commit df68ec3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
6 changes: 5 additions & 1 deletion google/cloud/aiplatform/pipeline_jobs.py
Expand Up @@ -249,8 +249,12 @@ def run(

_LOGGER.log_create_with_lro(self.__class__)

# PipelineJob.name is not used by pipeline service
pipeline_job_id = self._gca_resource.name.split("/")[-1]
self._gca_resource = self.api_client.create_pipeline_job(
parent=self._parent, pipeline_job=self._gca_resource
parent=self._parent,
pipeline_job=self._gca_resource,
pipeline_job_id=pipeline_job_id,
)

_LOGGER.log_create_complete_with_getter(
Expand Down
31 changes: 17 additions & 14 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Expand Up @@ -40,6 +40,7 @@

_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_PIPELINE_JOB_DISPLAY_NAME = "sample-pipeline-job-display-name"
_TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111"
_TEST_GCS_BUCKET_NAME = "my-bucket"
_TEST_CREDENTIALS = auth_credentials.AnonymousCredentials()
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_run_call_pipeline_service_create(
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
Expand All @@ -222,7 +223,7 @@ def test_run_call_pipeline_service_create(

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
name=_TEST_PIPELINE_JOB_NAME,
pipeline_spec={
"components": {},
Expand All @@ -233,7 +234,9 @@ def test_run_call_pipeline_service_create(
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT, pipeline_job=expected_gapic_pipeline_job,
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
)

mock_pipeline_service_get.assert_called_with(name=_TEST_PIPELINE_JOB_NAME)
Expand All @@ -242,6 +245,14 @@ def test_run_call_pipeline_service_create(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures("mock_pipeline_service_get")
def test_get_pipeline_job(self, mock_pipeline_service_get):
aiplatform.init(project=_TEST_PROJECT)
job = pipeline_jobs.PipelineJob.get(resource_name=_TEST_PIPELINE_JOB_ID)

mock_pipeline_service_get.assert_called_once_with(name=_TEST_PIPELINE_JOB_NAME)
assert isinstance(job, pipeline_jobs.PipelineJob)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
)
Expand All @@ -255,7 +266,7 @@ def test_cancel_pipeline_job(
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)
Expand All @@ -267,14 +278,6 @@ def test_cancel_pipeline_job(
name=_TEST_PIPELINE_JOB_NAME
)

@pytest.mark.usefixtures("mock_pipeline_service_get")
def test_get_training_job(self, mock_pipeline_service_get):
aiplatform.init(project=_TEST_PROJECT)
job = pipeline_jobs.PipelineJob.get(resource_name=_TEST_PIPELINE_JOB_ID)

mock_pipeline_service_get.assert_called_once_with(name=_TEST_PIPELINE_JOB_NAME)
assert isinstance(job, pipeline_jobs.PipelineJob)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
)
Expand All @@ -288,7 +291,7 @@ def test_cancel_pipeline_job_without_running(
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)
Expand All @@ -313,7 +316,7 @@ def test_pipeline_failure_raises(self, sync):
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_ID,
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
Expand Down

0 comments on commit df68ec3

Please sign in to comment.