Skip to content

Commit

Permalink
import update
Browse files Browse the repository at this point in the history
  • Loading branch information
giandos200 committed Jan 13, 2022
1 parent 2a88433 commit b440eb1
Show file tree
Hide file tree
Showing 21 changed files with 132 additions and 86 deletions.
2 changes: 1 addition & 1 deletion dice_ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .data import Data
from .model import Model
from .dice import Dice
from .model import Model

__all__ = ["Data",
"Model",
Expand Down
9 changes: 5 additions & 4 deletions dice_ml/counterfactual_explanations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import jsonschema
import os

from dice_ml.diverse_counterfactuals import CounterfactualExamples
from dice_ml.utils.exception import UserConfigValidationException
from dice_ml.diverse_counterfactuals import _DiverseCFV2SchemaConstants
import jsonschema

from dice_ml.constants import _SchemaVersions
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
_DiverseCFV2SchemaConstants)
from dice_ml.utils.exception import UserConfigValidationException


class _CommonSchemaConstants:
Expand Down
7 changes: 4 additions & 3 deletions dice_ml/data_interfaces/private_data_interface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Module containing meta data information about private data."""

import sys
import pandas as pd
import numpy as np
import collections
import logging
import sys

import numpy as np
import pandas as pd

from dice_ml.data_interfaces.base_data_interface import _BaseData

Expand Down
11 changes: 6 additions & 5 deletions dice_ml/data_interfaces/public_data_interface.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Module containing all required information about the interface between raw (or transformed)
public data and DiCE explainers."""

import pandas as pd
import numpy as np
import logging
from collections import defaultdict

from dice_ml.data_interfaces.base_data_interface import _BaseData
from dice_ml.utils.exception import SystemException, UserConfigValidationException
import numpy as np
import pandas as pd

from dice_ml.data_interfaces.base_data_interface import _BaseData
from dice_ml.utils.exception import (SystemException,
UserConfigValidationException)

class PublicData(_BaseData):
"""A data interface for public data. This class is an interface to DiCE explainers
Expand Down Expand Up @@ -258,7 +259,7 @@ def get_valid_feature_range(self, feature_range_input, normalized=True):
"""
feature_range = {}

for idx, feature_name in enumerate(self.feature_names):
for _, feature_name in enumerate(self.feature_names):
feature_range[feature_name] = []
if feature_name in self.continuous_feature_names:
max_value = self.data_df[feature_name].max()
Expand Down
8 changes: 5 additions & 3 deletions dice_ml/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
such as RandomSampling, DiCEKD or DiCEGenetic"""

from dice_ml.constants import BackEndTypes, SamplingStrategy
from dice_ml.utils.exception import UserConfigValidationException
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
from dice_ml.data_interfaces.private_data_interface import PrivateData
from dice_ml.utils.exception import UserConfigValidationException


class Dice(ExplainerBase):
Expand Down Expand Up @@ -67,12 +67,14 @@ def decide(model_interface, method):

elif model_interface.backend == BackEndTypes.Tensorflow1:
# pretrained Keras Sequential model with Tensorflow 1.x backend
from dice_ml.explainer_interfaces.dice_tensorflow1 import DiceTensorFlow1
from dice_ml.explainer_interfaces.dice_tensorflow1 import \
DiceTensorFlow1
return DiceTensorFlow1

elif model_interface.backend == BackEndTypes.Tensorflow2:
# pretrained Keras Sequential model with Tensorflow 2.x backend
from dice_ml.explainer_interfaces.dice_tensorflow2 import DiceTensorFlow2
from dice_ml.explainer_interfaces.dice_tensorflow2 import \
DiceTensorFlow2
return DiceTensorFlow2

elif model_interface.backend == BackEndTypes.Pytorch:
Expand Down
7 changes: 5 additions & 2 deletions dice_ml/diverse_counterfactuals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pandas as pd
import copy
import json

import pandas as pd

from dice_ml.constants import ModelTypes, _SchemaVersions
from dice_ml.utils.serialize import DummyDataInterface
from dice_ml.constants import _SchemaVersions, ModelTypes


class _DiverseCFV1SchemaConstants:
Expand Down Expand Up @@ -115,6 +117,7 @@ def _visualize_internal(self, display_sparse_df=True, show_only_changes=False,

def visualize_as_dataframe(self, display_sparse_df=True, show_only_changes=False):
from IPython.display import display

# original instance
print('Query instance (original outcome : %i)' % round(self.test_pred))
display(self.test_instance_df) # works only in Jupyter notebook
Expand Down
7 changes: 4 additions & 3 deletions dice_ml/explainer_interfaces/dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
Module to generate counterfactual explanations from a KD-Tree
This code is similar to 'Interpretable Counterfactual Explanations Guided by Prototypes': https://arxiv.org/pdf/1907.02584.pdf
"""
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
import numpy as np
import copy
import timeit

