From 38ff387152411f39cef763c456419de3e2129838 Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Thu, 22 Feb 2024 09:14:59 +0100 Subject: [PATCH 1/2] ICE: Adapt to newest scikit --- orangecontrib/explain/inspection.py | 7 +++++-- orangecontrib/explain/tests/test_inspection.py | 3 +-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/orangecontrib/explain/inspection.py b/orangecontrib/explain/inspection.py index 1c2a6b5..93e7e56 100644 --- a/orangecontrib/explain/inspection.py +++ b/orangecontrib/explain/inspection.py @@ -202,8 +202,11 @@ def individual_condition_expectation( feature_index = data.domain.index(feature.name) # fake sklearn estimator - model.fit = None - model.fit_ = None + def dummy_fit(*_, **__): + raise NotImplementedError() + + model.fit = dummy_fit + model.fit_ = dummy_fit if model.domain.class_var.is_discrete: model._estimator_type = "classifier" model.classes_ = model.domain.class_var.values diff --git a/orangecontrib/explain/tests/test_inspection.py b/orangecontrib/explain/tests/test_inspection.py index 6d33129..1b05f9c 100644 --- a/orangecontrib/explain/tests/test_inspection.py +++ b/orangecontrib/explain/tests/test_inspection.py @@ -1,6 +1,5 @@ import unittest from unittest.mock import Mock -import pkg_resources import numpy as np from sklearn.inspection import permutation_importance, partial_dependence @@ -121,7 +120,7 @@ def test_wrap_score_skl_predict_reg(self): baseline_score = scorer(mocked_model, data) mocked_model.assert_not_called() mocked_model.predict.assert_not_called() - self.assertAlmostEqual(baseline_score, 2, 0) + self.assertLess(baseline_score, 2.6) class TestPermutationFeatureImportance(unittest.TestCase): From 65721d691b3a6ed8f158c73249384c621acbd67e Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Thu, 22 Feb 2024 10:09:04 +0100 Subject: [PATCH 2/2] Handle multi target inputs --- orangecontrib/explain/widgets/owexplainfeaturebase.py | 6 ++++++ orangecontrib/explain/widgets/owexplainprediction.py | 11 ++++++++++- orangecontrib/explain/widgets/owexplainpredictions.py | 7 +++++++ orangecontrib/explain/widgets/owice.py | 6 ++++++ 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/orangecontrib/explain/widgets/owexplainfeaturebase.py b/orangecontrib/explain/widgets/owexplainfeaturebase.py index eb68624..b3ac357 100644 --- a/orangecontrib/explain/widgets/owexplainfeaturebase.py +++ b/orangecontrib/explain/widgets/owexplainfeaturebase.py @@ -21,6 +21,10 @@ from Orange.widgets.settings import Setting from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState from Orange.widgets.utils.graphicslayoutitem import SimpleLayoutItem +try: + from Orange.widgets.utils.multi_target import check_multiple_targets_input +except ImportError: + check_multiple_targets_input = lambda f: f from Orange.widgets.utils.sql import check_sql_input from Orange.widgets.utils.state_summary import format_summary_details from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView @@ -462,6 +466,7 @@ def __n_spin_changed(self): # Inputs @Inputs.data @check_sql_input + @check_multiple_targets_input def set_data(self, data: Optional[Table]): self.data = data summary = len(data) if data else self.info.NoInput @@ -473,6 +478,7 @@ def _check_data(self): pass @Inputs.model + @check_multiple_targets_input def set_model(self, model: Optional[Model]): self.closeContext() self.model = model diff --git a/orangecontrib/explain/widgets/owexplainprediction.py b/orangecontrib/explain/widgets/owexplainprediction.py index d42945c..614440a 100644 --- a/orangecontrib/explain/widgets/owexplainprediction.py +++ b/orangecontrib/explain/widgets/owexplainprediction.py @@ -17,6 +17,10 @@ from Orange.data import Table, Domain, ContinuousVariable, StringVariable from Orange.data.table import DomainTransformationError from Orange.widgets import gui +try: + from Orange.widgets.utils.multi_target import check_multiple_targets_input +except ImportError: + check_multiple_targets_input = lambda f: f from Orange.widgets.settings import Setting, ContextSetting, \ ClassValuesContextHandler from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin @@ -574,15 +578,18 @@ def __zoom_changed(self, delta: float): @Inputs.data @check_sql_input + @check_multiple_targets_input def set_data(self, data: Optional[Table]): self.data = data @Inputs.background_data @check_sql_input + @check_multiple_targets_input def set_background_data(self, data: Optional[Table]): self.background_data = data @Inputs.model + @check_multiple_targets_input def set_model(self, model: Optional[Model]): self.closeContext() self.model = model @@ -614,7 +621,9 @@ def clear(self): self.__results = None self.cancel() self.clear_scene() - self.clear_messages() + self.Error.domain_transform_err.clear() + self.Error.unknown_err.clear() + self.Information.multiple_instances.clear() def check_inputs(self): if self.data and len(self.data) > 1: diff --git a/orangecontrib/explain/widgets/owexplainpredictions.py b/orangecontrib/explain/widgets/owexplainpredictions.py index a50730f..b12c73d 100644 --- a/orangecontrib/explain/widgets/owexplainpredictions.py +++ b/orangecontrib/explain/widgets/owexplainpredictions.py @@ -19,6 +19,10 @@ from Orange.data import Table, Domain, ContinuousVariable, Variable from Orange.data.table import DomainTransformationError from Orange.widgets import gui +try: + from Orange.widgets.utils.multi_target import check_multiple_targets_input +except ImportError: + check_multiple_targets_input = lambda f: f from Orange.widgets.settings import ContextSetting, Setting, \ PerfectDomainContextHandler from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_SIGNAL_NAME, \ @@ -649,6 +653,7 @@ def _add_buttons(self): @Inputs.data @check_sql_input + @check_multiple_targets_input def set_data(self, data: Optional[Table]): self.closeContext() self.data = data @@ -658,10 +663,12 @@ def set_data(self, data: Optional[Table]): @Inputs.background_data @check_sql_input + @check_multiple_targets_input def set_background_data(self, data: Optional[Table]): self.background_data = data @Inputs.model + @check_multiple_targets_input def set_model(self, model: Optional[Model]): self.model = model diff --git a/orangecontrib/explain/widgets/owice.py b/orangecontrib/explain/widgets/owice.py index 3764515..914b45b 100644 --- a/orangecontrib/explain/widgets/owice.py +++ b/orangecontrib/explain/widgets/owice.py @@ -37,6 +37,10 @@ from orangecontrib.explain.inspection import individual_condition_expectation from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog +try: + from Orange.widgets.utils.multi_target import check_multiple_targets_input +except ImportError: + check_multiple_targets_input = lambda f: f class RunnerResults(SimpleNamespace): @@ -621,12 +625,14 @@ def _add_buttons(self): @Inputs.data @check_sql_input + @check_multiple_targets_input def set_data(self, data: Optional[Table]): self.data = data self.__sampled_mask = None self._check_data() @Inputs.model + @check_multiple_targets_input def set_model(self, model: Optional[Model]): self.model = model