Skip to content

Commit

Permalink
Merge pull request #69 from VesnaT/pls
Browse files Browse the repository at this point in the history
Handle multi target inputs
  • Loading branch information
PrimozGodec committed Mar 5, 2024
2 parents c80caa5 + 65721d6 commit ce9068a
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 5 deletions.
7 changes: 5 additions & 2 deletions orangecontrib/explain/inspection.py
Expand Up @@ -202,8 +202,11 @@ def individual_condition_expectation(
feature_index = data.domain.index(feature.name)

# fake sklearn estimator
model.fit = None
model.fit_ = None
def dummy_fit(*_, **__):
raise NotImplementedError()

model.fit = dummy_fit
model.fit_ = dummy_fit
if model.domain.class_var.is_discrete:
model._estimator_type = "classifier"
model.classes_ = model.domain.class_var.values
Expand Down
3 changes: 1 addition & 2 deletions orangecontrib/explain/tests/test_inspection.py
@@ -1,6 +1,5 @@
import unittest
from unittest.mock import Mock
import pkg_resources

import numpy as np
from sklearn.inspection import permutation_importance, partial_dependence
Expand Down Expand Up @@ -121,7 +120,7 @@ def test_wrap_score_skl_predict_reg(self):
baseline_score = scorer(mocked_model, data)
mocked_model.assert_not_called()
mocked_model.predict.assert_not_called()
self.assertAlmostEqual(baseline_score, 2, 0)
self.assertLess(baseline_score, 2.6)


class TestPermutationFeatureImportance(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions orangecontrib/explain/widgets/owexplainfeaturebase.py
Expand Up @@ -21,6 +21,10 @@
from Orange.widgets.settings import Setting
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.graphicslayoutitem import SimpleLayoutItem
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.utils.state_summary import format_summary_details
from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView
Expand Down Expand Up @@ -462,6 +466,7 @@ def __n_spin_changed(self):
# Inputs
@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.data = data
summary = len(data) if data else self.info.NoInput
Expand All @@ -473,6 +478,7 @@ def _check_data(self):
pass

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.closeContext()
self.model = model
Expand Down
11 changes: 10 additions & 1 deletion orangecontrib/explain/widgets/owexplainprediction.py
Expand Up @@ -17,6 +17,10 @@
from Orange.data import Table, Domain, ContinuousVariable, StringVariable
from Orange.data.table import DomainTransformationError
from Orange.widgets import gui
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f
from Orange.widgets.settings import Setting, ContextSetting, \
ClassValuesContextHandler
from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin
Expand Down Expand Up @@ -574,15 +578,18 @@ def __zoom_changed(self, delta: float):

@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.data = data

@Inputs.background_data
@check_sql_input
@check_multiple_targets_input
def set_background_data(self, data: Optional[Table]):
self.background_data = data

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.closeContext()
self.model = model
Expand Down Expand Up @@ -614,7 +621,9 @@ def clear(self):
self.__results = None
self.cancel()
self.clear_scene()
self.clear_messages()
self.Error.domain_transform_err.clear()
self.Error.unknown_err.clear()
self.Information.multiple_instances.clear()

def check_inputs(self):
if self.data and len(self.data) > 1:
Expand Down
7 changes: 7 additions & 0 deletions orangecontrib/explain/widgets/owexplainpredictions.py
Expand Up @@ -19,6 +19,10 @@
from Orange.data import Table, Domain, ContinuousVariable, Variable
from Orange.data.table import DomainTransformationError
from Orange.widgets import gui
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f
from Orange.widgets.settings import ContextSetting, Setting, \
PerfectDomainContextHandler
from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_SIGNAL_NAME, \
Expand Down Expand Up @@ -649,6 +653,7 @@ def _add_buttons(self):

@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.closeContext()
self.data = data
Expand All @@ -658,10 +663,12 @@ def set_data(self, data: Optional[Table]):

@Inputs.background_data
@check_sql_input
@check_multiple_targets_input
def set_background_data(self, data: Optional[Table]):
self.background_data = data

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.model = model

Expand Down
6 changes: 6 additions & 0 deletions orangecontrib/explain/widgets/owice.py
Expand Up @@ -37,6 +37,10 @@

from orangecontrib.explain.inspection import individual_condition_expectation
from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f


class RunnerResults(SimpleNamespace):
Expand Down Expand Up @@ -621,12 +625,14 @@ def _add_buttons(self):

@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.data = data
self.__sampled_mask = None
self._check_data()

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.model = model

Expand Down

0 comments on commit ce9068a

Please sign in to comment.