Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
Signed-off-by: giandos200 <giando95menico@gmail.com>
  • Loading branch information
giandos200 committed Jan 14, 2022
1 parent 785bc4a commit 00dda52
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 36 deletions.
9 changes: 5 additions & 4 deletions dice_ml/explainer_interfaces/dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
# post-hoc operation on continuous features to enhance sparsity - only for public data
if posthoc_sparsity_param is not None and posthoc_sparsity_param > 0 and 'data_df' in self.data_interface.__dict__:
self.final_cfs_df_sparse = copy.deepcopy(self.final_cfs)
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse, query_instance,
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse,
query_instance,
posthoc_sparsity_param,
posthoc_sparsity_algorithm)
else:
Expand All @@ -265,9 +266,9 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
'change the query instance or the features to vary...' '; total time taken: %02d' % m,
'min %02d' % s, 'sec')
elif total_cfs_found == 0:
print(
'No Counterfactuals found for the given configuration, perhaps try with different parameters...',
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
print(
'No Counterfactuals found for the given configuration, perhaps try with different parameters...',
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
else:
print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec')

Expand Down
8 changes: 3 additions & 5 deletions dice_ml/explainer_interfaces/dice_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir
kx += 1
self.cfs = np.array(row)

#if len(self.cfs) > self.population_size:
# pass
if len(self.cfs) != self.population_size:
print("Pericolo Loop infinito....!!!!")
remaining_cfs = self.do_random_init(
Expand Down Expand Up @@ -264,7 +262,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k
(see diverse_counterfactuals.py).
"""

self.population_size = 3 * total_CFs
self.population_size = 10 * total_CFs

self.start_time = timeit.default_timer()

Expand Down Expand Up @@ -470,8 +468,8 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
if rest_members > 0:
new_generation_2 = np.zeros((rest_members, self.data_interface.number_of_features))
for new_gen_idx in range(rest_members):
parent1 = random.choice(population[:max(int(len(population) / 2),1)])
parent2 = random.choice(population[:max(int(len(population) / 2),1)])
parent1 = random.choice(population[:max(int(len(population) / 2), 1)])
parent2 = random.choice(population[:max(int(len(population) / 2), 1)])
child = self.mate(parent1, parent2, features_to_vary, query_instance)
new_generation_2[new_gen_idx] = child

Expand Down
61 changes: 35 additions & 26 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,22 @@ def generate_counterfactuals(self, query_instances, total_CFs,
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
if total_CFs > 10:
if posthoc_sparsity_algorithm == None:
if posthoc_sparsity_algorithm is None:
posthoc_sparsity_algorithm = 'binary'
elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear':
elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear':
import warnings
warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
"'binary' search!".format(total_CFs))
elif posthoc_sparsity_algorithm == None:
warnings.warn(
"The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
"'binary' search!".format(total_CFs))
elif posthoc_sparsity_algorithm is None:
posthoc_sparsity_algorithm = 'linear'

cf_examples_arr = []
query_instances_list = []
if isinstance(query_instances, pd.DataFrame):
for ix in range(query_instances.shape[0]):
query_instances_list.append(query_instances[ix:(ix+1)])
query_instances_list.append(query_instances[ix:(ix + 1)])
elif isinstance(query_instances, Iterable):
query_instances_list = query_instances

Expand Down Expand Up @@ -190,11 +191,14 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query

if feature not in features_to_vary and permitted_range is not None:
if feature in permitted_range and feature in self.data_interface.continuous_feature_names:
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][1]:
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][
1]:
raise ValueError("Feature:", feature,
"is outside the permitted range and isn't allowed to vary.")
elif feature in permitted_range and feature in self.data_interface.categorical_feature_names:
if query_instance[feature].values[0] not in self.feature_range[feature]:
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
raise ValueError("Feature:", feature,
"is outside the permitted range and isn't allowed to vary.")

def local_feature_importance(self, query_instances, cf_examples_list=None,
total_CFs=10,
Expand Down Expand Up @@ -440,12 +444,13 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
cfs_preds_sparse = []

for cf_ix in list(final_cfs_sparse.index):
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
for feature in features_sorted:
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
if(abs(diff) <= quantiles[feature]):
if (abs(diff) <= quantiles[feature]):
if posthoc_sparsity_algorithm == "linear":
final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix,
feature, final_cfs_sparse, current_pred)
Expand All @@ -466,13 +471,14 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
query_instance greedily until the prediction class changes."""

old_diff = diff
change = (10**-decimal_prec[feature]) # the minimal possible change for a feature
change = (10 ** -decimal_prec[feature]) # the minimal possible change for a feature
current_pred = current_pred_orig
if self.model.model_type == ModelTypes.Classifier:
while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)):
while ((abs(diff) > 10e-4) and (np.sign(diff * old_diff) > 0) and self.is_cf_valid(current_pred)):
old_val = int(final_cfs_sparse.at[cf_ix, feature])
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff) * change
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
old_diff = diff

