Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KernelShap returns no explanations when link='logit' and predicted proba is 0 or 1 #952

Open
ascillitoe opened this issue Jul 17, 2023 · 0 comments
Labels

Comments

@ascillitoe
Copy link
Contributor

Issue

When the KernelShap explainer is configured with link='logit', and the predictor returns probabilities of 0 or 1, the logit link function raises errors due to logit(0) = -inf and logit(1) = inf.

Example error:

    240 @staticmethod
    241 def f(x):
--> 242     return np.log(x/(1-x))

ZeroDivisionError: float division by zero

This is a problem when batches of instances are passed to explain, because the errors raised by problem instances result in no instances being returned, even for instances where 0 > proba < 1.

Solution

When link='logit', we could patch the logit link function with something like:

from shap.utils._legacy import Link

class FixedLogitLink(Link):
    def __str__(self):
        return "logit"
    
    @staticmethod
    def f(x):
        with np.errstate(all='raise'):
            try:
                return np.log(x / (1 - x))
            except Exception as e:
                if x <= 0.0:
                    return -np.inf
                elif x >= 1.0:
                    return np.inf
                else:
                    raise e
            
    @staticmethod
    def finv(x):
        return 1/(1+np.exp(-x)) 

This would allow the explanations to run for instances who's predicted proba's are 0 < proba < 1. Problematic instances would just have np.nan explanations returned. For example:

>>> explanation.data['shap_values'][1]

array([[ 0.23972787, -0.01814982, -0.19931147, -0.70222757,  0.19566954,
        -0.52669155,  0.01747202, -0.41221566, -0.18493278, -0.04171934,
         0.19687651, -0.00301931],
       [        nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan]])

Related issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant