Skip to content

Commit

Permalink
Merge pull request #47 from VesnaT/ice_other_models
Browse files Browse the repository at this point in the history
ICE: Support all Orange models
  • Loading branch information
PrimozGodec committed Oct 18, 2022
2 parents d67f879 + b9cb59e commit 3c6b0f3
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 58 deletions.
20 changes: 14 additions & 6 deletions orangecontrib/explain/inspection.py
@@ -1,13 +1,13 @@
""" Permutation feature importance for models. """
from typing import Callable, Tuple, Optional, Dict
from typing import Callable, Dict

import numpy as np
import scipy.sparse as sp
from sklearn.inspection import partial_dependence

from Orange.base import Model, SklModel
from Orange.base import Model
from Orange.classification import Model as ClsModel
from Orange.data import Table, Variable, DiscreteVariable
from Orange.data import Table, Variable
from Orange.evaluation import Results
from Orange.evaluation.scoring import Score, TargetScore, RegressionScore, R2
from Orange.regression import Model as RegModel
Expand Down Expand Up @@ -178,7 +178,7 @@ def _calculate_permutation_scores(


def individual_condition_expectation(
model: SklModel,
model: Model,
data: Table,
feature: Variable,
grid_resolution: int = 1000,
Expand All @@ -194,10 +194,18 @@ def individual_condition_expectation(
assert feature.name in [a.name for a in data.domain.attributes]
feature_index = data.domain.index(feature.name)

assert isinstance(model, SklModel), f"Model ({model}) is not supported."
# fake sklearn estimator
model.fit = None
model.fit_ = None
if model.domain.class_var.is_discrete:
model._estimator_type = "classifier"
model.classes_ = model.domain.class_var.values
else:
model._estimator_type = "regressor"

progress_callback(0.1)

dep = partial_dependence(model.skl_model,
dep = partial_dependence(model,
data.X,
[feature_index],
grid_resolution=grid_resolution,
Expand Down
29 changes: 0 additions & 29 deletions orangecontrib/explain/tests/test_inspection.py
Expand Up @@ -350,35 +350,6 @@ def test_mixed_features(self):
self.assertEqual(res["individual"].shape, (2, 303, 41))
self.assertEqual(res["values"].shape, (41,))

def _test_sklearn(self):
from matplotlib import pyplot as plt
from sklearn.ensemble import RandomForestClassifier, \
RandomForestRegressor
from sklearn.inspection import PartialDependenceDisplay

X = self.housing.X
y = self.housing.Y
model = RandomForestRegressor(random_state=0)

# X = self.iris.X[:100]
# y = self.iris.Y[:100]
# y = np.abs(y - 1)
# model = RandomForestClassifier(random_state=0)
model.fit(X, y)
display = PartialDependenceDisplay.from_estimator(
model,
X,
[X.shape[1] - 1],
target=0,
kind="both",
centered=True,
subsample=1000,
# grid_resolution=100,
random_state=0,
)

plt.show()


if __name__ == "__main__":
unittest.main()
41 changes: 26 additions & 15 deletions orangecontrib/explain/widgets/owice.py
Expand Up @@ -15,9 +15,9 @@
from orangecanvas.gui.utils import disconnected
from orangewidget.utils.listview import ListViewSearch

from Orange.base import Model, SklModel, RandomForestModel
from Orange.base import Model
from Orange.data import Table, ContinuousVariable, Variable, \
DiscreteVariable
DiscreteVariable, Domain
from Orange.data.table import DomainTransformationError
from Orange.widgets import gui
from Orange.widgets.settings import ContextSetting, Setting, \
Expand Down Expand Up @@ -491,7 +491,7 @@ class OWICE(OWWidget, ConcurrentWidgetMixin):
priority = 130

class Inputs:
model = Input("Model", (SklModel, RandomForestModel))
model = Input("Model", Model)
data = Input("Data", Table)

class Outputs:
Expand Down Expand Up @@ -534,6 +534,7 @@ def __init__(self):
self.__pending_selection = self.selection
self.model: Optional[Model] = None
self.data: Optional[Table] = None
self.domain: Optional[Domain] = None
self.graph: ICEPlot = None
self._target_combo: QComboBox = None
self._features_view: ListViewSearch = None
Expand Down Expand Up @@ -624,13 +625,9 @@ def _add_buttons(self):
@Inputs.data
@check_sql_input
def set_data(self, data: Optional[Table]):
self.closeContext()
self.data = data
self.__sampled_mask = None
self._check_data()
self._setup_controls()
self.openContext(self.data.domain if self.data else None)
self.set_list_view_selection()

@Inputs.model
def set_model(self, model: Optional[Model]):
Expand Down Expand Up @@ -660,8 +657,29 @@ def _check_data(self):
self.__sampled_mask[np.random.choice(len(self.data), **kws)] = True
self.Information.data_sampled()

def handleNewSignals(self):
self.closeContext()
self.domain = None
if self.data and self.model:
model_domain = [a.name for a in self.model.domain]
attributes = [a for a in self.data.domain.attributes
if a.is_continuous and a.name in model_domain
or a.is_discrete]
class_var = self.model.domain.class_var
metas = [m for m in self.data.domain.metas if m.is_discrete]
self.domain = Domain(attributes, class_var, metas)
self._setup_controls()
self.openContext(self.domain)
self.set_list_view_selection()

self.__results_avgs = None
self._apply_feature_sorting()
self._run()
self.selection = None
self.commit.now()

def _setup_controls(self):
domain = self.data.domain if self.data else None
domain = self.domain

self._target_combo.clear()
self._target_combo.setEnabled(True)
Expand Down Expand Up @@ -702,13 +720,6 @@ def _ensure_selection_visible(view):
if len(selection) == 1:
view.scrollTo(selection[0])

def handleNewSignals(self):
self.__results_avgs = None
self._apply_feature_sorting()
self._run()
self.selection = None
self.commit.now()

def _apply_feature_sorting(self):
if self.data is None or self.model is None:
return
Expand Down
33 changes: 31 additions & 2 deletions orangecontrib/explain/widgets/tests/test_owice.py
Expand Up @@ -4,9 +4,14 @@

from AnyQt.QtCore import Qt, QPointF

from Orange.classification import RandomForestLearner
from Orange.classification import RandomForestLearner, CalibratedLearner, \
ThresholdLearner
from Orange.data import Table
from Orange.regression import RandomForestRegressionLearner
from Orange.regression import RandomForestRegressionLearner, \
SimpleRandomForestLearner
from Orange.tests.test_classification import all_learners as all_cls_learners
from Orange.tests.test_regression import all_learners as all_reg_learners, \
init_learner
from Orange.widgets.tests.base import WidgetTest
from orangecontrib.explain.widgets.owice import OWICE

Expand Down Expand Up @@ -51,6 +56,30 @@ def test_output(self):
annotated = self.get_output(self.widget.Outputs.annotated_data)
self.assertEqual(len(annotated), len(self.heart))

def test_all_reg_models(self):
data = self.housing[:10]
self.send_signal(self.widget.Inputs.data, data)
for learner in all_reg_learners():
if issubclass(learner, (SimpleRandomForestLearner,)):
continue
learner = init_learner(learner, data)
model = learner(data)
self.send_signal(self.widget.Inputs.model, model)
self.wait_until_finished()
self.assertFalse(self.widget.Error.unknown_err.is_shown())

def test_all_cls_models(self):
data = self.heart[:10]
self.send_signal(self.widget.Inputs.data, data)
for learner in all_cls_learners():
if issubclass(learner, (CalibratedLearner, ThresholdLearner)):
model = learner(RandomForestLearner())(data)
else:
model = learner()(data)
self.send_signal(self.widget.Inputs.model, model)
self.wait_until_finished()
self.assertFalse(self.widget.Error.unknown_err.is_shown())

def test_discrete_features(self):
self.send_signal(self.widget.Inputs.data, self.titanic)
self.assertTrue(self.widget.Error.no_cont_features.is_shown())
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -38,7 +38,7 @@
]

INSTALL_REQUIRES = [
"Orange3 >= 3.28.0",
"Orange3 >= 3.33.0",
"orange-widget-base",
"AnyQt",
"numpy",
Expand Down
9 changes: 4 additions & 5 deletions tox.ini
Expand Up @@ -26,11 +26,10 @@ setenv =
deps =
pyqt5==5.12.*
pyqtwebengine==5.12.*
oldest: scikit-learn~=0.22.0
oldest: orange3==3.27.1
# Use newer canvas-core and widget-base to avoid segfaults on windows
oldest: orange-canvas-core==0.1.18
oldest: orange-widget-base==4.9.0
oldest: scikit-learn==1.0.1
oldest: orange3==3.33.0
oldest: orange-canvas-core==0.1.27
oldest: orange-widget-base==4.18.0
latest: https://github.com/biolab/orange3/archive/refs/heads/master.zip#egg=orange3
latest: https://github.com/biolab/orange-canvas-core/archive/refs/heads/master.zip#egg=orange-canvas-core
latest: https://github.com/biolab/orange-widget-base/archive/refs/heads/master.zip#egg=orange-widget-base
Expand Down

0 comments on commit 3c6b0f3

Please sign in to comment.