Skip to content

Commit

Permalink
Merge pull request #23 from VesnaT/feature_importance_fix
Browse files Browse the repository at this point in the history
[FIX] Feature Importance: Handle data with no features
  • Loading branch information
PrimozGodec committed Apr 28, 2021
2 parents 26d4cd4 + 9712044 commit 30bfd58
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
2 changes: 2 additions & 0 deletions orangecontrib/explain/inspection.py
Expand Up @@ -69,6 +69,8 @@ def permutation_feature_importance(
def _check_data(data: Table):
if not data.domain.class_var:
raise ValueError("Data with a target variable required.")
if not data.domain.attributes:
raise ValueError("Data with features required.")


def _check_model(model: Model, data: Table) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions orangecontrib/explain/tests/test_inspection.py
Expand Up @@ -234,6 +234,15 @@ def test_inadequate_data(self):
self.assertRaises(DomainTransformationError,
permutation_feature_importance, *args)

def test_inadequate_data(self):
domain = Domain([],
class_vars=self.iris.domain.class_vars,
metas=self.iris.domain.attributes)
data = self.iris.transform(domain)
model = RandomForestLearner()(self.iris)
args = model, data, self.n_repeats
self.assertRaises(ValueError, permutation_feature_importance, *args)

def test_inadequate_model(self):
model = RandomForestLearner()(self.iris)
args = model, self.housing, self.n_repeats
Expand Down
Expand Up @@ -81,6 +81,17 @@ def test_regression_data_classification_model(self):
self.assertPlotEmpty(self.widget.plot)
self.assertTrue(self.widget.Error.unknown_err.is_shown())

def test_data_with_no_features(self):
domain = Domain([],
class_vars=self.iris.domain.class_vars,
metas=self.iris.domain.attributes)
data = self.iris.transform(domain)
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.model, self.rf_cls)
self.wait_until_finished()
self.assertPlotEmpty(self.widget.plot)
self.assertTrue(self.widget.Error.unknown_err.is_shown())

def test_output_scores(self):
self.send_signal(self.widget.Inputs.data, self.iris)
self.send_signal(self.widget.Inputs.model, self.rf_cls)
Expand Down

0 comments on commit 30bfd58

Please sign in to comment.