Skip to content

Commit

Permalink
feat: PipelineJob switch to v1 API from v1beta1 API (#750)
Browse files Browse the repository at this point in the history
* PipelineJob switch to v1 API

* format

* Update test_pipeline_jobs.py

Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
  • Loading branch information
chensun and sasha-gitg committed Nov 2, 2021
1 parent 49aaa87 commit 8db7e0c
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 45 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Expand Up @@ -100,6 +100,7 @@
model_evaluation_slice as model_evaluation_slice_v1,
model_service as model_service_v1,
operation as operation_v1,
pipeline_job as pipeline_job_v1,
pipeline_service as pipeline_service_v1,
pipeline_state as pipeline_state_v1,
prediction_service as prediction_service_v1,
Expand Down Expand Up @@ -145,6 +146,7 @@
model_evaluation_slice_v1,
model_service_v1,
operation_v1,
pipeline_job_v1,
pipeline_service_v1,
pipeline_state_v1,
prediction_service_v1,
Expand Down
24 changes: 11 additions & 13 deletions google/cloud/aiplatform/pipeline_jobs.py
Expand Up @@ -29,23 +29,23 @@
from google.protobuf import json_format

from google.cloud.aiplatform.compat.types import (
pipeline_job_v1beta1 as gca_pipeline_job_v1beta1,
pipeline_state_v1beta1 as gca_pipeline_state_v1beta1,
pipeline_job_v1 as gca_pipeline_job_v1,
pipeline_state_v1 as gca_pipeline_state_v1,
)

_LOGGER = base.Logger(__name__)

_PIPELINE_COMPLETE_STATES = set(
[
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED,
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED,
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_CANCELLED,
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_PAUSED,
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED,
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_FAILED,
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_CANCELLED,
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_PAUSED,
]
)

_PIPELINE_ERROR_STATES = set(
[gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED]
[gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_FAILED]
)

# Pattern for valid names used as a Vertex resource name.
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()

runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)

pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
Expand All @@ -215,7 +215,7 @@ def __init__(
if enable_caching is not None:
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)

self._gca_resource = gca_pipeline_job_v1beta1.PipelineJob(
self._gca_resource = gca_pipeline_job_v1.PipelineJob(
display_name=display_name,
pipeline_spec=pipeline_job["pipelineSpec"],
labels=labels,
Expand Down Expand Up @@ -299,7 +299,7 @@ def pipeline_spec(self):
return self._gca_resource.pipeline_spec

@property
def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
def state(self) -> Optional[gca_pipeline_state_v1.PipelineState]:
"""Current pipeline state."""
self._sync_gca_resource()
return self._gca_resource.state
Expand All @@ -310,9 +310,7 @@ def has_failed(self) -> bool:
False otherwise.
"""
return (
self.state == gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED
)
return self.state == gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_FAILED

def _dashboard_uri(self) -> str:
"""Helper method to compose the dashboard uri where pipeline can be
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/aiplatform/utils/__init__.py
Expand Up @@ -479,8 +479,9 @@ class PipelineClientWithOverride(ClientWithOverride):

class PipelineJobClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.V1BETA1
_default_version = compat.DEFAULT_VERSION
_version_map = (
(compat.V1, pipeline_service_client_v1.PipelineServiceClient),
(compat.V1BETA1, pipeline_service_client_v1beta1.PipelineServiceClient),
)

Expand Down
62 changes: 31 additions & 31 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Expand Up @@ -31,12 +31,12 @@
from google.cloud import storage
from google.protobuf import json_format

from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
client as pipeline_service_client_v1beta1,
from google.cloud.aiplatform_v1.services.pipeline_service import (
client as pipeline_service_client_v1,
)
from google.cloud.aiplatform_v1beta1.types import (
pipeline_job as gca_pipeline_job_v1beta1,
pipeline_state as gca_pipeline_state_v1beta1,
from google.cloud.aiplatform_v1.types import (
pipeline_job as gca_pipeline_job_v1,
pipeline_state as gca_pipeline_state_v1,
)

_TEST_PROJECT = "test-project"
Expand Down Expand Up @@ -128,11 +128,11 @@
@pytest.fixture
def mock_pipeline_service_create():
with mock.patch.object(
pipeline_service_client_v1beta1.PipelineServiceClient, "create_pipeline_job"
pipeline_service_client_v1.PipelineServiceClient, "create_pipeline_job"
) as mock_create_pipeline_job:
mock_create_pipeline_job.return_value = gca_pipeline_job_v1beta1.PipelineJob(
mock_create_pipeline_job.return_value = gca_pipeline_job_v1.PipelineJob(
name=_TEST_PIPELINE_JOB_NAME,
state=gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED,
state=gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED,
create_time=_TEST_PIPELINE_CREATE_TIME,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
Expand All @@ -141,7 +141,7 @@ def mock_pipeline_service_create():


def make_pipeline_job(state):
return gca_pipeline_job_v1beta1.PipelineJob(
return gca_pipeline_job_v1.PipelineJob(
name=_TEST_PIPELINE_JOB_NAME,
state=state,
create_time=_TEST_PIPELINE_CREATE_TIME,
Expand All @@ -153,35 +153,35 @@ def make_pipeline_job(state):
@pytest.fixture
def mock_pipeline_service_get():
with mock.patch.object(
pipeline_service_client_v1beta1.PipelineServiceClient, "get_pipeline_job"
pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job"
) as mock_get_pipeline_job:
mock_get_pipeline_job.side_effect = [
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_RUNNING
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
),
]

Expand All @@ -191,17 +191,17 @@ def mock_pipeline_service_get():
@pytest.fixture
def mock_pipeline_service_get_with_fail():
with mock.patch.object(
pipeline_service_client_v1beta1.PipelineServiceClient, "get_pipeline_job"
pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job"
) as mock_get_pipeline_job:
mock_get_pipeline_job.side_effect = [
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_RUNNING
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_RUNNING
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_RUNNING
),
make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_FAILED
),
]

Expand All @@ -211,15 +211,15 @@ def mock_pipeline_service_get_with_fail():
@pytest.fixture
def mock_pipeline_service_cancel():
with mock.patch.object(
pipeline_service_client_v1beta1.PipelineServiceClient, "cancel_pipeline_job"
pipeline_service_client_v1.PipelineServiceClient, "cancel_pipeline_job"
) as mock_cancel_pipeline_job:
yield mock_cancel_pipeline_job


@pytest.fixture
def mock_pipeline_service_list():
with mock.patch.object(
pipeline_service_client_v1beta1.PipelineServiceClient, "list_pipeline_jobs"
pipeline_service_client_v1.PipelineServiceClient, "list_pipeline_jobs"
) as mock_list_pipeline_jobs:
yield mock_list_pipeline_jobs

Expand Down Expand Up @@ -293,13 +293,13 @@ def test_run_call_pipeline_service_create(
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameters": {"string_param": {"stringValue": "hello"}},
}
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob(
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
Expand All @@ -322,7 +322,7 @@ def test_run_call_pipeline_service_create(
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -362,13 +362,13 @@ def test_submit_call_pipeline_service_pipeline_job_create(
"gcs_output_directory": _TEST_GCS_BUCKET_NAME,
"parameters": {"string_param": {"stringValue": "hello"}},
}
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob(
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
Expand All @@ -395,7 +395,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures("mock_pipeline_service_get")
Expand Down

0 comments on commit 8db7e0c

Please sign in to comment.