Skip to content

Commit

Permalink
Merge pull request #296 from interpretml/gaugup/AddEmptyFeaturesToVar…
Browse files Browse the repository at this point in the history
…yListValidations

Raise user config validation exception when features_to_vary list is empty
  • Loading branch information
gaugup committed May 11, 2022
2 parents 5e70ef4 + 84e8763 commit 6b5a521
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 4 additions & 0 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def _validate_counterfactual_configuration(
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")

if features_to_vary != "all":
if len(features_to_vary) == 0:
raise UserConfigValidationException("Some features need to be varied for generating counterfactuals.")

if posthoc_sparsity_algorithm not in _PostHocSparsityTypes.ALL:
raise UserConfigValidationException(
'The posthoc_sparsity_algorithm should be {0} and not {1}'.format(
Expand Down
11 changes: 9 additions & 2 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,12 @@ def test_generate_counterfactuals_user_config_validations(
explainer_function(query_instances=sample_custom_query_2,
total_CFs=10, desired_range=[0, 10])

with pytest.raises(
UserConfigValidationException,
match=r'Some features need to be varied for generating counterfactuals.'):
explainer_function(query_instances=sample_custom_query_2,
total_CFs=10, features_to_vary=[])

@pytest.mark.parametrize('explainer_function',
['generate_counterfactuals', 'local_feature_importance',
'feature_importance', 'global_feature_importance'])
Expand All @@ -572,6 +578,9 @@ def test_generate_counterfactuals_user_config_validations_regression(
explainer_function(query_instances=sample_custom_query_1,
total_CFs=10, desired_range=[4, 3])


@pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])
class TestExplainerBaseDataValidations:
def test_global_feature_importance_error_conditions_with_insufficient_query_points(
self, method,
sample_custom_query_1,
Expand Down Expand Up @@ -665,5 +674,3 @@ def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_que
exp.local_feature_importance(
query_instances=sample_custom_query_1,
total_CFs=1)

# class TestExplainerBaseDataValidations:

0 comments on commit 6b5a521

Please sign in to comment.