Skip to content

Commit

Permalink
fix: XAI Metadata compatibility with Model.upload (#705)
Browse files Browse the repository at this point in the history
* Add tests to check XAI metadata builder and Model.upload compatibility

* Use get_metadata_protobuf in test

* Fix version mismatch bug, tests passing
  • Loading branch information
vinnysenthil committed Sep 14, 2021
1 parent f0b58b0 commit f0570cb
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 13 deletions.
Expand Up @@ -17,9 +17,7 @@
from google.protobuf import json_format
from typing import Any, Dict, List, Optional

from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata
from google.cloud.aiplatform.explain.metadata import metadata_builder


Expand Down
Expand Up @@ -18,9 +18,7 @@
from typing import Optional, List, Dict, Any, Tuple

from google.cloud.aiplatform.explain.metadata import metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata


class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
Expand Down
Expand Up @@ -15,12 +15,15 @@
# limitations under the License.
#

import pytest
import tensorflow.compat.v1 as tf

from google.cloud.aiplatform import models
from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata

import test_models
from test_models import upload_model_mock, get_model_mock # noqa: F401


class SavedModelMetadataBuilderTF1Test(tf.test.TestCase):
Expand Down Expand Up @@ -108,3 +111,28 @@ def test_get_metadata_protobuf_double_output(self):
)

assert md_builder.get_metadata_protobuf() == expected_object

@pytest.mark.usefixtures("upload_model_mock", "get_model_mock")
def test_model_upload_compatibility(self):
self._set_up()
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.model_path, tags=[tf.saved_model.tag_constants.SERVING]
)

generated_md = md_builder.get_metadata_protobuf()

try:
models.Model.upload(
display_name=test_models._TEST_MODEL_NAME,
serving_container_image_uri=test_models._TEST_SERVING_CONTAINER_IMAGE,
explanation_parameters=test_models._TEST_EXPLANATION_PARAMETERS,
explanation_metadata=generated_md, # Test metadata from builder
labels=test_models._TEST_LABEL,
)
except TypeError as e:
if "Parameter to MergeFrom() must be instance of same class" in str(e):
pytest.fail(
f"Model.upload() expects different proto version, more info: {e}"
)
else:
raise e
Expand Up @@ -15,14 +15,16 @@
# limitations under the License.
#


import pytest
import tensorflow as tf
import numpy as np

from google.cloud.aiplatform import models
from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata

import test_models
from test_models import upload_model_mock, get_model_mock # noqa: F401


class SavedModelMetadataBuilderTF2Test(tf.test.TestCase):
Expand Down Expand Up @@ -184,3 +186,28 @@ def test_model_with_feature_column(self):
"outputs": {"output_1": {"outputTensorName": "output_1"}},
}
assert expected_md == generated_md

@pytest.mark.usefixtures("upload_model_mock", "get_model_mock")
def test_model_upload_compatibility(self):
self._set_up_sequential()

builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.saved_model_path
)
generated_md = builder.get_metadata_protobuf()

try:
models.Model.upload(
display_name=test_models._TEST_MODEL_NAME,
serving_container_image_uri=test_models._TEST_SERVING_CONTAINER_IMAGE,
explanation_parameters=test_models._TEST_EXPLANATION_PARAMETERS,
explanation_metadata=generated_md, # Test metadata from builder
labels=test_models._TEST_LABEL,
)
except TypeError as e:
if "Parameter to MergeFrom() must be instance of same class" in str(e):
pytest.fail(
f"Model.upload() expects different proto version, more info: {e}"
)
else:
raise e

0 comments on commit f0570cb

Please sign in to comment.