Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update explainer_base.py #424

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 1 addition & 5 deletions dice_ml/data_interfaces/private_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,7 @@ def get_features_range(self, permitted_range_input=None, features_dict=None):

ranges = {}
# Getting default ranges based on the dataset
for feature in features_dict:
if type(features_dict[feature][0]) is int: # continuous feature
ranges[feature] = features_dict[feature]
else:
ranges[feature] = features_dict[feature]
ranges[feature] = features_dict[feature]
feature_ranges_orig = ranges.copy()
# Overwriting the ranges for a feature if input provided
if permitted_range_input is not None:
Expand Down
7 changes: 1 addition & 6 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,7 @@ def setup(self, features_to_vary, permitted_range, query_instance, feature_weigh
if features_to_vary == 'all':
features_to_vary = self.data_interface.feature_names

if permitted_range is None: # use the precomputed default
self.feature_range = self.data_interface.permitted_range
feature_ranges_orig = self.feature_range
else: # compute the new ranges based on user input
self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range)

self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range)
self.check_query_instance_validity(features_to_vary, permitted_range, query_instance, feature_ranges_orig)

return features_to_vary
Expand Down
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,34 @@ def _load_custom_vars_dataset_model():
return model


def _load_adult_income_binary_model():
dataset = helpers.load_adult_income_dataset()
X_train = dataset.drop('income', axis=1)
y_train = dataset["income"]
num_feature_names = ["age", "hours_per_week"]
cat_feature_names = X_train.columns.difference(num_feature_names)
model = create_complex_classification_pipeline(
X_train, y_train, num_feature_names, cat_feature_names)
return model


def sample_adult_income_custom_query_11():
"""
Returns multiple query instance for adult income dataset
"""
data_point = 2
query_instances = pd.DataFrame({'age': [22]*data_point,
'workclass': ['Private']*data_point,
'education': ['HS-grad']*data_point,
'marital_status': ['Single']*data_point,
'occupation': ['Service']*data_point,
'race': ['White']*data_point,
'gender': ['Female']*data_point,
'hours_per_week': [45]*data_point},
index=list(range(data_point)))
return query_instances


@pytest.fixture(scope='session')
def sample_adultincome_query():
"""
Expand Down
19 changes: 18 additions & 1 deletion tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
from dice_ml.utils import helpers

from ..conftest import _load_custom_testing_binary_model
from ..conftest import (private_data_object,
sample_adult_income_custom_query_11,
_load_adult_income_binary_model,
_load_custom_testing_binary_model)


@pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])
Expand Down Expand Up @@ -349,6 +352,20 @@ def test_cfs_type_consistency(
assert cf_explanations.cf_examples_list[0].final_cfs_df[col].dtype == sample_custom_query[col].dtype
if cf_explanations.cf_examples_list[0].final_cfs_df_sparse is not None:
assert cf_explanations.cf_examples_list[0].final_cfs_df_sparse[col].dtype == sample_custom_query[col].dtype

@pytest.mark.parametrize("method", ["genetic"])
def test_genetic_private_data(method):
d = private_data_object()
query = sample_adult_income_custom_query_11()
model = _load_adult_income_binary_model()
m = dice_ml.Model(model=model, backend='sklearn')
exp = dice_ml.Dice(d, m, method=method)

return exp.generate_counterfactuals(
query_instances=query,
total_CFs=1,
desired_class="opposite",
initialization="random")


@pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])
Expand Down