Skip to content

Commit

Permalink
Merge pull request #20 from VesnaT/add_tests
Browse files Browse the repository at this point in the history
[FIX] Improve coverage
  • Loading branch information
PrimozGodec committed Mar 22, 2021
2 parents 8feafec + 5ce7be7 commit 255c666
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
16 changes: 11 additions & 5 deletions orangecontrib/explain/widgets/owexplainfeaturebase.py
Expand Up @@ -397,6 +397,7 @@ class Error(OWWidget.Error):
class Information(OWWidget.Information):
data_sampled = Msg("Data has been sampled.")

settingsHandler = NotImplemented
n_attributes = Setting(10)
zoom_level = Setting(0)
selection = Setting((), schema_only=True)
Expand Down Expand Up @@ -502,11 +503,6 @@ def _clear_scene(self):

def update_scene(self):
self._clear_scene()
if self.results is not None:
values = np.mean(self.results.x, axis=0)
indices = np.argsort(values)[::-1]
names = [self.results.names[i] for i in indices]
self.setup_plot(values[indices], names)

def setup_plot(self, values: np.ndarray, names: List[str], *plot_args):
width = int(self.view.viewport().rect().width())
Expand Down Expand Up @@ -622,6 +618,7 @@ def run(data: Table, model: Model, *, state: TaskState) -> BaseResults:

if __name__ == "__main__": # pragma: no cover
from Orange.classification import RandomForestLearner
from Orange.widgets.settings import DomainContextHandler
from Orange.widgets.utils.widgetpreview import WidgetPreview


Expand Down Expand Up @@ -676,6 +673,7 @@ def select_from_settings(self, *_):
class Widget(OWExplainFeatureBase):
name = "Explain"
PLOT_CLASS = Plot
settingsHandler = DomainContextHandler()

def update_selection(self, *_):
pass
Expand All @@ -686,6 +684,14 @@ def get_selected_data(self):
def get_scores_table(self):
return None

def update_scene(self):
super().update_scene()
if self.results is not None:
values = np.mean(self.results.x, axis=0)
indices = np.argsort(values)[::-1]
names = [self.results.names[i] for i in indices]
self.setup_plot(values[indices], names)

@staticmethod
def run(data, model, state):
if not data or not model:
Expand Down
4 changes: 2 additions & 2 deletions orangecontrib/explain/widgets/owexplainmodel.py
Expand Up @@ -332,7 +332,7 @@ def _set_items(self, x: np.ndarray, labels: List[str], colors: np.ndarray):
item.selection_changed.connect(self.select)
self._items.append(item)
self._layout.addItem(item, i, FeaturesPlot.ITEM_COLUMN)
if i == MAX_N_ITEMS:
if i == MAX_N_ITEMS - 1:
break


Expand Down Expand Up @@ -399,7 +399,7 @@ def setup_controls(self):

# Plot setup
def update_scene(self):
self._clear_scene()
super().update_scene()
if self.results is not None:
assert isinstance(self.results.x, list)
x = self.results.x[self.target_index]
Expand Down
4 changes: 2 additions & 2 deletions orangecontrib/explain/widgets/owpermutationimportance.py
Expand Up @@ -122,7 +122,7 @@ def _set_items(self, x: np.ndarray, labels: List[str], std: np.ndarray,
item.set_data(x[i], std[i])
self._items.append(item)
self._layout.addItem(item, i, FeaturesPlot.ITEM_COLUMN)
if i == MAX_N_ITEMS:
if i == MAX_N_ITEMS - 1:
break
self._bottom_axis.setLabel(x_label)

Expand Down Expand Up @@ -290,7 +290,7 @@ def get_runner_parameters(self) -> Tuple[Optional[Table], Optional[Model],

# Plot setup
def update_scene(self):
self._clear_scene()
super().update_scene()
if self.results is not None:
importance = self.results.x
mean = np.mean(importance, axis=1)
Expand Down
Expand Up @@ -205,6 +205,24 @@ def test_n_repeats(self, mocked_func: Mock):
self.wait_until_finished()
self.assertEqual(mocked_func.call_args[0][3], 3)

@patch("orangecontrib.explain.widgets.owpermutationimportance."
"MAX_N_ITEMS", 3)
def test_n_attributes(self):
self.widget.controls.n_attributes.setValue(3)
self.send_signal(self.widget.Inputs.data, self.iris)
self.send_signal(self.widget.Inputs.model, self.rf_cls)
self.wait_until_finished()
domain = self.iris.domain
domain = Domain(domain.attributes[:3], domain.class_vars)
self.assertDomainInPlot(self.widget.plot, domain)

def test_zoom_level(self):
self.send_signal(self.widget.Inputs.data, self.iris)
self.send_signal(self.widget.Inputs.model, self.rf_cls)
self.wait_until_finished()
self.widget.controls.zoom_level.setValue(10)
self.assertDomainInPlot(self.widget.plot, self.iris.domain)

def test_plot(self):
self.send_signal(self.widget.Inputs.data, self.iris)
self.wait_until_finished()
Expand Down

0 comments on commit 255c666

Please sign in to comment.