Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sklearn PLS Regression incompatibility with ExplainerDashboard #258

Open
shyam-bayer opened this issue Mar 30, 2023 · 4 comments
Open

Sklearn PLS Regression incompatibility with ExplainerDashboard #258

shyam-bayer opened this issue Mar 30, 2023 · 4 comments

Comments

@shyam-bayer
Copy link

I would like to use PLS regression with the Explainer dashboard package. However, it throws an error which I can't address. It seems like there is a compatibility issue. Could you please confirm if PLS regression is compatible or not.

Below is my script:

from sklearn.cross_decomposition import PLSRegression
from sklearn.datasets import load_diabetes
from explainerdashboard import ExplainerDashboard, RegressionExplainer
import numpy as np
from sklearn import linear_model
diabetes_X, diabetes_y = load_diabetes(as_frame=True, return_X_y=True)
regr = PLSRegression(n_components=2)
regr.fit(diabetes_X_flat, diabetes_y)
explainer = RegressionExplainer(regr, diabetes_X, diabetes_y)
db = ExplainerDashboard(explainer)

I am getting the following error:
Building ExplainerDashboard..
Detected notebook environment, consider setting mode='external', mode='inline' or mode='jupyterlab' to keep the notebook interactive while the dashboard is running...
For this type of model and model_output interactions don't work, so setting shap_interaction=False...
The explainer object has no decision_trees property. so setting decision_trees=False...
Generating layout...
Calculating
shap
values...
100%|██████████| 442/442 [01:22<00:00, 5.33it/s]