import numpy as np
import pandas as pd
import copy

from dice_ml import diverse_counterfactuals as exp
from dice_ml.constants import ModelTypes
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


class DiceKD(ExplainerBase):
Expand Down
9 changes: 5 additions & 4 deletions dice_ml/explainer_interfaces/dice_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
Module to generate diverse counterfactual explanations based on genetic algorithm
This code is similar to 'GeCo: Quality Counterfactual Explanations in Real Time': https://arxiv.org/pdf/2101.01292.pdf
"""
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
import numpy as np
import pandas as pd
import copy
import random
import timeit
import copy

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

from dice_ml import diverse_counterfactuals as exp
from dice_ml.constants import ModelTypes
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


class DiceGenetic(ExplainerBase):
Expand Down
10 changes: 5 additions & 5 deletions dice_ml/explainer_interfaces/dice_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
Module to generate diverse counterfactual explanations based on PyTorch framework
"""
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
import torch

import numpy as np
import copy
import random
import timeit
import copy
import numpy as np

import torch

from dice_ml import diverse_counterfactuals as exp
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


class DicePyTorch(ExplainerBase):
Expand Down
7 changes: 4 additions & 3 deletions dice_ml/explainer_interfaces/dice_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
Module to generate diverse counterfactual explanations based on random sampling.
A simple implementation.
"""
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
import numpy as np
import pandas as pd
import random
import timeit

import numpy as np
import pandas as pd

from dice_ml import diverse_counterfactuals as exp
from dice_ml.constants import ModelTypes
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


class DiceRandom(ExplainerBase):
Expand Down
12 changes: 6 additions & 6 deletions dice_ml/explainer_interfaces/dice_tensorflow1.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""
Module to generate diverse counterfactual explanations based on tensorflow 1.x
"""
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
import tensorflow as tf

import numpy as np
import random
import collections
import timeit
import copy
import random
import timeit

import numpy as np
import tensorflow as tf

from dice_ml import diverse_counterfactuals as exp
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


class DiceTensorFlow1(ExplainerBase):
Expand Down
12 changes: 6 additions & 6 deletions dice_ml/explainer_interfaces/dice_tensorflow2.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
Module to generate diverse counterfactual explanations based on tensorflow 2.x
"""
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
import tensorflow as tf

import numpy as np
import copy
import random
import timeit
import copy

import numpy as np
import tensorflow as tf

from dice_ml import diverse_counterfactuals as exp
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


class DiceTensorFlow2(ExplainerBase):
Expand Down Expand Up @@ -177,7 +177,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
# CF initialization
if len(self.cfs) != self.total_CFs:
self.cfs = []
for ix in range(self.total_CFs):
for _ in range(self.total_CFs):
one_init = [[]]
for jx in range(self.minx.shape[1]):
one_init[0].append(np.random.uniform(self.minx[0][jx], self.maxx[0][jx]))
Expand Down
54 changes: 41 additions & 13 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch.
All methods are in dice_ml.explainer_interfaces"""

import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable

import numpy as np
import pandas as pd
from sklearn.neighbors import KDTree
from tqdm import tqdm

from collections.abc import Iterable
from sklearn.neighbors import KDTree
from dice_ml.constants import ModelTypes
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.utils.exception import UserConfigValidationException
from dice_ml.constants import ModelTypes


