Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
parrt committed Apr 16, 2023
2 parents 1fb2749 + d42b273 commit 683cd48
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 21 deletions.
2 changes: 1 addition & 1 deletion dtreeviz/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def dtreeviz(tree_model,
instance_orientation,
show_root_edge_labels, show_node_labels, show_just_path, fancy, histtype, highlight_path, X,
max_X_features_LR, max_X_features_TD, depth_range_to_display, label_fontsize, ticks_fontsize,
fontname, title, title_fontsize, colors, scale)
fontname, title, title_fontsize, colors=colors, scale=scale)


def viz_leaf_samples(tree_model,
Expand Down
13 changes: 12 additions & 1 deletion dtreeviz/models/sklearn_decision_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,19 @@ def get_node_feature(self, id) -> int:
return self.tree_model.tree_.feature[id]

def get_node_nsamples_by_class(self, id):
# This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot
# be made on new datasets.
# if self.is_classifier():
# return self.tree_model.tree_.value[id][0]

# This code allows us to return the nsamples/class based on a dataset, train or validation
if self.is_classifier():
return self.tree_model.tree_.value[id][0]
all_nodes = self.internal + self.leaves
node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
if self.get_class_weights() is None:
return node_value[0]
else:
return node_value[0] * self.get_class_weights()

def get_prediction(self, id):
if self.is_classifier():
Expand Down
94 changes: 75 additions & 19 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import tempfile
from typing import Mapping, List
from typing import Mapping, List, Callable

import matplotlib
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -134,6 +135,9 @@ def leaf_sizes(self,

def ctree_leaf_distributions(self,
display_type: ("plot", "text") = "plot",
xaxis_display_type: str = "individual",
show_leaf_id_list: list = None,
show_leaf_filter: Callable[[np.ndarray], bool] = None,
plot_ylim: int = None,
colors: dict = None,
fontsize: int = 10,
Expand All @@ -156,6 +160,16 @@ def ctree_leaf_distributions(self,
:param display_type: str, optional
'plot' or 'text'
:param xaxis_display_type: str, optional
'individual': Displays every node ID individually
'auto': Let matplotlib automatically manage the node ID ticks
'y_sorted': Display in y order with no x-axis tick labels
:param show_leaf_id_list: list, optional
The allowed list of node id values to plot
:param show_leaf_filter: Callable[[np.ndarray], bool], optional
The filtering function to apply to leaf values before displaying the leaves.
The function is applied to a numpy array with the class i sample value in row i.
For example, to view only those leaves with more than 100 total samples, and more than 5 class 1 samples, use show_leaf_filter = lambda x: (100 < np.sum(x)) & (5 < x[1])
:param plot_ylim: int, optional
The max value for oY. This is useful in case we have few leaves with big sample values which 'shadow'
the other leaves values.
Expand All @@ -181,26 +195,54 @@ def ctree_leaf_distributions(self,
else:
fig, ax = plt.subplots()

ax.set_xticks(range(0, len(index)))
ax.set_xticklabels(index)
if plot_ylim is not None:
ax.set_ylim(0, plot_ylim)

leaf_samples_hist = [[] for i in range(self.shadow_tree.nclasses())]
for leaf_sample in leaf_samples:
for i, leaf_count in enumerate(leaf_sample):
leaf_samples_hist[i].append(leaf_count)
leaf_samples_hist = np.array(leaf_samples_hist)

if show_leaf_id_list is not None:
_mask = np.isin(index, show_leaf_id_list)
leaf_samples_hist = leaf_samples_hist[:, _mask]
index = tuple(np.array(index)[_mask])
if show_leaf_filter is not None:
_mask = np.apply_along_axis(show_leaf_filter, 0, leaf_samples_hist)
leaf_samples_hist = leaf_samples_hist[:, _mask]
index = tuple(np.array(index)[_mask])

if xaxis_display_type == 'individual':
x = np.arange(0, len(index))
ax.set_xticks(x)
ax.set_xticklabels(index)
elif xaxis_display_type == 'auto':
x = np.array(index)
ax.set_xlim(np.min(x)-1, np.max(x)+1)
elif xaxis_display_type == 'y_sorted':
# sort by total y = sum(classes), then class 0, 1, 2, ...
sort_cols = [np.sum(leaf_samples_hist, axis=0)]
for i in range(leaf_samples_hist.shape[0]):
sort_cols.append(leaf_samples_hist[i])
_sort = np.lexsort(sort_cols[::-1])[::-1]
leaf_samples_hist = leaf_samples_hist[:, _sort]
index = tuple(np.array(index)[_sort])

x = np.arange(0, len(index))
ax.set_xticks(x)
ax.set_xticklabels([])
ax.tick_params(axis='x', which='both', bottom=False)
else:
raise ValueError(f'Unknown xaxis_display_type = {xaxis_display_type}!')

bar_containers = []
bottom_values = np.full(len(index), 0)
for i, leaf_sample in enumerate(leaf_samples_hist):
bar_container = ax.bar(range(0, len(index)), leaf_sample, bottom=bottom_values,
if plot_ylim is not None:
ax.set_ylim(0, plot_ylim)

bottom_values = np.zeros(len(index))
for i in range(leaf_samples_hist.shape[0]):
bar_container = ax.bar(x, leaf_samples_hist[i], bottom=bottom_values,
color=colors_classes[i],
lw=.3, align='center', width=1)
bottom_values = bottom_values + np.array(leaf_sample)
bar_containers.append(bar_container)
bottom_values = bottom_values + leaf_samples_hist[i]

for bar_container in bar_containers:
for rect in bar_container.patches:
rect.set_linewidth(.5)
rect.set_edgecolor(colors['rect_edge'])
Expand Down Expand Up @@ -762,8 +804,8 @@ def node_stats(self, node_id: int) -> pd.DataFrame:
"""

node_samples = self.shadow_tree.get_node_samples()
df = pd.DataFrame(self.shadow_tree.X_train, columns=self.shadow_tree.feature_names)
return df.iloc[node_samples[node_id]].describe()
df = pd.DataFrame(self.shadow_tree.X_train, columns=self.shadow_tree.feature_names).convert_dtypes()
return df.iloc[node_samples[node_id]].describe(include='all')

def instance_feature_importance(self, x,
colors: dict = None,
Expand Down Expand Up @@ -884,7 +926,7 @@ def rtree_leaf_distributions(self,
for i in range(len(means)):
ax.plot(means[i], means_range[i], color=colors['split_line'], linewidth=prediction_line_width)

_format_axes(ax, self.shadow_tree.target_name, "Leaf", colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=None, grid=grid)
_format_axes(ax, self.shadow_tree.target_name, "Leaf IDs", colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=None, grid=grid)

def ctree_feature_space(self,
fontsize=10,
Expand Down Expand Up @@ -1142,12 +1184,22 @@ def _class_split_viz(node: ShadowDecTreeNode,
histtype=histtype,
bins=bins,
label=class_names)

# Alter appearance of each bar
for patch in barcontainers:
for rect in patch.patches:
if isinstance(barcontainers[0], matplotlib.container.BarContainer):
for patch in barcontainers:
for rect in patch.patches:
rect.set_linewidth(.5)
rect.set_edgecolor(colors['rect_edge'])
ax.set_yticks([0, max([max(h) for h in hist])])
elif isinstance(barcontainers[0], matplotlib.patches.Rectangle):
# In case a node will contains samples from only one class.
for rect in barcontainers.patches:
rect.set_linewidth(.5)
rect.set_edgecolor(colors['rect_edge'])
ax.set_yticks([0, max([max(h) for h in hist])])
ax.set_yticks([0, max(hist)])



# set an empty space at the beginning and the end of the node visualisation for better clarity
bin_length = bins[1] - bins[0]
Expand Down Expand Up @@ -1200,6 +1252,10 @@ def _class_leaf_viz(node: ShadowDecTreeNode,
counts = node.class_counts()
prediction = node.prediction_name()

# when using another dataset than the training dataset, some leaves could have 0 samples.
# Trying to make a pie chart will raise some deprecation
if sum(counts) == 0:
return
if leaftype == 'pie':
_draw_piechart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
graph_colors=graph_colors, fontname=fontname)
Expand Down

0 comments on commit 683cd48

Please sign in to comment.