Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
parrt committed Jul 13, 2023
2 parents b7279ec + be0cdfc commit 55d74b0
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 8 deletions.
9 changes: 7 additions & 2 deletions dtreeviz/models/shadow_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,17 @@ def get_shadow_tree(tree_model, X_train, y_train, feature_names, target_name, cl
from dtreeviz.models import lightgbm_decision_tree
return lightgbm_decision_tree.ShadowLightGBMTree(tree_model, tree_index, X_train, y_train,
feature_names, target_name, class_names)
elif "tensorflow_decision_forests.keras.RandomForestModel" in str(type(tree_model)):
elif any(tf_model in str(type(tree_model)) for tf_model in ["tensorflow_decision_forests.keras.RandomForestModel",
"tensorflow_decision_forests.keras.GradientBoostedTreesModel"]):
from dtreeviz.models import tensorflow_decision_tree
return tensorflow_decision_tree.ShadowTensorflowTree(tree_model, tree_index, X_train, y_train,
feature_names, target_name, class_names)
else:
raise ValueError(
f"Tree model must be in (DecisionTreeRegressor, DecisionTreeClassifier, "
"xgboost.core.Booster, lightgbm.basic.Booster, pyspark DecisionTreeClassificationModel, "
f"pyspark DecisionTreeClassificationModel, tensorflow_decision_forests.keras.RandomForestModel) "
f"pyspark DecisionTreeClassificationModel, tensorflow_decision_forests.keras.RandomForestModel, "
f"tensorflow_decision_forests.keras.GradientBoostedTreesModel) "
f"but you passed a {tree_model.__class__.__name__}!")


Expand Down Expand Up @@ -560,6 +562,9 @@ def prediction_name(self) -> (str, None):
Return prediction class or value otherwise.
"""
if self.isclassifier():
# In a GBT model, the trees are always regressive trees (even if the GBT is a classifier).
if "tensorflow_decision_forests.keras.GradientBoostedTreesModel" in str(type(self.shadow_tree.tree_model)):
return round(self.prediction(), 6)
if self.shadow_tree.class_names is not None:
return self.shadow_tree.class_names[self.prediction()]
return self.prediction()
Expand Down
2 changes: 1 addition & 1 deletion dtreeviz/models/sklearn_decision_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,4 @@ def get_min_samples_leaf(self):
return self.tree_model.min_samples_leaf

def shouldGoLeftAtSplit(self, id, x):
return x < self.get_node_split(id)
return x <= self.get_node_split(id)
4 changes: 4 additions & 0 deletions dtreeviz/models/tensorflow_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def get_node_nsamples_by_class(self, id):

def get_prediction(self, id):
if self.is_classifier():
# In a GBT model, the trees are always regressive trees (even if the GBT is a classifier). So we don't
# have the probability attribute
if "tensorflow_decision_forests.keras.GradientBoostedTreesModel" in str(type(self.model)):
return self.tree_nodes[id].value.value
return np.argmax(self.tree_nodes[id].value.probability)
else:
return self.tree_nodes[id].value.value
Expand Down
5 changes: 2 additions & 3 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,8 +1379,7 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
figsize = (.75, .8)

fig, ax = plt.subplots(1, 1, figsize=figsize)

m = np.mean(y)
m = node.prediction()

_format_axes(ax, None, None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False)
ax.set_ylim(y_range)
Expand Down Expand Up @@ -1534,7 +1533,7 @@ def _get_leaf_target_input(shadow_tree: ShadowDecTree, precision: int):
for i, node in enumerate(shadow_tree.leaves):
leaf_index_sample = node.samples()
leaf_target = shadow_tree.y_train[leaf_index_sample]
leaf_target_mean = np.mean(leaf_target)
leaf_target_mean = node.prediction()
np.random.seed(0) # generate the same list of random values for each call
X = np.random.normal(i, sigma, size=len(leaf_target))

Expand Down
2 changes: 1 addition & 1 deletion dtreeviz/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
__version__ = '2.2.1'
__version__ = '2.2.2'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name='dtreeviz',
version='2.2.1',
version='2.2.2',
url='https://github.com/parrt/dtreeviz',
license='MIT',
packages=find_packages(),
Expand Down

0 comments on commit 55d74b0

Please sign in to comment.