Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle multi target inputs #69

Merged
merged 2 commits into from Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions orangecontrib/explain/inspection.py
Expand Up @@ -202,8 +202,11 @@ def individual_condition_expectation(
feature_index = data.domain.index(feature.name)

# fake sklearn estimator
model.fit = None
model.fit_ = None
def dummy_fit(*_, **__):
raise NotImplementedError()

model.fit = dummy_fit
model.fit_ = dummy_fit
if model.domain.class_var.is_discrete:
model._estimator_type = "classifier"
model.classes_ = model.domain.class_var.values
Expand Down
3 changes: 1 addition & 2 deletions orangecontrib/explain/tests/test_inspection.py
@@ -1,6 +1,5 @@
import unittest
from unittest.mock import Mock
import pkg_resources

import numpy as np
from sklearn.inspection import permutation_importance, partial_dependence
Expand Down Expand Up @@ -121,7 +120,7 @@ def test_wrap_score_skl_predict_reg(self):
baseline_score = scorer(mocked_model, data)
mocked_model.assert_not_called()
mocked_model.predict.assert_not_called()
self.assertAlmostEqual(baseline_score, 2, 0)
self.assertLess(baseline_score, 2.6)


class TestPermutationFeatureImportance(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions orangecontrib/explain/widgets/owexplainfeaturebase.py
Expand Up @@ -21,6 +21,10 @@
from Orange.widgets.settings import Setting
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.graphicslayoutitem import SimpleLayoutItem
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.utils.state_summary import format_summary_details
from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView
Expand Down Expand Up @@ -462,6 +466,7 @@ def __n_spin_changed(self):
# Inputs
@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.data = data
summary = len(data) if data else self.info.NoInput
Expand All @@ -473,6 +478,7 @@ def _check_data(self):
pass

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.closeContext()
self.model = model
Expand Down
11 changes: 10 additions & 1 deletion orangecontrib/explain/widgets/owexplainprediction.py
Expand Up @@ -17,6 +17,10 @@
from Orange.data import Table, Domain, ContinuousVariable, StringVariable
from Orange.data.table import DomainTransformationError
from Orange.widgets import gui
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f
from Orange.widgets.settings import Setting, ContextSetting, \
ClassValuesContextHandler
from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin
Expand Down Expand Up @@ -574,15 +578,18 @@ def __zoom_changed(self, delta: float):

@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.data = data

@Inputs.background_data
@check_sql_input
@check_multiple_targets_input
def set_background_data(self, data: Optional[Table]):
self.background_data = data

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.closeContext()
self.model = model
Expand Down Expand Up @@ -614,7 +621,9 @@ def clear(self):
self.__results = None
self.cancel()
self.clear_scene()
self.clear_messages()
self.Error.domain_transform_err.clear()
self.Error.unknown_err.clear()
self.Information.multiple_instances.clear()

def check_inputs(self):
if self.data and len(self.data) > 1:
Expand Down
7 changes: 7 additions & 0 deletions orangecontrib/explain/widgets/owexplainpredictions.py
Expand Up @@ -19,6 +19,10 @@
from Orange.data import Table, Domain, ContinuousVariable, Variable
from Orange.data.table import DomainTransformationError
from Orange.widgets import gui
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f
from Orange.widgets.settings import ContextSetting, Setting, \
PerfectDomainContextHandler
from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_SIGNAL_NAME, \
Expand Down Expand Up @@ -649,6 +653,7 @@ def _add_buttons(self):

@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.closeContext()
self.data = data
Expand All @@ -658,10 +663,12 @@ def set_data(self, data: Optional[Table]):

@Inputs.background_data
@check_sql_input
@check_multiple_targets_input
def set_background_data(self, data: Optional[Table]):
self.background_data = data

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.model = model

Expand Down
6 changes: 6 additions & 0 deletions orangecontrib/explain/widgets/owice.py
Expand Up @@ -37,6 +37,10 @@

from orangecontrib.explain.inspection import individual_condition_expectation
from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f


class RunnerResults(SimpleNamespace):
Expand Down Expand Up @@ -621,12 +625,14 @@ def _add_buttons(self):

@Inputs.data
@check_sql_input
@check_multiple_targets_input
def set_data(self, data: Optional[Table]):
self.data = data
self.__sampled_mask = None
self._check_data()

@Inputs.model
@check_multiple_targets_input
def set_model(self, model: Optional[Model]):
self.model = model

Expand Down