Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Counterplots #402

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 53 additions & 1 deletion dice_ml/counterfactual_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import os

import jsonschema
import numpy as np
import pandas as pd
from counterplots import CreatePlot
from raiutils.exceptions import UserConfigValidationException

from dice_ml.constants import _SchemaVersions
from dice_ml.constants import BackEndTypes, _SchemaVersions
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
_DiverseCFV2SchemaConstants)

Expand Down Expand Up @@ -111,6 +114,55 @@ def visualize_as_list(self, display_sparse_df=True,
display_sparse_df=display_sparse_df,
show_only_changes=show_only_changes)

def plot_counterplots(self, dice_model):
"""Plot counterfactual with CounterPlots package.

:param dice_model: DiCE's model object.
"""
counterplots_out = []
for cf_examples in self.cf_examples_list:
self.features_names = list(cf_examples.test_instance_df.columns)[:-1]
self.features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1]
factual_instance = cf_examples.test_instance_df.to_numpy()[0][:-1]

def convert_data(x):
df_x = pd.DataFrame(data=x, columns=self.features_names)
# Transform each dtype according to features_dtypes
for feature_name, f_dtype in zip(self.features_names, self.features_dtypes):
df_x[feature_name] = pd.to_numeric(df_x[feature_name], errors='ignore').astype(f_dtype)

return df_x

if dice_model.backend == BackEndTypes.Sklearn:
self.factual_class_idx = np.argmax(
dice_model.model.predict_proba(convert_data([factual_instance])))

def model_pred(x):
# Use one against all strategy
pred_prob = dice_model.model.predict_proba(convert_data(x))
class_f_proba = pred_prob[:, self.factual_class_idx]

# Probability for all other classes (excluding class 0)
not_class_f_proba = 1 - class_f_proba

# Normalize to sum to 1
class_f_proba = class_f_proba / (class_f_proba + not_class_f_proba)

return class_f_proba
else:
def model_pred(x):
return dice_model.model.predict(dice_model.transformer.transform(convert_data(x)))

for cf_instance in cf_examples.final_cfs_df.to_numpy():
counterplots_out.append(
CreatePlot(
factual=factual_instance,
cf=cf_instance[:-1],
model_pred=model_pred,
feature_names=self.features_names,
))
return counterplots_out

@staticmethod
def _check_cf_exp_output_against_json_schema(
cf_dict, version):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pandas<2.0.0
scikit-learn
tqdm
raiutils>=0.4.0
counterplots>=0.0.7
79 changes: 79 additions & 0 deletions tests/test_counterfactual_explanations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import json
import unittest
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest
from raiutils.exceptions import UserConfigValidationException

Expand Down Expand Up @@ -319,3 +323,78 @@ def test_unsupported_versions_to_json(self, unsupported_version):
counterfactual_explanations.to_json()

assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve)


class TestCounterfactualExplanationsPlot(unittest.TestCase):

@patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot")
def test_plot_counterplots_sklearn(self, mock_create_plot):
# Dummy DiCE's model object with a Sklearn backend
dummy_model = Mock()
dummy_model.backend = "sklearn"
dummy_model.model.predict_proba = Mock(return_value=np.array([[0.4, 0.6], [0.2, 0.8]]))

# Sample cf_examples to test with
cf_examples_mock = Mock()
cf_examples_mock.test_instance_df = pd.DataFrame({
'feature1': [1],
'feature2': [2],
'target': [0]
})
cf_examples_mock.final_cfs_df = pd.DataFrame({
'feature1': [1.1, 1.2],
'feature2': [2.1, 2.2],
'target': [1, 1]
})

counterfact = CounterfactualExplanations(
cf_examples_list=[cf_examples_mock],
local_importance=None,
summary_importance=None,
version=None)

# Call function
result = counterfact.plot_counterplots(dummy_model)

# Assert the CreatePlot was called twice (as there are 2 counterfactual instances)
assert mock_create_plot.call_count == 2

# Assert that the result is as expected
assert result == ["dummy_plot", "dummy_plot"]

@patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot")
def test_plot_counterplots_non_sklearn(self, mock_create_plot):
# Sample Non-Sklearn backend
dummy_model = Mock()
dummy_model.backend = "NonSklearn"
dummy_model.model.predict = Mock(return_value=np.array([0, 1]))
dummy_model.transformer = Mock()
dummy_model.transformer.transform = Mock(return_value=np.array([[1, 2], [1.1, 2.1]]))

# Sample cf_examples to test with
cf_examples_mock = Mock()
cf_examples_mock.test_instance_df = pd.DataFrame({
'feature1': [1],
'feature2': [2],
'target': [0]
})
cf_examples_mock.final_cfs_df = pd.DataFrame({
'feature1': [1.1, 1.2],
'feature2': [2.1, 2.2],
'target': [1, 1]
})

counterfact = CounterfactualExplanations(
cf_examples_list=[cf_examples_mock],
local_importance=None,
summary_importance=None,
version=None)

# Call function
result = counterfact.plot_counterplots(dummy_model)

# Assert the CreatePlot was called twice (as there are 2 counterfactual instances)
assert mock_create_plot.call_count == 2

# Assert that the result is as expected
assert result == ["dummy_plot", "dummy_plot"]