diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index b805b08e26..e719c0d5fd 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -107,8 +107,9 @@ def __init__( display_name (str): Required. The user-defined name of this Pipeline. template_path (str): - Required. The path of PipelineJob JSON file. It can be a local path or a - Google Cloud Storage URI. Example: "gs://project.name" + Required. The path of PipelineJob or PipelineSpec JSON file. It + can be a local path or a Google Cloud Storage URI. + Example: "gs://project.name" job_id (str): Optional. The unique ID of the job run. If not specified, pipeline name + timestamp will be used. @@ -165,14 +166,37 @@ def __init__( self._parent = initializer.global_config.common_location_path( project=project, location=location ) - pipeline_job = json_utils.load_json( + pipeline_json = json_utils.load_json( template_path, self.project, self.credentials ) - pipeline_root = ( - pipeline_root - or pipeline_job["runtimeConfig"].get("gcsOutputDirectory") - or initializer.global_config.staging_bucket + # Pipeline_json can be either PipelineJob or PipelineSpec. + if pipeline_json.get("pipelineSpec") is not None: + pipeline_job = pipeline_json + pipeline_root = ( + pipeline_root + or pipeline_job["pipelineSpec"].get("defaultPipelineRoot") + or pipeline_job["runtimeConfig"].get("gcsOutputDirectory") + or initializer.global_config.staging_bucket + ) + else: + pipeline_job = { + "pipelineSpec": pipeline_json, + "runtimeConfig": {}, + } + pipeline_root = ( + pipeline_root + or pipeline_job["pipelineSpec"].get("defaultPipelineRoot") + or initializer.global_config.staging_bucket + ) + builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json( + pipeline_job ) + builder.update_pipeline_root(pipeline_root) + builder.update_runtime_parameters(parameter_values) + runtime_config_dict = builder.build() + + runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(runtime_config_dict, runtime_config) pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"] self.job_id = job_id or "{pipeline_name}-{timestamp}".format( @@ -188,15 +212,6 @@ def __init__( '"[a-z][-a-z0-9]{{0,127}}"'.format(job_id) ) - builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json( - pipeline_job - ) - builder.update_pipeline_root(pipeline_root) - builder.update_runtime_parameters(parameter_values) - runtime_config_dict = builder.build() - runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb - json_format.ParseDict(runtime_config_dict, runtime_config) - if enable_caching is not None: _set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching) diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index adb72aa15e..d6580de24d 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -53,16 +53,17 @@ _TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}" _TEST_PIPELINE_PARAMETER_VALUES = {"name_param": "hello"} -_TEST_PIPELINE_JOB_SPEC = { - "runtimeConfig": {}, - "pipelineSpec": { - "pipelineInfo": {"name": "my-pipeline"}, - "root": { - "dag": {"tasks": {}}, - "inputDefinitions": {"parameters": {"name_param": {"type": "STRING"}}}, - }, - "components": {}, +_TEST_PIPELINE_SPEC = { + "pipelineInfo": {"name": "my-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": {"parameters": {"name_param": {"type": "STRING"}}}, }, + "components": {}, +} +_TEST_PIPELINE_JOB = { + "runtimeConfig": {}, + "pipelineSpec": _TEST_PIPELINE_SPEC, } _TEST_PIPELINE_GET_METHOD_NAME = "get_fake_pipeline_job" @@ -175,10 +176,23 @@ def mock_pipeline_service_list(): @pytest.fixture -def mock_load_json(): - with patch.object(storage.Blob, "download_as_bytes") as mock_load_json: - mock_load_json.return_value = json.dumps(_TEST_PIPELINE_JOB_SPEC).encode() - yield mock_load_json +def mock_load_pipeline_job_json(): + with patch.object(storage.Blob, "download_as_bytes") as mock_load_pipeline_job_json: + mock_load_pipeline_job_json.return_value = json.dumps( + _TEST_PIPELINE_JOB + ).encode() + yield mock_load_pipeline_job_json + + +@pytest.fixture +def mock_load_pipeline_spec_json(): + with patch.object( + storage.Blob, "download_as_bytes" + ) as mock_load_pipeline_spec_json: + mock_load_pipeline_spec_json.return_value = json.dumps( + _TEST_PIPELINE_SPEC + ).encode() + yield mock_load_pipeline_spec_json class TestPipelineJob: @@ -199,9 +213,68 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - @pytest.mark.usefixtures("mock_load_json") + @pytest.mark.usefixtures("mock_load_pipeline_job_json") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_pipeline_job_create( + self, mock_pipeline_service_create, mock_pipeline_service_get, sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync, + ) + + if not sync: + job.wait() + + expected_runtime_config_dict = { + "gcs_output_directory": _TEST_GCS_BUCKET_NAME, + "parameters": {"name_param": {"stringValue": "hello"}}, + } + runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec={ + "components": {}, + "pipelineInfo": _TEST_PIPELINE_JOB["pipelineSpec"]["pipelineInfo"], + "root": _TEST_PIPELINE_JOB["pipelineSpec"]["root"], + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=_TEST_PIPELINE_JOB_ID, + ) + + mock_pipeline_service_get.assert_called_with(name=_TEST_PIPELINE_JOB_NAME) + + assert job._gca_resource == make_pipeline_job( + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + @pytest.mark.usefixtures("mock_load_pipeline_spec_json") @pytest.mark.parametrize("sync", [True, False]) - def test_run_call_pipeline_service_create( + def test_run_call_pipeline_service_pipeline_spec_create( self, mock_pipeline_service_create, mock_pipeline_service_get, sync, ): aiplatform.init( @@ -238,8 +311,8 @@ def test_run_call_pipeline_service_create( display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, pipeline_spec={ "components": {}, - "pipelineInfo": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["pipelineInfo"], - "root": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["root"], + "pipelineInfo": _TEST_PIPELINE_JOB["pipelineSpec"]["pipelineInfo"], + "root": _TEST_PIPELINE_JOB["pipelineSpec"]["root"], }, runtime_config=runtime_config, service_account=_TEST_SERVICE_ACCOUNT, @@ -267,7 +340,9 @@ def test_get_pipeline_job(self, mock_pipeline_service_get): assert isinstance(job, pipeline_jobs.PipelineJob) @pytest.mark.usefixtures( - "mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json", + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_load_pipeline_job_json", ) def test_cancel_pipeline_job( self, mock_pipeline_service_cancel, @@ -292,7 +367,9 @@ def test_cancel_pipeline_job( ) @pytest.mark.usefixtures( - "mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json", + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_load_pipeline_job_json", ) def test_list_pipeline_job(self, mock_pipeline_service_list): aiplatform.init( @@ -315,7 +392,9 @@ def test_list_pipeline_job(self, mock_pipeline_service_list): ) @pytest.mark.usefixtures( - "mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json", + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_load_pipeline_job_json", ) def test_cancel_pipeline_job_without_running( self, mock_pipeline_service_cancel, @@ -340,7 +419,7 @@ def test_cancel_pipeline_job_without_running( @pytest.mark.usefixtures( "mock_pipeline_service_create", "mock_pipeline_service_get_with_fail", - "mock_load_json", + "mock_load_pipeline_job_json", ) @pytest.mark.parametrize("sync", [True, False]) def test_pipeline_failure_raises(self, sync):