-
Notifications
You must be signed in to change notification settings - Fork 683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reproducible error: SHAP ExplainerError: Additivity check failed in TreeExplainer #873
Comments
Thanks for the code. I was trying to write a consistent unit test for PR #872 (which fixes this issue and #866); however, I'm unable to reproduce the error with the following code (structured as a unit test in def test_check_additivity_handling(self):
np.random.seed(0)
n = 100
X = np.hstack([np.random.randint(2, size=(n, 1)), np.random.randint(100, size=(n, 1))])
T = np.random.randint(2, size=(n,))
Y = np.random.randint(2, size=(n,))
X_train, _, T_train, _, Y_train, _ = train_test_split(X, T, Y, test_size=0.2, random_state=42)
scaler = MinMaxScaler()
X_train[:, 1:2] = scaler.fit_transform(X_train[:, 1:2])
est = CausalForestDML(
model_y=RandomForestClassifier(),
model_t=DummyClassifier(strategy="uniform"),
random_state=123,
discrete_outcome=True,
discrete_treatment=True
)
est.fit(Y_train, T_train, X=X_train)
shap_values = est.shap_values(X_train[:20])
assert shap_values is not None As discussed in my second comment in #866, I don't believe there's a consistent way to reproduce this error across machines—for various reasons. However, #872 will fix this issue regardless. |
@yanisvdc Do you have the literal values for Could you please try the following (on the set of variables that cause the error): arrays = {
"X": X_train, # after scaling (potentially call .to_numpy() on it if it's a df)
"T": T_train.squeeze(),
"Y": y_train.squeeze(),
}
for name, arr in arrays.items():
mlw = 20 if name != "X" else 50
arr_str = np.array2string(arr, separator=", ", floatmode="maxprec", precision=16, max_line_width=mlw)
print(f"{name}:\n\n{arr_str}") and copy the results here in a markdown codeblock for each variable? |
Thanks for your comments.
Error message:
Please let me know if you experience the same behavior. |
To answer this, here are the exact values:
|
This is related to issue #866
If you try to comment out or remove the lines:
the code runs correctly.
It seems that the scaler induces an issue, it is the same with StandardScaler(). My guess is that float rounding errors are responsible for this behavior. I believe that there is a PR to add check_additivity = False, which seems to be the only way to resolve this, unless I am missing something.
The text was updated successfully, but these errors were encountered: