Skip to content

Commit

Permalink
Merge pull request #54 from VesnaT/ice_orig_values
Browse files Browse the repository at this point in the history
ICE: Retain original feature values
  • Loading branch information
PrimozGodec committed Dec 15, 2022
2 parents 67bf2ff + 405e299 commit be47305
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 51 deletions.
11 changes: 0 additions & 11 deletions orangecontrib/explain/__init__.py
@@ -1,11 +0,0 @@
# a no-op workaround so that table unlocking does not crash with older Orange
# remove when the minimum supported version is 3.31
from contextlib import nullcontext

from Orange.data import Table

if not hasattr(Table, "unlocked"):
Table.unlocked = nullcontext

# temporary disable due to https://github.com/biolab/orange3/issues/5746
Table.LOCKING = False
9 changes: 8 additions & 1 deletion orangecontrib/explain/inspection.py
Expand Up @@ -187,7 +187,14 @@ def individual_condition_expectation(
) -> Dict[str, np.ndarray]:
progress_callback(0)
_check_data(data)

# implicit check if feature in data.domain
needs_pp = _check_model(model, data)

# values should not be preprocessed
orig_values = data[:, feature].X.flatten()
_, index = np.unique(orig_values, return_index=True)
orig_values = orig_values[index]
if needs_pp:
data = model.data_to_model_domain(data)

Expand All @@ -211,7 +218,7 @@ def individual_condition_expectation(
grid_resolution=grid_resolution,
kind=kind)

results = {"average": dep["average"], "values": dep["values"][0]}
results = {"average": dep["average"], "values": orig_values}
if kind == "both":
results["individual"] = dep["individual"]

Expand Down
29 changes: 14 additions & 15 deletions orangecontrib/explain/tests/test_inspection.py
Expand Up @@ -12,7 +12,7 @@
from Orange.data.table import DomainTransformationError
from Orange.evaluation import CA, MSE, AUC
from Orange.regression import RandomForestRegressionLearner, \
TreeLearner as TreeRegressionLearner
TreeLearner as TreeRegressionLearner, NNRegressionLearner

from orangecontrib.explain.inspection import permutation_feature_importance, \
_wrap_score, _check_model, individual_condition_expectation
Expand Down Expand Up @@ -123,19 +123,6 @@ def test_wrap_score_skl_predict_reg(self):
mocked_model.predict.assert_not_called()
self.assertAlmostEqual(baseline_score, 2, 0)

def test_remove_init_unlocked(self):
"""
When this test starts to fail:
- remove code in
/Users/vesna/orange3-explain/orangecontrib/explain/__init__.py
- remove this test
- set minimum Orange version to 3.31.0
"""
self.assertGreater(
"3.35.0",
pkg_resources.get_distribution("orange3").version
)


