Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add parameters_value in PipelineJob for schema > 2.0.0 #817

Merged
merged 14 commits into from Nov 15, 2021
Merged
62 changes: 36 additions & 26 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Expand Up @@ -17,6 +17,7 @@
import copy
import json
from typing import Any, Dict, Mapping, Optional, Union
import packaging.version


class PipelineRuntimeConfigBuilder(object):
Expand All @@ -28,6 +29,7 @@ class PipelineRuntimeConfigBuilder(object):
def __init__(
self,
pipeline_root: str,
schema_version: str,
parameter_types: Mapping[str, str],
parameter_values: Optional[Dict[str, Any]] = None,
):
Expand All @@ -36,12 +38,15 @@ def __init__(
Args:
pipeline_root (str):
Required. The root of the pipeline outputs.
schema_version (str):
Required. Schema version of the IR. This field determines the fields supported in current version of IR.
parameter_types (Mapping[str, str]):
Required. The mapping from pipeline parameter name to its type.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter name to its value.
"""
self._pipeline_root = pipeline_root
self._schema_version = schema_version
self._parameter_types = parameter_types
self._parameter_values = copy.deepcopy(parameter_values or {})

Expand All @@ -64,6 +69,8 @@ def from_job_spec_json(
.get("inputDefinitions", {})
.get("parameters", {})
)
schema_version = job_spec["pipelineSpec"]["schemaVersion"]

# 'type' is deprecated in IR and change to 'parameterType'.
parameter_types = {
k: v.get("parameterType") or v.get("type")
Expand All @@ -72,7 +79,7 @@ def from_job_spec_json(

pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
parameter_values = _parse_runtime_parameters(runtime_config_spec)
return cls(pipeline_root, parameter_types, parameter_values)
return cls(pipeline_root, schema_version, parameter_types, parameter_values)

def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
"""Updates pipeline_root value.
Expand All @@ -95,9 +102,12 @@ def update_runtime_parameters(
"""
if parameter_values:
parameters = dict(parameter_values)
for k, v in parameter_values.items():
if isinstance(v, (dict, list, bool)):
parameters[k] = json.dumps(v)
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
for k, v in parameter_values.items():
if isinstance(v, (dict, list, bool)):
parameters[k] = json.dumps(v)
self._parameter_values.update(parameters)

def build(self) -> Dict[str, Any]:
Expand All @@ -111,9 +121,15 @@ def build(self) -> Dict[str, Any]:
"Pipeline root must be specified, either during "
"compile time, or when calling the service."
)
if packaging.version.parse(self._schema_version) > packaging.version.parse(
"2.0.0"
):
parameter_values_key = "parameterValues"
else:
parameter_values_key = "parameters"
return {
"gcsOutputDirectory": self._pipeline_root,
"parameters": {
parameter_values_key: {
k: self._get_vertex_value(k, v)
for k, v in self._parameter_values.items()
if v is not None
Expand All @@ -122,7 +138,7 @@ def build(self) -> Dict[str, Any]:

def _get_vertex_value(
self, name: str, value: Union[int, float, str, bool, list, dict]
) -> Dict[str, Any]:
) -> Union[int, float, str, bool, list, dict]:
"""Converts primitive values into Vertex pipeline Value proto message.

Args:
Expand All @@ -147,27 +163,21 @@ def _get_vertex_value(
"pipeline job input definitions.".format(name)
)

result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
elif self._parameter_types[name] == "BOOLEAN":
result["boolValue"] = value
elif self._parameter_types[name] == "NUMBER_DOUBLE":
result["numberValue"] = value
elif self._parameter_types[name] == "NUMBER_INTEGER":
result["numberValue"] = value
elif self._parameter_types[name] == "LIST":
result["listValue"] = value
elif self._parameter_types[name] == "STRUCT":
result["structValue"] = value
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
else:
raise TypeError("Got unknown type of value: {}".format(value))

return result
return value


def _parse_runtime_parameters(
Expand Down