Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(PipelineJob): allow PipelineSpec as param #774

Merged
merged 5 commits into from Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 34 additions & 10 deletions google/cloud/aiplatform/pipeline_jobs.py
Expand Up @@ -107,8 +107,9 @@ def __init__(
display_name (str):
Required. The user-defined name of this Pipeline.
template_path (str):
Required. The path of PipelineJob JSON file. It can be a local path or a
Google Cloud Storage URI. Example: "gs://project.name"
Required. The path of PipelineJob or PipelineSpec JSON file. It
can be a local path or a Google Cloud Storage URI.
Example: "gs://project.name"
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
Expand Down Expand Up @@ -165,15 +166,25 @@ def __init__(
self._parent = initializer.global_config.common_location_path(
project=project, location=location
)
pipeline_job = json_utils.load_json(
pipeline_json = json_utils.load_json(
template_path, self.project, self.credentials
)
# Pipeline_json can be either PipelineJob or PipelineSpec.
if pipeline_json.get("pipelineSpec") is not None:
pipeline_job = pipeline_json
pipeline_spec_only = False
else:
pipeline_job = {
"pipelineSpec": pipeline_json,
"runtimeConfig": {},
}
pipeline_spec_only = True
pipeline_root = (
pipeline_root
or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
or initializer.global_config.staging_bucket
)

pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
self.job_id = job_id or "{pipeline_name}-{timestamp}".format(
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
Expand All @@ -188,12 +199,25 @@ def __init__(
'"[a-z][-a-z0-9]{{0,127}}"'.format(job_id)
)

builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
pipeline_job
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()
if pipeline_spec_only:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be covered by the else block. So you don't even need this flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I wasn't very clear with the last comment, I meant you could use PipelineRuntimeConfigBuilder.from_job_spec_json(pipeline_job) for both cases, something like:

if pipeline_json.get("pipelineSpec") is not None:
    pipeline_job = pipeline_json
    pipeline_root = ...
else:
    pipeline_job = {"pipelineSpec": pipeline_json, "runtimeConfig": {}}
    pipeline_root = ...

builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(pipeline_job)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()

This will result shorter and cleaner code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!

runtime_config_dict = pipeline_utils.PipelineRuntimeConfigBuilder(
pipeline_root=pipeline_root,
parameter_types={
key: value["type"]
for key, value in pipeline_job["pipelineSpec"]["root"]
.get("inputDefinitions", {})
.get("parameters", {})
.items()
},
parameter_values=parameter_values,
).build()
else:
builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
pipeline_job
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)

Expand Down
121 changes: 100 additions & 21 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Expand Up @@ -53,16 +53,17 @@
_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"

_TEST_PIPELINE_PARAMETER_VALUES = {"name_param": "hello"}
_TEST_PIPELINE_JOB_SPEC = {
"runtimeConfig": {},
"pipelineSpec": {
"pipelineInfo": {"name": "my-pipeline"},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {"parameters": {"name_param": {"type": "STRING"}}},
},
"components": {},
_TEST_PIPELINE_SPEC = {
"pipelineInfo": {"name": "my-pipeline"},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {"parameters": {"name_param": {"type": "STRING"}}},
},
"components": {},
}
_TEST_PIPELINE_JOB = {
"runtimeConfig": {},
"pipelineSpec": _TEST_PIPELINE_SPEC,
}

_TEST_PIPELINE_GET_METHOD_NAME = "get_fake_pipeline_job"
Expand Down Expand Up @@ -175,10 +176,23 @@ def mock_pipeline_service_list():


@pytest.fixture
def mock_load_json():
with patch.object(storage.Blob, "download_as_bytes") as mock_load_json:
mock_load_json.return_value = json.dumps(_TEST_PIPELINE_JOB_SPEC).encode()
yield mock_load_json
def mock_load_pipeline_job_json():
with patch.object(storage.Blob, "download_as_bytes") as mock_load_pipeline_job_json:
mock_load_pipeline_job_json.return_value = json.dumps(
_TEST_PIPELINE_JOB
).encode()
yield mock_load_pipeline_job_json


@pytest.fixture
def mock_load_pipeline_spec_json():
with patch.object(
storage.Blob, "download_as_bytes"
) as mock_load_pipeline_spec_json:
mock_load_pipeline_spec_json.return_value = json.dumps(
_TEST_PIPELINE_SPEC
).encode()
yield mock_load_pipeline_spec_json


class TestPipelineJob:
Expand All @@ -199,9 +213,68 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@pytest.mark.usefixtures("mock_load_json")
@pytest.mark.usefixtures("mock_load_pipeline_job_json")
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_pipeline_job_create(
self, mock_pipeline_service_create, mock_pipeline_service_get, sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcs_output_directory": _TEST_GCS_BUCKET_NAME,
"parameters": {"name_param": {"stringValue": "hello"}},
}
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": _TEST_PIPELINE_JOB["pipelineSpec"]["pipelineInfo"],
"root": _TEST_PIPELINE_JOB["pipelineSpec"]["root"],
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
)

mock_pipeline_service_get.assert_called_with(name=_TEST_PIPELINE_JOB_NAME)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures("mock_load_pipeline_spec_json")
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create(
def test_run_call_pipeline_service_pipeline_spec_create(
self, mock_pipeline_service_create, mock_pipeline_service_get, sync,
):
aiplatform.init(
Expand Down Expand Up @@ -238,8 +311,8 @@ def test_run_call_pipeline_service_create(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["pipelineInfo"],
"root": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["root"],
"pipelineInfo": _TEST_PIPELINE_JOB["pipelineSpec"]["pipelineInfo"],
"root": _TEST_PIPELINE_JOB["pipelineSpec"]["root"],
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
Expand Down Expand Up @@ -267,7 +340,9 @@ def test_get_pipeline_job(self, mock_pipeline_service_get):
assert isinstance(job, pipeline_jobs.PipelineJob)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_load_pipeline_job_json",
)
def test_cancel_pipeline_job(
self, mock_pipeline_service_cancel,
Expand All @@ -292,7 +367,9 @@ def test_cancel_pipeline_job(
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_load_pipeline_job_json",
)
def test_list_pipeline_job(self, mock_pipeline_service_list):
aiplatform.init(
Expand All @@ -315,7 +392,9 @@ def test_list_pipeline_job(self, mock_pipeline_service_list):
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_load_pipeline_job_json",
)
def test_cancel_pipeline_job_without_running(
self, mock_pipeline_service_cancel,
Expand All @@ -340,7 +419,7 @@ def test_cancel_pipeline_job_without_running(
@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get_with_fail",
"mock_load_json",
"mock_load_pipeline_job_json",
)
@pytest.mark.parametrize("sync", [True, False])
def test_pipeline_failure_raises(self, sync):
Expand Down