Skip to content

Commit

Permalink
feat: Added explain tabular samples (#348)
Browse files Browse the repository at this point in the history
* Added tabular explanation sample

* Cleaned up mocks

* Ran linter

* Fixed mock and added explanation printing

* Added more verbose explanations

* Fixed endpoint fixture

* Fixed linting issues
  • Loading branch information
ivanmkc committed May 4, 2021
1 parent 9245d30 commit c95d1ce
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 9 deletions.
14 changes: 11 additions & 3 deletions samples/model-builder/conftest.py
Expand Up @@ -237,8 +237,9 @@ def mock_batch_predict_model(mock_model):


@pytest.fixture
def mock_upload_model():
with patch.object(aiplatform.models.Model, "upload") as mock:
def mock_upload_model(mock_model):
with patch.object(aiplatform.Model, "upload") as mock:
mock.return_value = mock_model
yield mock


Expand Down Expand Up @@ -277,7 +278,7 @@ def mock_endpoint():

@pytest.fixture
def mock_create_endpoint():
with patch.object(aiplatform.Endpoint, "create") as mock:
with patch.object(aiplatform.models.Endpoint, "create") as mock:
yield mock


Expand All @@ -286,3 +287,10 @@ def mock_get_endpoint(mock_endpoint):
with patch.object(aiplatform, "Endpoint") as mock_get_endpoint:
mock_get_endpoint.return_value = mock_endpoint
yield mock_get_endpoint


@pytest.fixture
def mock_endpoint_explain(mock_endpoint):
with patch.object(mock_endpoint, "explain") as mock_endpoint_explain:
mock_get_endpoint.return_value = mock_endpoint
yield mock_endpoint_explain
Expand Up @@ -16,7 +16,6 @@
from google.cloud.aiplatform import schema

import create_and_import_dataset_video_sample

import test_constants as constants


Expand Down
50 changes: 50 additions & 0 deletions samples/model-builder/explain_tabular_sample.py
@@ -0,0 +1,50 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

from google.cloud import aiplatform


# [START aiplatform_sdk_explain_tabular_sample]
def explain_tabular_sample(
project: str, location: str, endpoint_id: str, instance_dict: Dict
):

aiplatform.init(project=project, location=location)

endpoint = aiplatform.Endpoint(endpoint_id)

response = endpoint.explain(instances=[instance_dict], parameters={})

for explanation in response.explanations:
print(" explanation")
# Feature attributions.
attributions = explanation.attributions
for attribution in attributions:
print(" attribution")
print(" baseline_output_value:", attribution.baseline_output_value)
print(" instance_output_value:", attribution.instance_output_value)
print(" output_display_name:", attribution.output_display_name)
print(" approximation_error:", attribution.approximation_error)
print(" output_name:", attribution.output_name)
output_index = attribution.output_index
for output_index in output_index:
print(" output_index:", output_index)

for prediction in response.predictions:
print(prediction)


# [END aiplatform_sdk_explain_tabular_sample]
39 changes: 39 additions & 0 deletions samples/model-builder/explain_tabular_sample_test.py
@@ -0,0 +1,39 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import explain_tabular_sample
import test_constants as constants


def test_explain_tabular_sample(
mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain
):

explain_tabular_sample.explain_tabular_sample(
project=constants.PROJECT,
location=constants.LOCATION,
endpoint_id=constants.ENDPOINT_NAME,
instance_dict=constants.PREDICTION_TABULAR_INSTANCE,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)

mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,)

mock_endpoint_explain.assert_called_once_with(
instances=[constants.PREDICTION_TABULAR_INSTANCE], parameters={}
)
Expand Up @@ -17,7 +17,6 @@
import pytest

import import_data_video_classification_sample

import test_constants as constants


Expand Down
36 changes: 32 additions & 4 deletions samples/model-builder/test_constants.py
Expand Up @@ -131,15 +131,31 @@
inputs={
"features": {
"input_tensor_name": "dense_input",
"encoding": "BAG_OF_FEATURES",
# Input is tabular data
"modality": "numeric",
"index_feature_mapping": ["abc", "def", "ghj"],
# Assign feature names to the inputs for explanation
"encoding": "BAG_OF_FEATURES",
"index_feature_mapping": [
"crim",
"zn",
"indus",
"chas",
"nox",
"rm",
"age",
"dis",
"rad",
"tax",
"ptratio",
"b",
"lstat",
],
}
},
outputs={"medv": {"output_tensor_name": "dense_2"}},
outputs={"prediction": {"output_tensor_name": "dense_2"}},
)
EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters(
{"sampled_shapley_attribution": {"path_count": 10}}
{"xrai_attribution": {"step_count": 1}}
)

