Skip to content

Commit

Permalink
Cart with predict_all
Browse files Browse the repository at this point in the history
  • Loading branch information
arminwitte committed Feb 9, 2023
1 parent fbc7dd2 commit b5578b6
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions binarybeech/binarybeech.py
Expand Up @@ -112,6 +112,12 @@ def __init__(self,df,y_name,X_names=None,min_leaf_samples=1,min_split_samples=1,
self.metrics = metrics_factory.create_metrics(metrics_type, self.y_name)

self.logger = logging.getLogger(__name__)

def predict_all(self,df):
y_hat = np.empty((len(df.index),))
for i, x in enumerate(df.iloc):
y_hat[i] = self.tree.predict(x).value
return y_hat

def train(self,k=5, plot=True, slack=1.):
"""
Expand Down Expand Up @@ -567,6 +573,7 @@ def create_trees(self,M):
c = CART(df.sample(frac=self.sample_frac),self.y_name,X_names=X_names,**kwargs)
c.create_tree()
self.trees.append(c.tree)
print(f"{i:4d}: Tree with {c.tree.leaf_count()} leaves created.")

def predict(self,x):
y = []
Expand Down

0 comments on commit b5578b6

Please sign in to comment.