From b5578b657a51cbc84bc5bb45422d83e8ec9939d8 Mon Sep 17 00:00:00 2001 From: arminwitte <110226001+arminwitte@users.noreply.github.com> Date: Thu, 9 Feb 2023 09:31:14 +0100 Subject: [PATCH] Cart with predict_all --- binarybeech/binarybeech.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/binarybeech/binarybeech.py b/binarybeech/binarybeech.py index f8461ae..0e9321a 100644 --- a/binarybeech/binarybeech.py +++ b/binarybeech/binarybeech.py @@ -112,6 +112,12 @@ def __init__(self,df,y_name,X_names=None,min_leaf_samples=1,min_split_samples=1, self.metrics = metrics_factory.create_metrics(metrics_type, self.y_name) self.logger = logging.getLogger(__name__) + + def predict_all(self,df): + y_hat = np.empty((len(df.index),)) + for i, x in enumerate(df.iloc): + y_hat[i] = self.tree.predict(x).value + return y_hat def train(self,k=5, plot=True, slack=1.): """ @@ -567,6 +573,7 @@ def create_trees(self,M): c = CART(df.sample(frac=self.sample_frac),self.y_name,X_names=X_names,**kwargs) c.create_tree() self.trees.append(c.tree) + print(f"{i:4d}: Tree with {c.tree.leaf_count()} leaves created.") def predict(self,x): y = []