Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shap_values should be 2d, instead shape=(200, 21, 2)! #297

Open
harshil17 opened this issue Mar 9, 2024 · 11 comments
Open

shap_values should be 2d, instead shape=(200, 21, 2)! #297

harshil17 opened this issue Mar 9, 2024 · 11 comments

Comments

@harshil17
Copy link

I am running the sample code same as it's given here https://github.com/oegedijk/explainerdashboard, using titanic datasource.

And running into the error saying "shap_values should be 2d, instead shape=(200, 21, 2)!"

Attached is the full error trace. can pleas anyone help me understand why i am getting this error and how can i resolve it ?

`AssertionError Traceback (most recent call last)
Cell In[7], line 12
1 explainer = ClassifierExplainer(model, X_test, y_test,
2 cats=['Deck', 'Embarked',
3 {'Gender': ['Sex_male', 'Sex_female', 'Sex_nan']}],
(...)
9 target = "Survival", # defaults to y.name
10 )
---> 12 db = ExplainerDashboard(explainer,
13 title="Titanic Explainer", # defaults to "Model Explainer"
14 shap_interaction=False, # you can switch off tabs with bools
15 )
16 db.run(port=8050)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboards.py:803, in ExplainerDashboard.init(self, explainer, tabs, title, name, description, simple, hide_header, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, mode, width, height, bootstrap, external_stylesheets, server, url_base_pathname, routes_pathname_prefix, requests_pathname_prefix, responsive, logins, port, importances, model_summary, contributions, whatif, shap_dependence, shap_interaction, decision_trees, **kwargs)
801 if isinstance(tabs, list):
802 tabs = [self._convert_str_tabs(tab) for tab in tabs]
--> 803 self.explainer_layout = ExplainerTabsLayout(
804 explainer,
805 tabs,
806 title,
807 description=self.description,
808 **update_kwargs(
809 kwargs,
810 header_hide_title=self.header_hide_title,
811 header_hide_selector=self.header_hide_selector,
812 header_hide_download=self.header_hide_download,
813 hide_poweredby=self.hide_poweredby,
814 block_selector_callbacks=self.block_selector_callbacks,
815 pos_label=self.pos_label,
816 fluid=fluid,
817 ),
818 )
819 else:
820 tabs = self._convert_str_tabs(tabs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboards.py:119, in ExplainerTabsLayout.init(self, explainer, tabs, title, name, description, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, **kwargs)
116 self.fluid = fluid
118 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label)
--> 119 self.tabs = [
120 instantiate_component(tab, explainer, name=str(i + 1), **kwargs)
121 for i, tab in enumerate(tabs)
122 ]
123 assert (
124 len(self.tabs) > 0
125 ), "When passing a list to tabs, need to pass at least one valid tab!"
127 self.register_components(*self.tabs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboards.py:120, in (.0)
116 self.fluid = fluid
118 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label)
119 self.tabs = [
--> 120 instantiate_component(tab, explainer, name=str(i + 1), **kwargs)
121 for i, tab in enumerate(tabs)
122 ]
123 assert (
124 len(self.tabs) > 0
125 ), "When passing a list to tabs, need to pass at least one valid tab!"
127 self.register_components(*self.tabs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboard_methods.py:890, in instantiate_component(component, explainer, name, **kwargs)
884 kwargs = {
885 k: v
886 for k, v in kwargs.items()
887 if k in init_argspec.args + init_argspec.kwonlyargs
888 }
889 if "name" in init_argspec.args + init_argspec.kwonlyargs:
--> 890 component = component(explainer, name=name, **kwargs)
891 else:
892 print(
893 f"ExplainerComponent {component} does not accept a name parameter, "
894 f"so cannot assign name='{name}': "
(...)
899 "cluster will generate its own random uuid name!"
900 )

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboard_components\composites.py:545, in IndividualPredictionsComposite.init(self, explainer, title, name, hide_predindexselector, hide_predictionsummary, hide_contributiongraph, hide_pdp, hide_contributiontable, hide_title, hide_selector, index_check, **kwargs)
538 self.summary = RegressionPredictionSummaryComponent(
539 explainer, hide_selector=hide_selector, **kwargs
540 )
542 self.contributions = ShapContributionsGraphComponent(
543 explainer, hide_selector=hide_selector, **kwargs
544 )
--> 545 self.pdp = PdpComponent(
546 explainer, name=self.name + "3", hide_selector=hide_selector, **kwargs
547 )
548 self.contributions_list = ShapContributionsTableComponent(
549 explainer, hide_selector=hide_selector, **kwargs
550 )
552 self.index_connector = IndexConnector(
553 self.index,
554 [self.summary, self.contributions, self.pdp, self.contributions_list],
555 explainer=explainer if index_check else None,
556 )

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboard_components\overview_components.py:639, in PdpComponent.init(self, explainer, title, name, subtitle, hide_col, hide_index, hide_title, hide_subtitle, hide_footer, hide_selector, hide_popout, hide_dropna, hide_sample, hide_gridlines, hide_gridpoints, hide_cats_sort, index_dropdown, feature_input_component, pos_label, col, index, dropna, sample, gridlines, gridpoints, cats_sort, description, **kwargs)
636 self.index_name = "pdp-index-" + self.name
638 if self.col is None:
--> 639 self.col = self.explainer.columns_ranked_by_shap()[0]
641 if self.feature_input_component is not None:
642 self.exclude_callbacks(self.feature_input_component)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:86, in insert_pos_label..inner(self, *args, **kwargs)
84 else:
85 kwargs.update(dict(pos_label=self.pos_label))
---> 86 return func(self, **kwargs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:1318, in BaseExplainer.columns_ranked_by_shap(self, pos_label)
1306 @insert_pos_label
1307 def columns_ranked_by_shap(self, pos_label=None):
1308 """returns the columns of X, ranked by mean abs shap value
1309
1310 Args:
(...)
1316
1317 """
-> 1318 return self.mean_abs_shap_df(pos_label).Feature.tolist()

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:86, in insert_pos_label..inner(self, *args, **kwargs)
84 else:
85 kwargs.update(dict(pos_label=self.pos_label))
---> 86 return func(self, **kwargs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:3128, in ClassifierExplainer.mean_abs_shap_df(self, pos_label)
3126 """mean absolute SHAP values"""
3127 if not hasattr(self, "_mean_abs_shap_df"):
-> 3128 _ = self.get_shap_values_df()
3129 self._mean_abs_shap_df = [
3130 self.get_shap_values_df(pos_label)[self.merged_cols]
3131 .abs()
(...)
3138 for pos_label in self.labels
3139 ]
3140 return self._mean_abs_shap_df[pos_label]

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:86, in insert_pos_label..inner(self, *args, **kwargs)
84 else:
85 kwargs.update(dict(pos_label=self.pos_label))
---> 86 return func(self, **kwargs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:2845, in ClassifierExplainer.get_shap_values_df(self, pos_label)
2843 if len(self.labels) == 2:
2844 if not isinstance(_shap_values, list):
-> 2845 assert (
2846 len(_shap_values.shape) == 2
2847 ), f"shap_values should be 2d, instead shape={_shap_values.shape}!"
2848 elif isinstance(_shap_values, list) and len(_shap_values) == 2:
2849 # for binary classifier only keep positive class
2850 _shap_values = _shap_values[1]

AssertionError: shap_values should be 2d, instead shape=(200, 21, 2)!`

