From 918998c0bdc25b6a39d359a34f892dac1ca4efac Mon Sep 17 00:00:00 2001 From: Yaqi Ji Date: Mon, 19 Jul 2021 17:03:43 -0700 Subject: [PATCH] feat: add tf1 metadata builder (#526) * feat: add tf1 metadata builder * Change import checks --- .../explain/metadata/tf/v1/__init__.py | 15 ++ .../tf/v1/saved_model_metadata_builder.py | 165 ++++++++++++++++++ .../tf/v2/saved_model_metadata_builder.py | 2 +- ...n_saved_model_metadata_builder_tf1_test.py | 82 +++++++++ ..._saved_model_metadata_builder_tf2_test.py} | 2 +- 5 files changed, 264 insertions(+), 2 deletions(-) create mode 100644 google/cloud/aiplatform/explain/metadata/tf/v1/__init__.py create mode 100644 google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py create mode 100644 tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py rename tests/unit/aiplatform/{test_explain_saved_model_metadata_builder_test.py => test_explain_saved_model_metadata_builder_tf2_test.py} (99%) diff --git a/google/cloud/aiplatform/explain/metadata/tf/v1/__init__.py b/google/cloud/aiplatform/explain/metadata/tf/v1/__init__.py new file mode 100644 index 0000000000..0e973c9a40 --- /dev/null +++ b/google/cloud/aiplatform/explain/metadata/tf/v1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py b/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py new file mode 100644 index 0000000000..29f5b5b900 --- /dev/null +++ b/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.protobuf import json_format +from typing import Any, Dict, List, Optional + +from google.cloud.aiplatform.compat.types import ( + explanation_metadata_v1beta1 as explanation_metadata, +) +from google.cloud.aiplatform.explain.metadata import metadata_builder + + +class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder): + """Metadata builder class that accepts a TF1 saved model.""" + + def __init__( + self, + model_path: str, + tags: Optional[List[str]] = None, + signature_name: Optional[str] = None, + outputs_to_explain: Optional[List[str]] = None, + ) -> None: + """Initializes a SavedModelMetadataBuilder object. + + Args: + model_path: + Required. Local or GCS path to load the saved model from. + tags: + Optional. Tags to identify the model graph. If None or empty, + TensorFlow's default serving tag will be used. + signature_name: + Optional. Name of the signature to be explained. Inputs and + outputs of this signature will be written in the metadata. If not + provided, the default signature will be used. + outputs_to_explain: + Optional. List of output names to explain. Only single output is + supported for now. Hence, the list should contain one element. + This parameter is required if the model signature (provided via + signature_name) specifies multiple outputs. + + Raises: + ValueError if outputs_to_explain contains more than 1 element or + signature contains multiple outputs. + """ + if outputs_to_explain: + if len(outputs_to_explain) > 1: + raise ValueError( + "Only one output is supported at the moment. " + f"Received: {outputs_to_explain}." + ) + self._output_to_explain = next(iter(outputs_to_explain)) + + try: + import tensorflow.compat.v1 as tf + except ImportError: + raise ImportError( + "Tensorflow is not installed and is required to load saved model. " + 'Please install the SDK using "pip install "tensorflow>=1.15,<2.0""' + ) + + if not signature_name: + signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY + self._tags = tags or [tf.saved_model.tag_constants.SERVING] + self._graph = tf.Graph() + + with self.graph.as_default(): + self._session = tf.Session(graph=self.graph) + self._metagraph_def = tf.saved_model.loader.load( + sess=self.session, tags=self._tags, export_dir=model_path + ) + if signature_name not in self._metagraph_def.signature_def: + raise ValueError( + f"Serving sigdef key {signature_name} not in the signature def." + ) + serving_sigdef = self._metagraph_def.signature_def[signature_name] + if not outputs_to_explain: + if len(serving_sigdef.outputs) > 1: + raise ValueError( + "The signature contains multiple outputs. Specify " + 'an output via "outputs_to_explain" parameter.' + ) + self._output_to_explain = next(iter(serving_sigdef.outputs.keys())) + + self._inputs = _create_input_metadata_from_signature(serving_sigdef.inputs) + self._outputs = _create_output_metadata_from_signature( + serving_sigdef.outputs, self._output_to_explain + ) + + @property + def graph(self) -> "tf.Graph": # noqa: F821 + return self._graph + + @property + def session(self) -> "tf.Session": # noqa: F821 + return self._session + + def get_metadata(self) -> Dict[str, Any]: + """Returns the current metadata as a dictionary. + + Returns: + Json format of the explanation metadata. + """ + current_md = explanation_metadata.ExplanationMetadata( + inputs=self._inputs, outputs=self._outputs, + ) + return json_format.MessageToDict(current_md._pb) + + +def _create_input_metadata_from_signature( + signature_inputs: Dict[str, "tf.Tensor"] # noqa: F821 +) -> Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata]: + """Creates InputMetadata from signature inputs. + + Args: + signature_inputs: + Required. Inputs of the signature to be explained. If not provided, + the default signature will be used. + + Returns: + Inferred input metadata from the model. + """ + input_mds = {} + for key, tensor in signature_inputs.items(): + input_mds[key] = explanation_metadata.ExplanationMetadata.InputMetadata( + input_tensor_name=tensor.name + ) + return input_mds + + +def _create_output_metadata_from_signature( + signature_outputs: Dict[str, "tf.Tensor"], # noqa: F821 + output_to_explain: Optional[str] = None, +) -> Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata]: + """Creates OutputMetadata from signature inputs. + + Args: + signature_outputs: + Required. Inputs of the signature to be explained. If not provided, + the default signature will be used. + output_to_explain: + Optional. Output name to explain. + + Returns: + Inferred output metadata from the model. + """ + output_mds = {} + for key, tensor in signature_outputs.items(): + if not output_to_explain or output_to_explain == key: + output_mds[key] = explanation_metadata.ExplanationMetadata.OutputMetadata( + output_tensor_name=tensor.name + ) + return output_mds diff --git a/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py b/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py index 9541310d21..abff1a2e12 100644 --- a/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py +++ b/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py @@ -37,7 +37,7 @@ def __init__( Args: model_path: - Required. Path to load the saved model from. + Required. Local or GCS path to load the saved model from. signature_name: Optional. Name of the signature to be explained. Inputs and outputs of this signature will be written in the metadata. If not diff --git a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py new file mode 100644 index 0000000000..c24553751f --- /dev/null +++ b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow.compat.v1 as tf + +from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder + + +class SavedModelMetadataBuilderTF1Test(tf.test.TestCase): + def _set_up(self): + self.sess = tf.Session(graph=tf.Graph()) + with self.sess.graph.as_default(): + self.x = tf.placeholder(shape=[None, 10], dtype=tf.float32, name="inp") + weights = tf.constant(1.0, shape=(10, 2), name="weights") + bias_weight = tf.constant(1.0, shape=(2,), name="bias") + self.linear_layer = tf.add(tf.matmul(self.x, weights), bias_weight) + self.prediction = tf.nn.relu(self.linear_layer) + # save the model + self.model_path = self.get_temp_dir() + builder = tf.saved_model.builder.SavedModelBuilder(self.model_path) + tensor_info_x = tf.saved_model.utils.build_tensor_info(self.x) + tensor_info_pred = tf.saved_model.utils.build_tensor_info(self.prediction) + tensor_info_lin = tf.saved_model.utils.build_tensor_info(self.linear_layer) + prediction_signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs={"x": tensor_info_x}, + outputs={"y": tensor_info_pred}, + method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME, + ) + double_output_signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs={"x": tensor_info_x}, + outputs={"y": tensor_info_pred, "lin": tensor_info_lin}, + method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME, + ) + + builder.add_meta_graph_and_variables( + self.sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={ + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature, + "double": double_output_signature, + }, + ) + builder.save() + + def test_get_metadata_correct_inputs(self): + self._set_up() + md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder( + self.model_path, tags=[tf.saved_model.tag_constants.SERVING] + ) + expected_md = { + "inputs": {"x": {"inputTensorName": "inp:0"}}, + "outputs": {"y": {"outputTensorName": "Relu:0"}}, + } + + assert md_builder.get_metadata() == expected_md + + def test_get_metadata_double_output(self): + self._set_up() + md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder( + self.model_path, signature_name="double", outputs_to_explain=["lin"] + ) + + expected_md = { + "inputs": {"x": {"inputTensorName": "inp:0"}}, + "outputs": {"lin": {"outputTensorName": "Add:0"}}, + } + + assert md_builder.get_metadata() == expected_md diff --git a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_test.py b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py similarity index 99% rename from tests/unit/aiplatform/test_explain_saved_model_metadata_builder_test.py rename to tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py index e5e70bcea0..6e297f6949 100644 --- a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_test.py +++ b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py @@ -22,7 +22,7 @@ from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder -class SavedModelMetadataBuilderTest(tf.test.TestCase): +class SavedModelMetadataBuilderTF2Test(tf.test.TestCase): def test_get_metadata_sequential(self): # Set up for the sequential. self.seq_model = tf.keras.models.Sequential()