# Endpoint constants
Expand All @@ -148,4 +164,16 @@
TRAFFIC_SPLIT = {"a": 99, "b": 1}
MIN_REPLICA_COUNT = 1
MAX_REPLICA_COUNT = 1
ACCELERATOR_TYPE = "NVIDIA_TESLA_P100"
ACCELERATOR_COUNT = 2
ENDPOINT_DEPLOY_METADATA = ()
PREDICTION_TABULAR_INSTANCE = {
"longitude": "-124.35",
"latitude": "40.54",
"housing_median_age": "52.0",
"total_rooms": "1820.0",
"total_bedrooms": "300.0",
"population": "806",
"households": "270.0",
"median_income": "3.014700",
}
@@ -0,0 +1,70 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Sequence

from google.cloud import aiplatform


# [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample]
def upload_model_explain_tabular_managed_container_sample(
project,
location,
model_display_name: str,
serving_container_image_uri: str,
artifact_uri: Optional[str] = None,
serving_container_predict_route: Optional[str] = None,
serving_container_health_route: Optional[str] = None,
description: Optional[str] = None,
serving_container_command: Optional[Sequence[str]] = None,
serving_container_args: Optional[Sequence[str]] = None,
serving_container_environment_variables: Optional[Dict[str, str]] = None,
serving_container_ports: Optional[Sequence[int]] = None,
instance_schema_uri: Optional[str] = None,
parameters_schema_uri: Optional[str] = None,
prediction_schema_uri: Optional[str] = None,
explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None,
explanation_parameters: Optional[aiplatform.explain.ExplanationParameters] = None,
sync: bool = True,
):

aiplatform.init(project=project, location=location)

model = aiplatform.Model.upload(
display_name=model_display_name,
serving_container_image_uri=serving_container_image_uri,
artifact_uri=artifact_uri,
serving_container_predict_route=serving_container_predict_route,
serving_container_health_route=serving_container_health_route,
description=description,
serving_container_command=serving_container_command,
serving_container_args=serving_container_args,
serving_container_environment_variables=serving_container_environment_variables,
serving_container_ports=serving_container_ports,
instance_schema_uri=instance_schema_uri,
parameters_schema_uri=parameters_schema_uri,
prediction_schema_uri=prediction_schema_uri,
explanation_metadata=explanation_metadata,
explanation_parameters=explanation_parameters,
sync=sync,
)

model.wait()

print(model.display_name)
print(model.resource_name)
return model


# [END aiplatform_sdk_upload_model_explain_tabular_managed_container_sample]
@@ -0,0 +1,65 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import test_constants as constants

import upload_model_explain_tabular_managed_container_sample


def test_upload_model_explain_tabular_managed_container_sample(
mock_sdk_init, mock_model, mock_init_model, mock_upload_model
):

upload_model_explain_tabular_managed_container_sample.upload_model_explain_tabular_managed_container_sample(
project=constants.PROJECT,
location=constants.LOCATION,
model_display_name=constants.MODEL_NAME,
serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI,
artifact_uri=constants.MODEL_ARTIFACT_URI,
serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE,
serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE,
description=constants.DESCRIPTION,
serving_container_command=constants.SERVING_CONTAINER_COMMAND,
serving_container_args=constants.SERVING_CONTAINER_ARGS,
serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
serving_container_ports=constants.SERVING_CONTAINER_PORTS,
instance_schema_uri=constants.INSTANCE_SCHEMA_URI,
parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI,
prediction_schema_uri=constants.PREDICTION_SCHEMA_URI,
explanation_metadata=constants.EXPLANATION_METADATA,
explanation_parameters=constants.EXPLANATION_PARAMETERS,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)

mock_upload_model.assert_called_once_with(
display_name=constants.MODEL_NAME,
serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI,
artifact_uri=constants.MODEL_ARTIFACT_URI,
serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE,
serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE,
description=constants.DESCRIPTION,
serving_container_command=constants.SERVING_CONTAINER_COMMAND,
serving_container_args=constants.SERVING_CONTAINER_ARGS,
serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
serving_container_ports=constants.SERVING_CONTAINER_PORTS,
instance_schema_uri=constants.INSTANCE_SCHEMA_URI,
parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI,
prediction_schema_uri=constants.PREDICTION_SCHEMA_URI,
explanation_metadata=constants.EXPLANATION_METADATA,
explanation_parameters=constants.EXPLANATION_PARAMETERS,
sync=True,
)

0 comments on commit c95d1ce

Please sign in to comment.