ValueError Traceback (most recent call last)
Cell In[38], line 1
----> 1 db = ExplainerDashboard(explainer)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboards.py:795, in ExplainerDashboard.init(self, explainer, tabs, title, name, description, simple, hide_header, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, mode, width, height, bootstrap, external_stylesheets, server, url_base_pathname, routes_pathname_prefix, requests_pathname_prefix, responsive, logins, port, importances, model_summary, contributions, whatif, shap_dependence, shap_interaction, decision_trees, **kwargs)
793 if isinstance(tabs, list):
794 tabs = [self._convert_str_tabs(tab) for tab in tabs]
--> 795 self.explainer_layout = ExplainerTabsLayout(
796 explainer,
797 tabs,
798 title,
799 description=self.description,
800 **update_kwargs(
801 kwargs,
802 header_hide_title=self.header_hide_title,
803 header_hide_selector=self.header_hide_selector,
804 header_hide_download=self.header_hide_download,
805 hide_poweredby=self.hide_poweredby,
806 block_selector_callbacks=self.block_selector_callbacks,
807 pos_label=self.pos_label,
808 fluid=fluid,
809 ),
810 )
811 else:
812 tabs = self._convert_str_tabs(tabs)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboards.py:111, in ExplainerTabsLayout.init(self, explainer, tabs, title, name, description, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, **kwargs)
108 self.fluid = fluid
110 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label)
--> 111 self.tabs = [
112 instantiate_component(tab, explainer, name=str(i + 1), **kwargs)
113 for i, tab in enumerate(tabs)
114 ]
115 assert (
116 len(self.tabs) > 0
117 ), "When passing a list to tabs, need to pass at least one valid tab!"
119 self.register_components(*self.tabs)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboards.py:112, in (.0)
108 self.fluid = fluid
110 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label)
111 self.tabs = [
--> 112 instantiate_component(tab, explainer, name=str(i + 1), **kwargs)
113 for i, tab in enumerate(tabs)
114 ]
115 assert (
116 len(self.tabs) > 0
117 ), "When passing a list to tabs, need to pass at least one valid tab!"
119 self.register_components(*self.tabs)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboard_methods.py:890, in instantiate_component(component, explainer, name, **kwargs)
884 kwargs = {
885 k: v
886 for k, v in kwargs.items()
887 if k in init_argspec.args + init_argspec.kwonlyargs
888 }
889 if "name" in init_argspec.args + init_argspec.kwonlyargs:
--> 890 component = component(explainer, name=name, **kwargs)
891 else:
892 print(
893 f"ExplainerComponent {component} does not accept a name parameter, "
894 f"so cannot assign name='{name}': "
(...)
899 "cluster will generate its own random uuid name!"
900 )

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboard_components/composites.py:413, in RegressionModelStatsComposite.init(self, explainer, title, name, hide_title, hide_modelsummary, hide_predsvsactual, hide_residuals, hide_regvscol, logs, pred_or_actual, residuals, col, **kwargs)
403 self.preds_vs_actual = PredictedVsActualComponent(
404 explainer, name=self.name + "0", logs=logs, **kwargs
405 )
406 self.residuals = ResidualsComponent(
407 explainer,
408 name=self.name + "1",
(...)
411 **kwargs,
412 )
--> 413 self.reg_vs_col = RegressionVsColComponent(
414 explainer, name=self.name + "2", logs=logs, **kwargs
415 )

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboard_components/regression_components.py:1676, in RegressionVsColComponent.init(self, explainer, title, name, subtitle, hide_title, hide_subtitle, hide_footer, hide_col, hide_ratio, hide_points, hide_winsor, hide_cats_topx, hide_cats_sort, hide_popout, col, display, round, points, winsor, cats_topx, cats_sort, plot_sample, description, **kwargs)
1673 super().init(explainer, title, name)
1675 if self.col is None:
-> 1676 self.col = self.explainer.columns_ranked_by_shap()[0]
1678 assert self.display in {
1679 "observed",
1680 "predicted",
(...)
1686 f" but you passed display={self.display}!"
1687 )
1689 if self.description is None:

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:66, in insert_pos_label..inner(self, *args, **kwargs)
63 @wraps(func)
64 def inner(self, *args, **kwargs):
65 if not self.is_classifier:
---> 66 return func(self, *args, **kwargs)
67 if "pos_label" in kwargs:
68 if kwargs["pos_label"] is not None:
69 # ensure that pos_label is int

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:1310, in BaseExplainer.columns_ranked_by_shap(self, pos_label)
1298 @insert_pos_label
1299 def columns_ranked_by_shap(self, pos_label=None):
1300 """returns the columns of X, ranked by mean abs shap value
1301
1302 Args:
(...)
1308
1309 """
-> 1310 return self.mean_abs_shap_df(pos_label).Feature.tolist()

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:66, in insert_pos_label..inner(self, *args, **kwargs)
63 @wraps(func)
64 def inner(self, *args, **kwargs):
65 if not self.is_classifier:
---> 66 return func(self, *args, **kwargs)
67 if "pos_label" in kwargs:
68 if kwargs["pos_label"] is not None:
69 # ensure that pos_label is int

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:1287, in BaseExplainer.mean_abs_shap_df(self, pos_label)
1284 """Mean absolute SHAP values per feature."""
1285 if not hasattr(self, "_mean_abs_shap_df"):
1286 self._mean_abs_shap_df = (
-> 1287 self.get_shap_values_df(pos_label)[self.merged_cols]
1288 .abs()
1289 .mean()
1290 .sort_values(ascending=False)
1291 .to_frame()
1292 .rename_axis(index="Feature")
1293 .reset_index()
1294 .rename(columns={0: "MEAN_ABS_SHAP"})
1295 )
1296 return self._mean_abs_shap_df

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:66, in insert_pos_label..inner(self, *args, **kwargs)
63 @wraps(func)
64 def inner(self, *args, **kwargs):
65 if not self.is_classifier:
---> 66 return func(self, *args, **kwargs)
67 if "pos_label" in kwargs:
68 if kwargs["pos_label"] is not None:
69 # ensure that pos_label is int

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:1151, in BaseExplainer.get_shap_values_df(self, pos_label)
1144 self._shap_values_df = pd.DataFrame(
1145 self.shap_explainer.shap_values(
1146 torch.tensor(self.X.values), **self.shap_kwargs
1147 ),
1148 columns=self.columns,
1149 )
1150 else:
-> 1151 self._shap_values_df = pd.DataFrame(
1152 self.shap_explainer.shap_values(self.X, **self.shap_kwargs),
1153 columns=self.columns,
1154 )
1155 self._shap_values_df = merge_categorical_shap_values(
1156 self._shap_values_df, self.onehot_dict, self.merged_cols
1157 ).astype(self.precision)
1158 return self._shap_values_df

File /usr/local/lib/python3.10/site-packages/pandas/core/frame.py:762, in DataFrame.init(self, data, index, columns, dtype, copy)
754 mgr = arrays_to_mgr(
755 arrays,
756 columns,
(...)
759 typ=manager,
760 )
761 else:
--> 762 mgr = ndarray_to_mgr(
763 data,
764 index,
765 columns,
766 dtype=dtype,
767 copy=copy,
768 typ=manager,
769 )
770 else:
771 mgr = dict_to_mgr(
772 {},
773 index,
(...)
776 typ=manager,
777 )

