From f0570cb999f024ca96e7daaa102c81b681c2a575 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Tue, 14 Sep 2021 15:28:17 -0700 Subject: [PATCH] fix: XAI Metadata compatibility with Model.upload (#705) * Add tests to check XAI metadata builder and Model.upload compatibility * Use get_metadata_protobuf in test * Fix version mismatch bug, tests passing --- .../tf/v1/saved_model_metadata_builder.py | 4 +-- .../tf/v2/saved_model_metadata_builder.py | 4 +-- ...n_saved_model_metadata_builder_tf1_test.py | 34 ++++++++++++++++-- ...n_saved_model_metadata_builder_tf2_test.py | 35 ++++++++++++++++--- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py b/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py index 89261f8c1f..6f0af6d93b 100644 --- a/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py +++ b/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py @@ -17,9 +17,7 @@ from google.protobuf import json_format from typing import Any, Dict, List, Optional -from google.cloud.aiplatform.compat.types import ( - explanation_metadata_v1beta1 as explanation_metadata, -) +from google.cloud.aiplatform.compat.types import explanation_metadata from google.cloud.aiplatform.explain.metadata import metadata_builder diff --git a/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py b/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py index 36f520d7b0..dd7f2b8d0a 100644 --- a/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py +++ b/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py @@ -18,9 +18,7 @@ from typing import Optional, List, Dict, Any, Tuple from google.cloud.aiplatform.explain.metadata import metadata_builder -from google.cloud.aiplatform.compat.types import ( - explanation_metadata_v1beta1 as explanation_metadata, -) +from google.cloud.aiplatform.compat.types import explanation_metadata class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder): diff --git a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py index 41ff8bb68e..8c83b3b087 100644 --- a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py +++ b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf1_test.py @@ -15,12 +15,15 @@ # limitations under the License. # +import pytest import tensorflow.compat.v1 as tf +from google.cloud.aiplatform import models from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder -from google.cloud.aiplatform.compat.types import ( - explanation_metadata_v1beta1 as explanation_metadata, -) +from google.cloud.aiplatform.compat.types import explanation_metadata + +import test_models +from test_models import upload_model_mock, get_model_mock # noqa: F401 class SavedModelMetadataBuilderTF1Test(tf.test.TestCase): @@ -108,3 +111,28 @@ def test_get_metadata_protobuf_double_output(self): ) assert md_builder.get_metadata_protobuf() == expected_object + + @pytest.mark.usefixtures("upload_model_mock", "get_model_mock") + def test_model_upload_compatibility(self): + self._set_up() + md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder( + self.model_path, tags=[tf.saved_model.tag_constants.SERVING] + ) + + generated_md = md_builder.get_metadata_protobuf() + + try: + models.Model.upload( + display_name=test_models._TEST_MODEL_NAME, + serving_container_image_uri=test_models._TEST_SERVING_CONTAINER_IMAGE, + explanation_parameters=test_models._TEST_EXPLANATION_PARAMETERS, + explanation_metadata=generated_md, # Test metadata from builder + labels=test_models._TEST_LABEL, + ) + except TypeError as e: + if "Parameter to MergeFrom() must be instance of same class" in str(e): + pytest.fail( + f"Model.upload() expects different proto version, more info: {e}" + ) + else: + raise e diff --git a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py index 5ebc0a9af7..a18eed243c 100644 --- a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py +++ b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py @@ -15,14 +15,16 @@ # limitations under the License. # - +import pytest import tensorflow as tf import numpy as np +from google.cloud.aiplatform import models from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder -from google.cloud.aiplatform.compat.types import ( - explanation_metadata_v1beta1 as explanation_metadata, -) +from google.cloud.aiplatform.compat.types import explanation_metadata + +import test_models +from test_models import upload_model_mock, get_model_mock # noqa: F401 class SavedModelMetadataBuilderTF2Test(tf.test.TestCase): @@ -184,3 +186,28 @@ def test_model_with_feature_column(self): "outputs": {"output_1": {"outputTensorName": "output_1"}}, } assert expected_md == generated_md + + @pytest.mark.usefixtures("upload_model_mock", "get_model_mock") + def test_model_upload_compatibility(self): + self._set_up_sequential() + + builder = saved_model_metadata_builder.SavedModelMetadataBuilder( + self.saved_model_path + ) + generated_md = builder.get_metadata_protobuf() + + try: + models.Model.upload( + display_name=test_models._TEST_MODEL_NAME, + serving_container_image_uri=test_models._TEST_SERVING_CONTAINER_IMAGE, + explanation_parameters=test_models._TEST_EXPLANATION_PARAMETERS, + explanation_metadata=generated_md, # Test metadata from builder + labels=test_models._TEST_LABEL, + ) + except TypeError as e: + if "Parameter to MergeFrom() must be instance of same class" in str(e): + pytest.fail( + f"Model.upload() expects different proto version, more info: {e}" + ) + else: + raise e