Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: XAI Metadata compatibility with Model.upload #705

Merged
merged 4 commits into from Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -17,9 +17,7 @@
from google.protobuf import json_format
from typing import Any, Dict, List, Optional

from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata
from google.cloud.aiplatform.explain.metadata import metadata_builder


Expand Down
Expand Up @@ -18,9 +18,7 @@
from typing import Optional, List, Dict, Any, Tuple

from google.cloud.aiplatform.explain.metadata import metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata


class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
Expand Down
Expand Up @@ -15,12 +15,15 @@
# limitations under the License.
#

import pytest
import tensorflow.compat.v1 as tf

from google.cloud.aiplatform import models
from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata

import test_models
from test_models import upload_model_mock, get_model_mock # noqa: F401


class SavedModelMetadataBuilderTF1Test(tf.test.TestCase):
Expand Down Expand Up @@ -108,3 +111,28 @@ def test_get_metadata_protobuf_double_output(self):
)

assert md_builder.get_metadata_protobuf() == expected_object

@pytest.mark.usefixtures("upload_model_mock", "get_model_mock")
def test_model_upload_compatibility(self):
self._set_up()
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.model_path, tags=[tf.saved_model.tag_constants.SERVING]
)

generated_md = md_builder.get_metadata_protobuf()

try:
models.Model.upload(
display_name=test_models._TEST_MODEL_NAME,
serving_container_image_uri=test_models._TEST_SERVING_CONTAINER_IMAGE,
explanation_parameters=test_models._TEST_EXPLANATION_PARAMETERS,
explanation_metadata=generated_md, # Test metadata from builder
labels=test_models._TEST_LABEL,
)
except TypeError as e:
if "Parameter to MergeFrom() must be instance of same class" in str(e):
pytest.fail(
f"Model.upload() expects different proto version, more info: {e}"
)
else:
raise e
Expand Up @@ -15,14 +15,16 @@
# limitations under the License.
#


import pytest
import tensorflow as tf
import numpy as np

from google.cloud.aiplatform import models
from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)
from google.cloud.aiplatform.compat.types import explanation_metadata

import test_models
from test_models import upload_model_mock, get_model_mock # noqa: F401


class SavedModelMetadataBuilderTF2Test(tf.test.TestCase):
Expand Down Expand Up @@ -184,3 +186,28 @@ def test_model_with_feature_column(self):
"outputs": {"output_1": {"outputTensorName": "output_1"}},
}
assert expected_md == generated_md

@pytest.mark.usefixtures("upload_model_mock", "get_model_mock")
def test_model_upload_compatibility(self):
self._set_up_sequential()

builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.saved_model_path
)
generated_md = builder.get_metadata_protobuf()

try:
models.Model.upload(
display_name=test_models._TEST_MODEL_NAME,
serving_container_image_uri=test_models._TEST_SERVING_CONTAINER_IMAGE,
explanation_parameters=test_models._TEST_EXPLANATION_PARAMETERS,
explanation_metadata=generated_md, # Test metadata from builder
labels=test_models._TEST_LABEL,
)
except TypeError as e:
if "Parameter to MergeFrom() must be instance of same class" in str(e):
pytest.fail(
f"Model.upload() expects different proto version, more info: {e}"
)
else:
raise e