Skip to content

Commit

Permalink
fix: add parameters_value in PipelineJob for schema > 2.0.0 (#817)
Browse files Browse the repository at this point in the history
* feat: update PipelineJob to accept protobuf value

* fix tests

* address comments

* fix: update Pipeline Job parameter values according to schema_version

* fix test

* fix key to parameters

* fix format'

* address comments

Co-authored-by: nicain <nicain.seattle@gmail.com>
  • Loading branch information
ji-yaqi and nicain committed Nov 15, 2021
1 parent 2f9a879 commit 900a449
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 82 deletions.
62 changes: 36 additions & 26 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Expand Up @@ -17,6 +17,7 @@
import copy
import json
from typing import Any, Dict, Mapping, Optional, Union
import packaging.version


class PipelineRuntimeConfigBuilder(object):
Expand All @@ -28,6 +29,7 @@ class PipelineRuntimeConfigBuilder(object):
def __init__(
self,
pipeline_root: str,
schema_version: str,
parameter_types: Mapping[str, str],
parameter_values: Optional[Dict[str, Any]] = None,
):
Expand All @@ -36,12 +38,15 @@ def __init__(
Args:
pipeline_root (str):
Required. The root of the pipeline outputs.
schema_version (str):
Required. Schema version of the IR. This field determines the fields supported in current version of IR.
parameter_types (Mapping[str, str]):
Required. The mapping from pipeline parameter name to its type.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter name to its value.
"""
self._pipeline_root = pipeline_root
self._schema_version = schema_version
self._parameter_types = parameter_types
self._parameter_values = copy.deepcopy(parameter_values or {})

Expand All @@ -64,6 +69,8 @@ def from_job_spec_json(
.get("inputDefinitions", {})
.get("parameters", {})
)
schema_version = job_spec["pipelineSpec"]["schemaVersion"]

# 'type' is deprecated in IR and change to 'parameterType'.
parameter_types = {
k: v.get("parameterType") or v.get("type")
Expand All @@ -72,7 +79,7 @@ def from_job_spec_json(

pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
parameter_values = _parse_runtime_parameters(runtime_config_spec)
return cls(pipeline_root, parameter_types, parameter_values)
return cls(pipeline_root, schema_version, parameter_types, parameter_values)

def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
"""Updates pipeline_root value.
Expand All @@ -95,9 +102,12 @@ def update_runtime_parameters(
"""
if parameter_values:
parameters = dict(parameter_values)
for k, v in parameter_values.items():
if isinstance(v, (dict, list, bool)):
parameters[k] = json.dumps(v)
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
for k, v in parameter_values.items():
if isinstance(v, (dict, list, bool)):
parameters[k] = json.dumps(v)
self._parameter_values.update(parameters)

def build(self) -> Dict[str, Any]:
Expand All @@ -111,9 +121,15 @@ def build(self) -> Dict[str, Any]:
"Pipeline root must be specified, either during "
"compile time, or when calling the service."
)
if packaging.version.parse(self._schema_version) > packaging.version.parse(
"2.0.0"
):
parameter_values_key = "parameterValues"
else:
parameter_values_key = "parameters"
return {
"gcsOutputDirectory": self._pipeline_root,
"parameters": {
parameter_values_key: {
k: self._get_vertex_value(k, v)
for k, v in self._parameter_values.items()
if v is not None
Expand All @@ -122,7 +138,7 @@ def build(self) -> Dict[str, Any]:

def _get_vertex_value(
self, name: str, value: Union[int, float, str, bool, list, dict]
) -> Dict[str, Any]:
) -> Union[int, float, str, bool, list, dict]:
"""Converts primitive values into Vertex pipeline Value proto message.
Args:
Expand All @@ -147,27 +163,21 @@ def _get_vertex_value(
"pipeline job input definitions.".format(name)
)

result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
elif self._parameter_types[name] == "BOOLEAN":
result["boolValue"] = value
elif self._parameter_types[name] == "NUMBER_DOUBLE":
result["numberValue"] = value
elif self._parameter_types[name] == "NUMBER_INTEGER":
result["numberValue"] = value
elif self._parameter_types[name] == "LIST":
result["listValue"] = value
elif self._parameter_types[name] == "STRUCT":
result["structValue"] = value
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
else:
raise TypeError("Got unknown type of value: {}".format(value))

return result
return value


def _parse_runtime_parameters(
Expand Down

0 comments on commit 900a449

Please sign in to comment.