diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index 7bd512e7e8..93345cb5d8 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -44,6 +44,7 @@ model_evaluation_slice as model_evaluation_slice_v1beta1, model_service as model_service_v1beta1, operation as operation_v1beta1, + pipeline_job as pipeline_job_v1beta1, pipeline_service as pipeline_service_v1beta1, pipeline_state as pipeline_state_v1beta1, prediction_service as prediction_service_v1beta1, @@ -158,6 +159,7 @@ model_evaluation_slice_v1beta1, model_service_v1beta1, operation_v1beta1, + pipeline_job_v1beta1, pipeline_service_v1beta1, pipeline_state_v1beta1, prediction_service_v1beta1, diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py new file mode 100644 index 0000000000..5c86c74dee --- /dev/null +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from typing import Optional, Dict, List + +import logging +import re +import sys + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.utils import pipeline_runtime_config_builder + +from google.cloud.aiplatform.compat.services import pipeline_service_client +from google.cloud.aiplatform.compat.types import ( + pipeline_job_v1beta1 as gca_pipeline_job_v1beta1, + pipeline_state_v1beta1 as gca_pipeline_state_v1beta1, +) + +from google.rpc import code_pb2 + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) +_LOGGER = base.Logger(__name__) + +_PIPELINE_COMPLETE_STATES = set( + [ + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED, + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED, + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_CANCELLED, + gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_PAUSED, + ] +) + +_PIPELINE_CLIENT_VERSION='v1beta1' + +# AIPlatformPipelines service API job name relative name prefix pattern. +_JOB_NAME_PATTERN = '{parent}/pipelineJobs/{job_id}' + +# Pattern for valid names used as a Vertex resource name. +_VALID_NAME_PATTERN = re.compile('^[a-z][-a-z0-9]{0,127}$') + +def _set_enable_caching_value(pipeline_spec: Dict, + enable_caching: bool) -> None: + """Sets pipeline tasks caching options. + Args: + pipeline_spec: The dictionary of pipeline spec. + enable_caching: Whether to enable caching. + """ + for component in [pipeline_spec['root']] + list( + pipeline_spec['components'].values()): + if 'dag' in component: + for task in component['dag']['tasks'].values(): + task['cachingOptions'] = {'enableCache': enable_caching} + + +class PipelineJob(base.VertexAiResourceNounWithFutureManager): + + client_class = utils.PipelineClientWithOverride + _is_client_prediction_client = False + + _resource_noun = "pipelineJobs" + _getter_method = "get_pipeline_job" + _list_method = "list_pipeline_jobs" + _cancel_method = "cancel_pipeline_job" + _delete_method = "delete_pipeline_job" + + def __init__( + self, + display_name: str, + job_spec_path: str, + job_id: Optional[str] = None, + pipeline_root: Optional[str] = None, + parameter_values: Optional[Dict] = None, + enable_caching: bool = True, + encryption_spec_key_name: Optional[str] = None, + network: Optional[str] = None, + labels: Optional[Dict] = None, + service_account: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + ): + """Retrieves a PipelineJob resource and instantiates its + representation. + + Args: + display_name (str): + Required. The user-defined name of this Pipeline. + job_spec_path (str): + The path of PipelineJob JSON file. It can be a local path or a + GS URI. Example: "gs://project.name" + job_id (Optional[str]): + Optionally, the user can provide the unique ID of the job run. + If not specified, pipeline name + timestamp will be used. + pipeline_root (Optional[str]): + Optionally the user can override the pipeline root + specified during the compile time. Default to be staging bucket. + parameter_values (Optional[Dict]): + The mapping from runtime parameter names to its values that + control the pipeline run. + enable_caching (bool): + Required. Whether to turn on caching for the run. Defaults to True. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the job. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If this is set, then all + resources created by the BatchPredictionJob will + be encrypted with the provided encryption key. + + Overrides encryption_spec_key_name set in aiplatform.init. + labels (Optional[Dict]): + The user defined metadata to organize PipelineJob. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to create this batch prediction + job. Overrides credentials set in aiplatform.init. + project: Optional[str] = None, + Optional project to retrieve PipelineJob from. If not set, + project set in aiplatform.init will be used. + location: Optional[str] = None, + Optional location to retrieve PipelineJob from. If not set, + location set in aiplatform.init will be used. + """ + utils.validate_display_name(display_name) + + super().__init__(project=project, location=location, credentials=credentials) + + self._parent = initializer.global_config.common_location_path( + project=project, location=location + ) + pipeline_root = pipeline_root or initializer.global_config.staging_bucket + pipeline_spec = utils.load_json(job_spec_path) + pipeline_name = pipeline_spec['pipelineSpec']['pipelineInfo']['name'] + job_id = job_id or '{pipeline_name}-{timestamp}'.format( + pipeline_name=re.sub('[^-0-9a-z]+', '-', + pipeline_name.lower()).lstrip('-').rstrip('-'), + timestamp=_get_current_time().strftime('%Y%m%d%H%M%S')) + if not _VALID_NAME_PATTERN.match(job_id): + raise ValueError( + 'Generated job ID: {} is illegal as a Vertex pipelines job ID. ' + 'Expecting an ID following the regex pattern ' + '"[a-z][-a-z0-9]{{0,127}}"'.format(job_id)) + + job_name = _JOB_NAME_PATTERN.format(parent=self._parent, job_id=job_id) + + pipeline_spec['name'] = job_name + pipeline_spec['displayName'] = job_id + + builder = pipeline_runtime_config_builder.PipelineRuntimeConfigBuilder.from_job_spec_json( + pipeline_spec) + builder.update_pipeline_root(pipeline_root) + builder.update_runtime_parameters(parameter_values) + + runtime_config = builder.build() + pipeline_spec['runtimeConfig'] = runtime_config + + _set_enable_caching_value(pipeline_spec['pipelineSpec'], enable_caching) + + if encryption_spec_key_name is not None: + pipeline_spec['encryptionSpec'] = {'kmsKeyName': encryption_spec_key_name} + if service_account is not None: + pipeline_spec['serviceAccount'] = service_account + if network is not None: + pipeline_spec['network'] = network + + if labels: + if not isinstance(labels, Dict): + raise ValueError( + 'Expect labels to be a mapping of string key value pairs. ' + 'Got "{}" of type "{}"'.format(labels, type(labels))) + for k, v in labels.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise ValueError( + 'Expect labels to be a mapping of string key value pairs. ' + 'Got "{}".'.format(labels)) + + pipeline_spec['labels'] = labels + + self._gca_resource = gca_pipeline_job_v1beta1.PipelineJob( + display_name=display_name, + pipeline_spec=pipeline_spec, + labels=labels, + runtime_config=None, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + service_account=service_account, + network=network, + ) + + @base.optional_sync() + def run( + self, + service_account: Optional[str] = None, + network: Optional[str] = None, + sync: bool = True, + ) -> None: + """Run this configured PipelineJob. + Args: + service_account (str): + Optional. Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + network (str): + Optional. The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + sync (bool): + Whether to execute this method synchronously. If False, this method + will unblock and it will be executed in a concurrent Future. + """ + + if service_account: + self._gca_resource.pipeline_spec.service_account = service_account + + if network: + self._gca_resource.pipeline_spec.network = network + + _LOGGER.log_create_with_lro(self.__class__) + + self._gca_resource = self.api_client.select_version(_PIPELINE_CLIENT_VERSION).create_pipeline_job( + parent=self._parent, + pipeline_job=self._gca_resource + ) + + _LOGGER.log_create_complete_with_getter( + self.__class__, self._gca_resource, "pipeline_job" + ) + + _LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri()) + + self._block_until_complete() + + @property + def pipeline_spec(self): + return self._gca_resource.pipeline_spec + + @property + def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]: + """Current pipeline state.""" + + if self._assert_has_run(): + return + + self._sync_gca_resource() + return self._gca_resource.state + + @property + def _has_run(self) -> bool: + """Helper property to check if this pipeline job has been run.""" + return self._gca_resource is not None + + def _assert_has_run(self) -> bool: + """Helper method to assert that this pipeline has run.""" + if not self._has_run: + if self._is_waiting_to_run(): + return True + raise RuntimeError( + "PipelineJob has not been launched. You must run this" + " PipelineJob using PipelineJob.run. " + ) + return False + + @property + def has_failed(self) -> bool: + """Returns True if pipeline has failed. + + False otherwise. + """ + self._assert_has_run() + return self.state == gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED + + def _dashboard_uri(self) -> str: + """Helper method to compose the dashboard uri where pipeline can be + viewed.""" + fields = utils.extract_fields_from_resource_name(self.resource_name) + url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/pipelines/runs/{fields.id}?project={fields.project}" + return url + + def _sync_gca_resource(self): + """Helper method to sync the local gca_source against the service.""" + + self._gca_resource = self.api_client.select_version(_PIPELINE_CLIENT_VERSION).get_pipeline_job( + name=self.resource_name + ) + + def _block_until_complete(self): + """Helper method to block and check on job until complete.""" + + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + log_wait = 5 + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + previous_time = time.time() + while self.state not in _PIPELINE_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + time.sleep(wait) + + self._raise_failure() + + _LOGGER.log_action_completed_against_resource("run", "completed", self) + + if self._gca_resource.name and not self.has_failed: + _LOGGER.info("Pipeline Job available at:\n%s" % self._dashboard_uri()) + + def _raise_failure(self): + """Helper method to raise failure if PipelineJob fails. + + Raises: + RuntimeError: If pipeline failed. + """ + + if self._gca_resource.error.code != code_pb2.OK: + raise RuntimeError("Pipeline failed with:\n%s" % self._gca_resource.error) + diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index 4404defb21..24ebc38128 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -22,7 +22,7 @@ from collections import namedtuple import logging import re -from typing import Any, Match, Optional, Type, TypeVar, Tuple +from typing import Any, Dict, Match, Optional, Type, TypeVar, Tuple from google.api_core import client_options from google.api_core import gapic_v1 @@ -194,6 +194,48 @@ def full_resource_name( return resource_name +def load_json(path: str) -> Dict[str, Any]: + """Loads data from a JSON document. + Args: + path: The path of the JSON document. It can be a local path or a GS URI. + Returns: + A deserialized Dict object representing the JSON document. + """ + + if path.startswith('gs://'): + return _load_json_from_gs_uri(path) + else: + return _load_json_from_local_file(path) + + +def _load_json_from_gs_uri(uri: str) -> Dict[str, Any]: + """Loads data from a JSON document referenced by a GS URI. + Args: + uri: The GCS URI of the JSON document. + Returns: + A deserialized Dict object representing the JSON document. + Raises: + google.cloud.exceptions.NotFound: If the blob is not found. + json.decoder.JSONDecodeError: On JSON parsing problems. + ValueError: If uri is not a valid gs URI. + """ + storage_client = storage.Client() + blob = storage.Blob.from_string(uri, storage_client) + return json.loads(blob.download_as_string()) + + +def _load_json_from_local_file(file_path: str) -> Dict[str, Any]: + """Loads data from a JSON local file. + Args: + file_path: The local file path of the JSON document. + Returns: + A deserialized Dict object representing the JSON document. + Raises: + json.decoder.JSONDecodeError: On JSON parsing problems. + """ + with open(file_path) as f: + return json.load(f) + # TODO(b/172286889) validate resource noun def validate_resource_noun(resource_noun: str) -> bool: """Validates resource noun. diff --git a/google/cloud/aiplatform/utils/pipeline_runtime_config_builder.py b/google/cloud/aiplatform/utils/pipeline_runtime_config_builder.py new file mode 100644 index 0000000000..5e5382aa80 --- /dev/null +++ b/google/cloud/aiplatform/utils/pipeline_runtime_config_builder.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from typing import Any, Dict, Mapping, Optional, Union + + +class PipelineRuntimeConfigBuilder(object): + """Pipeline RuntimeConfig builder. + + Constructs a RuntimeConfig spec with pipeline_root and parameter overrides. + """ + + def __init__( + self, + pipeline_root: str, + parameter_types: Mapping[str, str], + parameter_values: Optional[Dict[str, Any]] = None, + ): + """Creates a PipelineRuntimeConfigBuilder object. + + Args: + pipeline_root: The root of the pipeline outputs. + parameter_types: The mapping from pipeline parameter name to its type. + parameter_values: The mapping from runtime parameter name to its value. + """ + self._pipeline_root = pipeline_root + self._parameter_types = parameter_types + self._parameter_values = copy.deepcopy(parameter_values or {}) + + @classmethod + def from_job_spec_json(cls, job_spec: Mapping[str, + Any]) -> 'PipelineRuntimeConfigBuilder': + """Creates a PipelineRuntimeConfigBuilder object from PipelineJob json spec. + + Args: + job_spec: The PipelineJob spec. + + Returns: + A PipelineRuntimeConfigBuilder object. + """ + runtime_config_spec = job_spec['runtimeConfig'] + parameter_types = {} + parameter_input_definitions = job_spec['pipelineSpec']['root'].get( + 'inputDefinitions', {}).get('parameters', {}) + for k, v in parameter_input_definitions.items(): + parameter_types[k] = v['type'] + + pipeline_root = runtime_config_spec.get('gcsOutputDirectory') + parameter_values = _parse_runtime_parameters(runtime_config_spec) + return cls(pipeline_root, parameter_types, parameter_values) + + def update_pipeline_root(self, pipeline_root: Optional[str]) -> None: + """Updates pipeline_root value. + + Args: + pipeline_root: The root of the pipeline outputs. + """ + if pipeline_root: + self._pipeline_root = pipeline_root + + def update_runtime_parameters( + self, parameter_values: Optional[Mapping[str, Any]]) -> None: + """Merges runtime parameter values. + + Args: + parameter_values: The mapping from runtime parameter names to its values. + """ + if parameter_values: + self._parameter_values.update(parameter_values) + + def build(self) -> Mapping[str, Any]: + """Build a RuntimeConfig proto.""" + if not self._pipeline_root: + raise ValueError('Pipeline root must be specified, either during compile ' + 'time, or when calling the service.') + return { + 'gcsOutputDirectory': self._pipeline_root, + 'parameters': { + k: self._get_vertex_value(k, v) + for k, v in self._parameter_values.items() + if v is not None + } + } + + def _get_vertex_value(self, name: str, value: Union[int, float, + str]) -> Mapping[str, Any]: + """Converts primitive values into Vertex pipeline Value proto message. + + Args: + name: The name of the pipeline parameter. + value: The value of the pipeline parameter. + + Returns: + A dictionary represents the Vertex pipeline Value proto message. + + Raises: + AssertionError: if the value is None. + ValueError: if the parameeter name is not found in pipeline root inputs. + TypeError: if the paraemter type is not one of 'INT', 'DOUBLE', 'STRING'. + """ + assert value is not None, 'None values should be filterd out.' + + if name not in self._parameter_types: + raise ValueError('The pipeline parameter {} is not found in the pipeline ' + 'job input definitions.'.format(name)) + + result = {} + if self._parameter_types[name] == 'INT': + result['intValue'] = value + elif self._parameter_types[name] == 'DOUBLE': + result['doubleValue'] = value + elif self._parameter_types[name] == 'STRING': + result['stringValue'] = value + else: + raise TypeError('Got unknown type of value: {}'.format(value)) + + return result + + +def _parse_runtime_parameters( + runtime_config_spec: Mapping[str, Any]) -> Optional[Dict[str, Any]]: + """Extracts runtime parameters from runtime config json spec.""" + runtime_parameters = runtime_config_spec.get('parameters') + if not runtime_parameters: + return None + + result = {} + for name, value in runtime_parameters.items(): + if 'intValue' in value: + result[name] = int(value['intValue']) + elif 'doubleValue' in value: + result[name] = float(value['doubleValue']) + elif 'stringValue' in value: + result[name] = value['stringValue'] + else: + raise TypeError('Got unknown type of value: {}'.format(value)) + + return result diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py new file mode 100644 index 0000000000..8f16ce6b2c --- /dev/null +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import functools + +from unittest import mock +from importlib import reload +from unittest.mock import patch + +from google.cloud import storage +from google.cloud import bigquery + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import pipeline_jobs +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + client as pipeline_service_client_v1beta1, +) + +from google.cloud.aiplatform_v1beta1.types import ( + pipeline_job as gca_pipeline_job_v1beta1, + pipeline_state as gca_pipeline_state_v1beta1, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111" +_TEST_GCS_BUCKET_NAME = "my-bucket" +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) +_TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com" + +_TEST_JOB_SPEC_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" + +_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}" +_TEST_PIPELINE_JOB_DISPLAY_NAME = "test-pipeline-job" + +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 +_TEST_STARTING_REPLICA_COUNT = 2 +_TEST_MAX_REPLICA_COUNT = 12 + +_TEST_PIPELINE_PARAMETER_VALUES = {"name_param": {"stringValue": "hello"}} +_TEST_PIPELINE_JOB_SPEC = { + "displayName": "my-pipeline", + "runtimeConfig": { + "gcsOutputDirectory": "gs://some-bucket/tmp/", + "parameters": { + "name_param": { + "stringValue": "world" + } + } + }, + "pipelineSpec": { + "pipelineInfo": { + "name": "my-pipeline" + }, + "root": { + "dag": { + "tasks": { + "task-test": { + "taskInfo": { + "name": "task-test" + }, + "inputs": { + "parameters": { + "name": { + "componentInputParameter": "name_param" + } + } + }, + "componentRef": { + "name": "comp-test" + } + } + } + }, + "inputDefinitions": { + "parameters": { + "name_param": { + "type": "STRING" + } + } + } + }, + "deploymentSpec": { + "executors": { + "exec-test": { + "container": { + "image": "python:3.7-alpine", + "command": [ + "echo", + "Hello {{$.inputs.parameters['name']}} {{$.scheduledTime.strftime('%Y-%m-%d')}}" + ] + } + } + } + }, + "components": { + "comp-test": { + "executorLabel": "exec-test", + "inputDefinitions": { + "parameters": { + "name": { + "type": "STRING" + } + } + } + } + }, + "sdkVersion": "dummy-kfp-version", + "schemaVersion": "2.0.0" + }, +} + +_TEST_PIPELINE_GET_METHOD_NAME = "get_fake_pipeline_job" +_TEST_PIPELINE_LIST_METHOD_NAME = "list_fake_pipeline_jobs" +_TEST_PIPELINE_CANCEL_METHOD_NAME = "cancel_fake_pipeline_job" +_TEST_PIPELINE_DELETE_METHOD_NAME = "delete_fake_pipeline_job" +_TEST_PIPELINE_RESOURCE_NAME = f"{_TEST_PARENT}/fakePipelineJobs/{_TEST_PIPELINE_JOB_ID}" + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client_v1beta1.PipelineServiceClient, "create_pipeline_job" + ) as mock_create_pipeline_job: + mock_create_pipeline_job.return_value = gca_pipeline_job_v1beta1.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED, + ) + yield mock_create_pipeline_job + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client_v1beta1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_pipeline_job: + mock_get_pipeline_job.return_value = gca_pipeline_job_v1beta1.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED, + ) + yield mock_get_pipeline_job + +@pytest.fixture +def mock_load_json(): + with patch.object(utils, 'load_json') as mock_load_json: + mock_load_json.return_value = _TEST_PIPELINE_JOB_SPEC + yield mock_load_json + + +class TestPipelineJob: + class FakePipelineJob(pipeline_jobs.PipelineJob): + + _resource_noun = "fakePipelineJobs" + _getter_method = _TEST_PIPELINE_GET_METHOD_NAME + _list_method = _TEST_PIPELINE_LIST_METHOD_NAME + _cancel_method = _TEST_PIPELINE_CANCEL_METHOD_NAME + _delete_method = _TEST_PIPELINE_DELETE_METHOD_NAME + resource_name = _TEST_PIPELINE_RESOURCE_NAME + + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_load_json, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + job_spec_path=_TEST_JOB_SPEC_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + ) + + pipeline_job = job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + sync=sync, + ) + + # TODO(ji-yaqi): uncomment when wait method is added. + # if not sync: + # pipeline_job.wait() + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec=_TEST_PIPELINE_JOB_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + ) + +