if not self.is_cf_valid(current_pred):
Expand Down Expand Up @@ -505,11 +511,12 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
right = query_instance[feature].iat[0]

while left <= right:
current_val = left + ((right - left)/2)
current_val = left + ((right - left) / 2)
current_val = round(current_val, decimal_prec[feature])

final_cfs_sparse.at[cf_ix, feature] = current_val
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])

if current_val == right or current_val == left:
break
Expand All @@ -524,19 +531,20 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
right = int(final_cfs_sparse.at[cf_ix, feature])

while right >= left:
current_val = right - ((right - left)/2)
current_val = right - ((right - left) / 2)
current_val = round(current_val, decimal_prec[feature])

final_cfs_sparse.at[cf_ix, feature] = current_val
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])

if current_val == right or current_val == left:
break

if self.is_cf_valid(current_pred):
right = current_val - (10**-decimal_prec[feature])
right = current_val - (10 ** -decimal_prec[feature])
else:
left = current_val + (10**-decimal_prec[feature])
left = current_val + (10 ** -decimal_prec[feature])

return final_cfs_sparse

Expand Down Expand Up @@ -578,7 +586,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_
raise UserConfigValidationException("Desired class not present in training data!")
else:
raise UserConfigValidationException("The target class for {0} could not be identified".format(
desired_class_input))
desired_class_input))

def infer_target_cfs_range(self, desired_range_input):
target_range = None
Expand All @@ -597,7 +605,7 @@ def decide_cf_validity(self, model_outputs):
pred = model_outputs[i]
if self.model.model_type == ModelTypes.Classifier:
if self.num_output_nodes == 2: # binary
pred_1 = pred[self.num_output_nodes-1]
pred_1 = pred[self.num_output_nodes - 1]
validity[i] = 1 if \
((self.target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
(self.target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else 0
Expand Down Expand Up @@ -634,7 +642,7 @@ def is_cf_valid(self, model_score):
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
return validity
if self.num_output_nodes == 2: # binary
pred_1 = model_score[self.num_output_nodes-1]
pred_1 = model_score[self.num_output_nodes - 1]
validity = True if \
((target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
Expand Down Expand Up @@ -710,7 +718,8 @@ def round_to_precision(self):
for ix, feature in enumerate(self.data_interface.continuous_feature_names):
self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix])
if self.final_cfs_df_sparse is not None:
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(
precisions[ix])

def _check_any_counterfactuals_computed(self, cf_examples_arr):
"""Check if any counterfactuals were generated for any query point."""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import nbformat
import pytest

NOTEBOOKS_PATH = "../docs/source/notebooks/"
NOTEBOOKS_PATH = "docs/source/notebooks/"
notebooks_list = [f.name for f in os.scandir(NOTEBOOKS_PATH) if f.name.endswith(".ipynb")]
# notebooks that should not be run
advanced_notebooks = [
Expand Down

0 comments on commit 00dda52

Please sign in to comment.