@Brian-AlphaPlay
Copy link

Brian-AlphaPlay commented Mar 10, 2024

I have the same error as well. Running code that had no changes to it and worked fine before.

EDIT: Looks to be caused by "breaking change" in the newest version of shap: https://github.com/shap/shap/releases
Issue does not appear when downgrading to shap==0.44.1

@oegedijk
Copy link
Owner

Yes, tests are failing as well. Seems like the output shape of the shap library has changed. Will look into it...

@oegedijk
Copy link
Owner

okay, have the fix on master. Will see if I can release tomorrow...

@oegedijk
Copy link
Owner

okay it's released: 0.4.6. I actually had to change the github pypi release mechanism, so I would appreciate it if you could let me know that it worked!

@harshil17
Copy link
Author

harshil17 commented Mar 12, 2024

Thanks, So i just tried with 0.4.6 and it's still giving me the same error ?

@oegedijk
Copy link
Owner

pip install -U explainerdashboard should install version 0.4.6 or are you using conda? (it takes about a day for the conda-forge CI to pick up the latest version and release it)

@harshil17
Copy link
Author

Yes, I am using conda. I see. I will wait until tomorrow then and see if it works.

@harshil17
Copy link
Author

harshil17 commented Mar 12, 2024

I do see though that i am upgraded to 0.4.6.

@oegedijk
Copy link
Owner

Should also be on conda now. What model are you using? Can you run the classifier example from the README?

@harshil17
Copy link
Author

just checked with classifier example from readme with updated version of explainerdashboard and it's still giving me the same error.

@harshil17
Copy link
Author

Nevermind, i just restart the jupyter and it worked. Thank you very much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants