From 243b75c2655beeef47848410a40d86a072428ac3 Mon Sep 17 00:00:00 2001 From: Yaqi Ji Date: Thu, 16 Sep 2021 15:44:03 -0700 Subject: [PATCH] =?UTF-8?q?feat(PipelineJob):=20support=20dict,=20list,=20?= =?UTF-8?q?bool=20typed=20input=20parameters=20fr=E2=80=A6=20(#693)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: release 1.4.2 Release-As: 1.4.2 * address comments --- google/cloud/aiplatform/utils/pipeline_utils.py | 7 ++++++- tests/unit/aiplatform/test_utils.py | 14 +++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/google/cloud/aiplatform/utils/pipeline_utils.py b/google/cloud/aiplatform/utils/pipeline_utils.py index 31b08671a5..bc531d2b12 100644 --- a/google/cloud/aiplatform/utils/pipeline_utils.py +++ b/google/cloud/aiplatform/utils/pipeline_utils.py @@ -15,6 +15,7 @@ # import copy +import json from typing import Any, Dict, Mapping, Optional, Union @@ -89,7 +90,11 @@ def update_runtime_parameters( Optional. The mapping from runtime parameter names to its values. """ if parameter_values: - self._parameter_values.update(parameter_values) + parameters = dict(parameter_values) + for k, v in parameter_values.items(): + if isinstance(v, (dict, list, bool)): + parameters[k] = json.dumps(v) + self._parameter_values.update(parameters) def build(self) -> Dict[str, Any]: """Build a RuntimeConfig proto. diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index bdc674ebc0..418014ee45 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -370,6 +370,9 @@ class TestPipelineUtils: "int_param": {"type": "INT"}, "float_param": {"type": "DOUBLE"}, "new_param": {"type": "STRING"}, + "bool_param": {"type": "STRING"}, + "dict_param": {"type": "STRING"}, + "list_param": {"type": "STRING"}, } } } @@ -430,7 +433,13 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(self): ) my_builder.update_pipeline_root("path/to/my/new/root") my_builder.update_runtime_parameters( - {"int_param": 888, "new_param": "new-string"} + { + "int_param": 888, + "new_param": "new-string", + "dict_param": {"a": 1}, + "list_param": [1, 2, 3], + "bool_param": True, + } ) actual_runtime_config = my_builder.build() @@ -441,6 +450,9 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(self): "int_param": {"intValue": 888}, "float_param": {"doubleValue": 3.14}, "new_param": {"stringValue": "new-string"}, + "dict_param": {"stringValue": '{"a": 1}'}, + "list_param": {"stringValue": "[1, 2, 3]"}, + "bool_param": {"stringValue": "true"}, }, } assert expected_runtime_config == actual_runtime_config