class ExplainerBase(ABC):
Expand Down Expand Up @@ -85,6 +85,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
if posthoc_sparsity_algorithm == None:
posthoc_sparsity_algorithm = 'binary'
elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear':
import warnings
warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
"'binary' search!".format(total_CFs))
Expand All @@ -98,6 +99,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
query_instances_list.append(query_instances[ix:(ix+1)])
elif isinstance(query_instances, Iterable):
query_instances_list = query_instances

for query_instance in tqdm(query_instances_list):
self.data_interface.set_continuous_feature_indexes(query_instance)
res = self._generate_counterfactuals(
Expand All @@ -112,6 +114,9 @@ def generate_counterfactuals(self, query_instances, total_CFs,
verbose=verbose,
**kwargs)
cf_examples_arr.append(res)

self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)

return CounterfactualExplanations(cf_examples_list=cf_examples_arr)

@abstractmethod
Expand Down Expand Up @@ -217,10 +222,12 @@ def local_feature_importance(self, query_instances, cf_examples_list=None,
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance should be "
"greater than or equal to 10")
"greater than or equal to 10 to compute feature importance for all query points")
elif total_CFs < 10:
raise UserConfigValidationException("The number of counterfactuals generated per "
"query instance should be greater than or equal to 10")
raise UserConfigValidationException(
"The number of counterfactuals requested per "
"query instance should be greater than or equal to 10 "
"to compute feature importance for all query points")
importances = self.feature_importance(
query_instances,
cf_examples_list=cf_examples_list,
Expand Down Expand Up @@ -261,16 +268,25 @@ def global_feature_importance(self, query_instances, cf_examples_list=None,
input, and the global feature importance summarized over all inputs.
"""
if query_instances is not None and len(query_instances) < 10:
raise UserConfigValidationException("The number of query instances should be greater than or equal to 10")
raise UserConfigValidationException(
"The number of query instances should be greater than or equal to 10 "
"to compute global feature importance over all query points")
if cf_examples_list is not None:
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
if len(cf_examples_list) < 10:
raise UserConfigValidationException(
"The number of points for which counterfactuals generated should be "
"greater than or equal to 10 "
"to compute global feature importance")
elif any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance should be "
"greater than or equal to 10")
"greater than or equal to 10"
"to compute global feature importance over all query points")
elif total_CFs < 10:
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance should be greater "
"than or equal to 10")
"than or equal to 10"
"to compute global feature importance over all query points")
importances = self.feature_importance(
query_instances,
cf_examples_list=cf_examples_list,
Expand Down Expand Up @@ -349,7 +365,7 @@ def feature_importance(self, query_instances, cf_examples_list=None,
continue

per_query_point_cfs = 0
for index, row in df.iterrows():
for _, row in df.iterrows():
per_query_point_cfs += 1
for col in self.data_interface.continuous_feature_names:
if not np.isclose(org_instance[col].iat[0], row[col]):
Expand Down Expand Up @@ -530,7 +546,7 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred)
self.target_cf_class = np.array(
[[self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)]],
dtype=np.float32)
desired_class = self.target_cf_class[0][0]
desired_class = int(self.target_cf_class[0][0])
if self.target_cf_class == 0 and self.stopping_threshold > 0.5:
self.stopping_threshold = 0.25
elif self.target_cf_class == 1 and self.stopping_threshold < 0.5:
Expand Down Expand Up @@ -695,3 +711,15 @@ def round_to_precision(self):
self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix])
if self.final_cfs_df_sparse is not None:
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])

def _check_any_counterfactuals_computed(self, cf_examples_arr):
"""Check if any counterfactuals were generated for any query point."""
no_cf_generated = True
# Check if any counterfactuals were generated for any query point
for cf_examples in cf_examples_arr:
if cf_examples.final_cfs_df is not None and len(cf_examples.final_cfs_df) > 0:
no_cf_generated = False
break
if no_cf_generated:
raise UserConfigValidationException(
"No counterfactuals found for any of the query points! Kindly check your configuration.")

0 comments on commit b440eb1

Please sign in to comment.