File /usr/local/lib/python3.10/site-packages/pandas/core/internals/construction.py:329, in ndarray_to_mgr(values, index, columns, dtype, copy, typ)
324 values = values.reshape(-1, 1)
326 else:
327 # by definition an array here
328 # the dtypes will be coerced to a single dtype
--> 329 values = _prep_ndarraylike(values, copy=copy_on_sanitize)
331 if dtype is not None and not is_dtype_equal(values.dtype, dtype):
332 # GH#40110 see similar check inside sanitize_array
333 rcf = not (is_integer_dtype(dtype) and values.dtype.kind == "f")

File /usr/local/lib/python3.10/site-packages/pandas/core/internals/construction.py:583, in _prep_ndarraylike(values, copy)
581 values = values.reshape((values.shape[0], 1))
582 elif values.ndim != 2:
--> 583 raise ValueError(f"Must pass 2-d input. shape={values.shape}")
585 return values

ValueError: Must pass 2-d input. shape=(1, 442, 10)

List of packages:
Package Version


ansi2html 1.8.0
asttokens 2.2.1
attrs 22.2.0
backcall 0.2.0
certifi 2022.12.7
charset-normalizer 3.1.0
click 8.1.3
cloudpickle 2.2.1
colour 0.1.5
comm 0.1.3
contourpy 1.0.7
cycler 0.11.0
dash 2.9.2
dash-auth 2.0.0
dash-bootstrap-components 1.4.1
dash-core-components 2.0.0
dash-html-components 2.0.0
dash-table 5.0.0
debugpy
1.6.6
decorator 5.1.1
dtreeviz
2.2.0
exceptiongroup 1.1.1
executing 1.2.0
explainerdashboard 0.4.2.1
Flask 2.2.3
flask-simplelogin
0.1.1
Flask-WTF 0.15.1
fonttools 4.39.3
graphviz 0.20.1
idna 3.4
iniconfig 2.0.0
ipykernel 6.22.0
ipython 8.12.0
itsdangerous 2.1.2
jedi 0.18.2
Jinja2 3.1.2
joblib 1.2.0
jupyter_client 8.1.0
jupyter_core 5.3.0
jupyter-dash 0.4.2
kiwisolver 1.4.4
llvmlite 0.39.1
MarkupSafe 2.1.2
matplotlib 3.7.1
matplotlib-inline 0.1.6
nest-asyncio 1.5.6
numba 0.56.4
numpy 1.23.5
oyaml 1.0
packaging 23.0
pandas 1.5.3
parso 0.8.3
pexpect 4.8.0
pickleshare 0.7.5
Pillow 9.4.0
pip 22.3.1
platformdirs
3.2.0
plotly 5.14.0
pluggy 1.0.0
prompt-toolkit 3.0.38
psutil 5.9.4
ptyprocess 0.7.0
pure-eval 0.2.2
Pygments 2.14.0
pyparsing 3.0.9
pytest 7.2.2
python-dateutil 2.8.2
pytz 2023.3
PyYAML 6.0
pyzmq 25.0.2
requests 2.28.2
retrying 1.3.4
scikit-learn 1.2.2
scipy 1.10.1
setuptools 67.4.0
shap 0.41.0
six 1.16.0
slicer 0.0.7
stack-data 0.6.2
tenacity 8.2.2
threadpoolctl 3.1.0
tomli 2.0.1
tornado 6.2
tqdm 4.65.0
traitlets 5.9.0
urllib3 1.26.15
waitress 2.1.2
wcwidth 0.2.6
Werkzeug 2.2.3
wheel 0.38.4
WTForms 3.0.1

@oegedijk
Copy link
Owner

oegedijk commented May 9, 2023

Hi @shyam-bayer, my guess is due the fact that PLSRegression does not return a single prediction but can return multiple components. This means that model.predict(X_test) will return a two-dimensional numpy array instead of a single dimensional one, which results in the errors.

@oegedijk
Copy link
Owner

oegedijk commented May 9, 2023

this seems to fix it, at least as long the PLSRegressor has a single component: 3f4a9a0

@oegedijk
Copy link
Owner

oegedijk commented May 9, 2023

so should be in the next release

@shyam-bayer
Copy link
Author

Thanks for looking into this. However, Like any other regression model, PLSRegression also returns single "ypred" output if there is only one "y" column. PLSRegression can use multiple latent variables but that will not effect shape of "ypred" vector/matrix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants