Skip to content

Commit

Permalink
Add validation tests to dice explainers (#208)
Browse files Browse the repository at this point in the history
Signed-off-by: gaugup <gaugup@microsoft.com>
  • Loading branch information
gaugup committed Aug 12, 2021
1 parent fee19a8 commit 96efb6b
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,55 @@ def test_columns_out_of_order(self, desired_class, binary_classification_exp_obj
permitted_range=None,
features_to_vary='all')

@pytest.mark.parametrize("desired_class, binary_classification_exp_object",
[(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
indirect=['binary_classification_exp_object'])
def test_incorrect_features_to_vary_list(self, desired_class, binary_classification_exp_object, sample_custom_query_1):
exp = binary_classification_exp_object # explainer object
with pytest.raises(
UserConfigValidationException,
match="Got features {" + "'unknown_feature'" + "} which are not present in training data"):
exp.generate_counterfactuals(
query_instances=sample_custom_query_1,
total_CFs=10,
desired_class=desired_class,
desired_range=None,
permitted_range=None,
features_to_vary=['unknown_feature'])

@pytest.mark.parametrize("desired_class, binary_classification_exp_object",
[(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
indirect=['binary_classification_exp_object'])
def test_incorrect_features_permitted_range(self, desired_class, binary_classification_exp_object, sample_custom_query_1):
exp = binary_classification_exp_object # explainer object
with pytest.raises(
UserConfigValidationException,
match="Got features {" + "'unknown_feature'" + "} which are not present in training data"):
exp.generate_counterfactuals(
query_instances=sample_custom_query_1,
total_CFs=10,
desired_class=desired_class,
desired_range=None,
permitted_range={'unknown_feature': [1, 30]},
features_to_vary='all')

@pytest.mark.parametrize("desired_class, binary_classification_exp_object",
[(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
indirect=['binary_classification_exp_object'])
def test_incorrect_values_permitted_range(self, desired_class, binary_classification_exp_object, sample_custom_query_1):
exp = binary_classification_exp_object # explainer object
with pytest.raises(UserConfigValidationException) as ucve:
exp.generate_counterfactuals(
query_instances=sample_custom_query_1,
total_CFs=10,
desired_class=desired_class,
desired_range=None,
permitted_range={'Categorical': ['d']},
features_to_vary='all')

assert 'The category {0} does not occur in the training data for feature {1}. Allowed categories are {2}'.format(
'd', 'Categorical', ['a', 'b', 'c']) in str(ucve)


class TestExplainerBaseMultiClassClassification:

Expand Down

0 comments on commit 96efb6b

Please sign in to comment.