diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 112d5c200b..70431c9565 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -237,8 +237,9 @@ def mock_batch_predict_model(mock_model): @pytest.fixture -def mock_upload_model(): - with patch.object(aiplatform.models.Model, "upload") as mock: +def mock_upload_model(mock_model): + with patch.object(aiplatform.Model, "upload") as mock: + mock.return_value = mock_model yield mock @@ -277,7 +278,7 @@ def mock_endpoint(): @pytest.fixture def mock_create_endpoint(): - with patch.object(aiplatform.Endpoint, "create") as mock: + with patch.object(aiplatform.models.Endpoint, "create") as mock: yield mock @@ -286,3 +287,10 @@ def mock_get_endpoint(mock_endpoint): with patch.object(aiplatform, "Endpoint") as mock_get_endpoint: mock_get_endpoint.return_value = mock_endpoint yield mock_get_endpoint + + +@pytest.fixture +def mock_endpoint_explain(mock_endpoint): + with patch.object(mock_endpoint, "explain") as mock_endpoint_explain: + mock_get_endpoint.return_value = mock_endpoint + yield mock_endpoint_explain diff --git a/samples/model-builder/create_and_import_dataset_video_sample_test.py b/samples/model-builder/create_and_import_dataset_video_sample_test.py index 1ebbc7a3d0..e1d1ddeb19 100644 --- a/samples/model-builder/create_and_import_dataset_video_sample_test.py +++ b/samples/model-builder/create_and_import_dataset_video_sample_test.py @@ -16,7 +16,6 @@ from google.cloud.aiplatform import schema import create_and_import_dataset_video_sample - import test_constants as constants diff --git a/samples/model-builder/explain_tabular_sample.py b/samples/model-builder/explain_tabular_sample.py new file mode 100644 index 0000000000..16d1204787 --- /dev/null +++ b/samples/model-builder/explain_tabular_sample.py @@ -0,0 +1,50 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_explain_tabular_sample] +def explain_tabular_sample( + project: str, location: str, endpoint_id: str, instance_dict: Dict +): + + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint_id) + + response = endpoint.explain(instances=[instance_dict], parameters={}) + + for explanation in response.explanations: + print(" explanation") + # Feature attributions. + attributions = explanation.attributions + for attribution in attributions: + print(" attribution") + print(" baseline_output_value:", attribution.baseline_output_value) + print(" instance_output_value:", attribution.instance_output_value) + print(" output_display_name:", attribution.output_display_name) + print(" approximation_error:", attribution.approximation_error) + print(" output_name:", attribution.output_name) + output_index = attribution.output_index + for output_index in output_index: + print(" output_index:", output_index) + + for prediction in response.predictions: + print(prediction) + + +# [END aiplatform_sdk_explain_tabular_sample] diff --git a/samples/model-builder/explain_tabular_sample_test.py b/samples/model-builder/explain_tabular_sample_test.py new file mode 100644 index 0000000000..d088da9658 --- /dev/null +++ b/samples/model-builder/explain_tabular_sample_test.py @@ -0,0 +1,39 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import explain_tabular_sample +import test_constants as constants + + +def test_explain_tabular_sample( + mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain +): + + explain_tabular_sample.explain_tabular_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint_id=constants.ENDPOINT_NAME, + instance_dict=constants.PREDICTION_TABULAR_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) + + mock_endpoint_explain.assert_called_once_with( + instances=[constants.PREDICTION_TABULAR_INSTANCE], parameters={} + ) diff --git a/samples/model-builder/import_data_video_classification_sample_test.py b/samples/model-builder/import_data_video_classification_sample_test.py index cce5c0abd6..5e5e142533 100644 --- a/samples/model-builder/import_data_video_classification_sample_test.py +++ b/samples/model-builder/import_data_video_classification_sample_test.py @@ -17,7 +17,6 @@ import pytest import import_data_video_classification_sample - import test_constants as constants diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 994a8724ee..641fa1c490 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -131,15 +131,31 @@ inputs={ "features": { "input_tensor_name": "dense_input", - "encoding": "BAG_OF_FEATURES", + # Input is tabular data "modality": "numeric", - "index_feature_mapping": ["abc", "def", "ghj"], + # Assign feature names to the inputs for explanation + "encoding": "BAG_OF_FEATURES", + "index_feature_mapping": [ + "crim", + "zn", + "indus", + "chas", + "nox", + "rm", + "age", + "dis", + "rad", + "tax", + "ptratio", + "b", + "lstat", + ], } }, - outputs={"medv": {"output_tensor_name": "dense_2"}}, + outputs={"prediction": {"output_tensor_name": "dense_2"}}, ) EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( - {"sampled_shapley_attribution": {"path_count": 10}} + {"xrai_attribution": {"step_count": 1}} ) # Endpoint constants @@ -148,4 +164,16 @@ TRAFFIC_SPLIT = {"a": 99, "b": 1} MIN_REPLICA_COUNT = 1 MAX_REPLICA_COUNT = 1 +ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +ACCELERATOR_COUNT = 2 ENDPOINT_DEPLOY_METADATA = () +PREDICTION_TABULAR_INSTANCE = { + "longitude": "-124.35", + "latitude": "40.54", + "housing_median_age": "52.0", + "total_rooms": "1820.0", + "total_bedrooms": "300.0", + "population": "806", + "households": "270.0", + "median_income": "3.014700", +} diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py new file mode 100644 index 0000000000..bc676ba917 --- /dev/null +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py @@ -0,0 +1,70 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] +def upload_model_explain_tabular_managed_container_sample( + project, + location, + model_display_name: str, + serving_container_image_uri: str, + artifact_uri: Optional[str] = None, + serving_container_predict_route: Optional[str] = None, + serving_container_health_route: Optional[str] = None, + description: Optional[str] = None, + serving_container_command: Optional[Sequence[str]] = None, + serving_container_args: Optional[Sequence[str]] = None, + serving_container_environment_variables: Optional[Dict[str, str]] = None, + serving_container_ports: Optional[Sequence[int]] = None, + instance_schema_uri: Optional[str] = None, + parameters_schema_uri: Optional[str] = None, + prediction_schema_uri: Optional[str] = None, + explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None, + explanation_parameters: Optional[aiplatform.explain.ExplanationParameters] = None, + sync: bool = True, +): + + aiplatform.init(project=project, location=location) + + model = aiplatform.Model.upload( + display_name=model_display_name, + serving_container_image_uri=serving_container_image_uri, + artifact_uri=artifact_uri, + serving_container_predict_route=serving_container_predict_route, + serving_container_health_route=serving_container_health_route, + description=description, + serving_container_command=serving_container_command, + serving_container_args=serving_container_args, + serving_container_environment_variables=serving_container_environment_variables, + serving_container_ports=serving_container_ports, + instance_schema_uri=instance_schema_uri, + parameters_schema_uri=parameters_schema_uri, + prediction_schema_uri=prediction_schema_uri, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + return model + + +# [END aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py new file mode 100644 index 0000000000..653de93f74 --- /dev/null +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py @@ -0,0 +1,65 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import test_constants as constants + +import upload_model_explain_tabular_managed_container_sample + + +def test_upload_model_explain_tabular_managed_container_sample( + mock_sdk_init, mock_model, mock_init_model, mock_upload_model +): + + upload_model_explain_tabular_managed_container_sample.upload_model_explain_tabular_managed_container_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_display_name=constants.MODEL_NAME, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + description=constants.DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_upload_model.assert_called_once_with( + display_name=constants.MODEL_NAME, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + description=constants.DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + sync=True, + )