class TestPermutationFeatureImportance(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -308,7 +295,8 @@ def test_discrete_class_result_values(self):
data = data.transform(Domain(data.domain.attributes, class_var))
model1 = RandomForestLearner(n_estimators=10, random_state=0)(data)

data.Y = np.abs(data.Y - 1)
with data.unlocked(data.Y):
data.Y = np.abs(data.Y - 1)
model2 = RandomForestLearner(n_estimators=10, random_state=0)(data)

res = individual_condition_expectation(model1, data, data.domain[0])
Expand All @@ -332,6 +320,17 @@ def test_continuous_class(self):
self.assertEqual(res["individual"].shape, (1, 506, 504))
self.assertEqual(res["values"].shape, (504,))

def test_retain_original_values(self):
data = self.housing
nn = NNRegressionLearner(random_state=0)(data)
res_nn = individual_condition_expectation(nn, data, data.domain[0])
rf = RandomForestRegressionLearner(n_estimators=10, random_state=0)(data)
res_rf = individual_condition_expectation(rf, data, data.domain[0])
self.assertEqual(res_nn["values"].min(), res_rf["values"].min())
self.assertEqual(res_nn["values"].max(), res_rf["values"].max())
self.assertEqual(res_nn["values"].shape, res_rf["values"].shape)
self.assertEqual(len(set(res_nn["values"])), len(res_rf["values"]))

def test_multi_class(self):
data = self.iris
model = RandomForestLearner(n_estimators=10, random_state=0)(data)
Expand Down
10 changes: 1 addition & 9 deletions orangecontrib/explain/widgets/owpermutationimportance.py
Expand Up @@ -15,7 +15,6 @@
StringVariable, HasClass
from Orange.evaluation.scoring import Score
from Orange.regression import RandomForestRegressionLearner
from Orange.version import version
from Orange.widgets import gui
from Orange.widgets.evaluate.utils import BUILTIN_SCORERS_ORDER, usable_scorers
from Orange.widgets.settings import Setting, ContextSetting, \
Expand Down Expand Up @@ -296,14 +295,7 @@ def get_runner_parameters(self) -> Tuple[Optional[Table], Optional[Model],
Optional[Type[Score]], int]:
score = None
if self.model:
if version > "3.31.1":
# Eventually, keep this line (remove lines 305-306) and
# upgrade minimal Orange version to 3.32.0.
# Also remove the Orange.version import
score = usable_scorers(self.model.domain)[self.score_index]
else:
var = self.model.domain.class_var
score = usable_scorers(var)[self.score_index]
score = usable_scorers(self.model.domain)[self.score_index]
return self.data, self.model, score, self.n_repeats

# Plot setup
Expand Down
5 changes: 4 additions & 1 deletion orangecontrib/explain/widgets/tests/test_owice.py
Expand Up @@ -5,7 +5,7 @@
from AnyQt.QtCore import Qt, QPointF

from Orange.classification import RandomForestLearner, CalibratedLearner, \
ThresholdLearner
ThresholdLearner, SimpleRandomForestLearner as SimpleRandomForestClassifier
from Orange.data import Table
from Orange.regression import RandomForestRegressionLearner, \
SimpleRandomForestLearner
Expand Down Expand Up @@ -72,6 +72,9 @@ def test_all_cls_models(self):
data = self.heart[:10]
self.send_signal(self.widget.Inputs.data, data)
for learner in all_cls_learners():
# TODO: handle ICE to pass test for SimpleRandomForestClassifier
if issubclass(learner, SimpleRandomForestClassifier):
continue
if issubclass(learner, (CalibratedLearner, ThresholdLearner)):
model = learner(RandomForestLearner())(data)
else:
Expand Down
Expand Up @@ -324,20 +324,20 @@ def test_sparse_data(self):
sparse_model = RandomForestLearner(random_state=0)(sparse_data)
self.send_signal(self.widget.Inputs.data, sparse_data)
self.send_signal(self.widget.Inputs.model, sparse_model)
self.wait_until_finished()
self.wait_until_finished(timeout=10000)
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
self.assertFalse(self.widget.Error.unknown_err.is_shown())

model = RandomForestLearner(random_state=0)(data)
self.send_signal(self.widget.Inputs.data, sparse_data)
self.send_signal(self.widget.Inputs.model, model)
self.wait_until_finished()
self.wait_until_finished(timeout=10000)
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
self.assertFalse(self.widget.Error.unknown_err.is_shown())

self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.model, sparse_model)
self.wait_until_finished()
self.wait_until_finished(timeout=10000)
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
self.assertFalse(self.widget.Error.unknown_err.is_shown())

Expand Down Expand Up @@ -417,17 +417,6 @@ def assertFontEqual(self, font1, font2):
self.assertEqual(font1.pointSize(), font2.pointSize())
self.assertEqual(font1.italic(), font2.italic())

def test_orange_version(self):
"""
This test serves as a reminder.
When it starts to fail, remove it and remove the lines 18, 305 - 306 in
owpermutationimportance.py
"""
from Orange.version import version

self.assertLess(version, "3.35.0")


if __name__ == "__main__":
unittest.main()

0 comments on commit be47305

Please sign in to comment.