Skip to content

Commit

Permalink
Merge pull request #40 from arminwitte/import_export_json
Browse files Browse the repository at this point in the history
Import export json
  • Loading branch information
arminwitte committed Jun 25, 2023
2 parents ea29d56 + 9f8b902 commit 1415c5b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 1 deletion.
8 changes: 8 additions & 0 deletions binarybeech/attributehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,14 @@ def register_method_group(self, method_group):

def register_handler(self, attribute_handler_class, method_group="default"):
self.attribute_handlers[method_group].append(attribute_handler_class)

def __getitem__(self,name):
ahc = None
for val in self.attribute_handlers.values():
for a in val:
if name == a.__name__:
ahc = a
return ahc

def get_attribute_handler_class(self, arr, method_group="default"):
"""
Expand Down
109 changes: 108 additions & 1 deletion binarybeech/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# coding: utf-8

import numpy as np

import json
from binarybeech.attributehandler import attribute_handler_factory

class Node:
def __init__(
Expand Down Expand Up @@ -34,6 +35,39 @@ def get_child(self, df):
if self.decision_fun(df[self.attribute], self.threshold)
else self.branches[1]
)

def to_dict(self):
d = {}
d["branches"] = None
d["threshold"] = self.threshold
d["attribute"] = self.attribute
d["decision_fun"] = self.decision_fun
d["is_leaf"] = self.is_leaf
d["value"] = self.value
d["pinfo"] = self.pinfo
return d

def to_json(self,filename=None):
d = self.to_dict()
d["decision_fun"] = d["decision_fun"].__qualname__
if filename is None:
return json.dumps(d)
else:
with open(filename,"w") as f:
json.dump(d,f)

@classmethod
def from_dict(cls,d):
n = cls(branches=d.get("branches"),
attribute=d.get("attribute"),
threshold=d.get("threshold"),
value=d.get("value"),
decision_fun=d.get("decision_fun"),
parent=d.get("parent"),
)
n.pinfo = d.get("pinfo",{})
return n



class Tree:
Expand Down Expand Up @@ -76,3 +110,76 @@ def classes(self):
for n in nodes:
c.append(n.value)
return np.unique(c).tolist()

def to_dict(self):
return self._to_dict(self.root)

def _to_dict(self, node):
d = node.to_dict()
if not node.is_leaf:
d["branches"] = []
for b in node.branches:
d_ = self._to_dict(b)
d["branches"].append(d_)
return d


def to_json(self,filename=None):
d = self.to_dict()
self._replace_fun(d)
if filename is None:
return json.dumps(d)
else:
with open(filename,"w") as f:
json.dump(d,f)

def _replace_fun(self, d):
if "decision_fun" in d and d["decision_fun"] is not None:
d["decision_fun"] = d["decision_fun"].__qualname__.split(".")[-2]
if "branches" in d and d["branches"] is not None:
for b in d["branches"]:
self._replace_fun(b)

@classmethod
def from_dict(cls,d):
root = cls._from_dict(d)
return cls(root)

@staticmethod
def _from_dict(d):
# if the dict does not describe a leaf, process the branches first.
if not d["is_leaf"]:
branches = []
for b in d["branches"]:
branches.append(Tree._from_dict(b))
d["branches"] = branches
return Node.from_dict(d)

@classmethod
def from_json(cls,filename=None,string=None):
if not filename and not string:
raise ValueError("Either filename or a string has to be passed as argument to from_json.")

if filename is not None:
with open(filename, "r") as f:
d = json.load(f)
else:
d = json.loads(string)

cls._replace_str_with_fun(d)

tree = cls.from_dict(d)
return tree

@staticmethod
def _replace_str_with_fun(d):
if not d["is_leaf"]:
s = d["decision_fun"]
d["decision_fun"] = attribute_handler_factory[s].decide
for b in d["branches"]:
Tree._replace_str_with_fun(b)





43 changes: 43 additions & 0 deletions tests/test_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd

from binarybeech.binarybeech import CART, RandomForest
from binarybeech.tree import Tree


def test_iris_cart_create():
Expand Down Expand Up @@ -35,3 +36,45 @@ def test_iris_randomforest():
acc = val["accuracy"]
np.testing.assert_array_equal(p[:10], ["setosa"] * 10)
assert acc <= 1.0 and acc > 0.9


def test_iris_from_dict():
df_iris = pd.read_csv("data/iris.csv")
c = CART(df=df_iris, y_name="species", method="classification")
c.create_tree()

tree_dict = c.tree.to_dict()
assert isinstance(tree_dict, dict)

tree = Tree.from_dict(tree_dict)
assert isinstance(tree, Tree)
assert len(tree.nodes()) == 21
assert tree.leaf_count() == 11


c.tree = tree
p = c.predict(df_iris)
val = c.validate()
acc = val["accuracy"]
np.testing.assert_array_equal(p[:10], ["setosa"] * 10)
assert acc <= 1.0 and acc > 0.95

def test_iris_from_json():
df_iris = pd.read_csv("data/iris.csv")
c = CART(df=df_iris, y_name="species", method="classification", seed=42)
c.train()

tree_json = c.tree.to_json()
assert isinstance(tree_json, str)

tree = Tree.from_json(string=tree_json)
assert isinstance(tree, Tree)
assert len(tree.nodes()) == 5
assert tree.leaf_count() == 3

c.tree = tree
p = c.predict(df_iris)
val = c.validate()
acc = val["accuracy"]
np.testing.assert_array_equal(p[:10], ["setosa"] * 10)
assert acc <= 1.0 and acc > 0.95
16 changes: 16 additions & 0 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,19 @@ def test_tree_parent():
t = Tree(root=n0)
assert isinstance(t.traverse({"var": 0.0}).parent, Node)
assert len(t.leafs()) == 2

def test_tree_to_json():
def decfun(x, y):
return x < y
n1 = Node(value=1)
n2 = Node(value=2)
n0 = Node(
attribute="var",
threshold=0.5,
branches=[n1, n2],
decision_fun = decfun,
)
t = Tree(root=n0)

assert isinstance(t.to_json(),str)

0 comments on commit 1415c5b

Please sign in to comment.