Skip to content

Commit

Permalink
Oob the right way
Browse files Browse the repository at this point in the history
  • Loading branch information
arminwitte committed Feb 11, 2023
1 parent b0b6cdf commit 49a89eb
Showing 1 changed file with 23 additions and 29 deletions.
52 changes: 23 additions & 29 deletions binarybeech/binarybeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def _handle_missings(self,df_in):
# use mean if numerical
for name in self.X_names:
if np.issubdtype(df_out[name].values.dtype, np.number):
df_out[name] = df_out[name].apply(lambda x: x.fillna(np.nanmean(df_out[name].values)))
df_out[name] = df_out[name].fillna(np.nanmean(df_out[name].values))
else:
df_out[name] = df_out[name].apply(lambda x: x.fillna("missing"))
df_out[name] = df_out[name].fillna("missing")
return df_out

def create_tree(self, leaf_loss_threshold=1e-12):
Expand Down Expand Up @@ -562,9 +562,9 @@ def __init__(self,df,y_name,X_names=None, verbose=False ,sample_frac=1, n_attrib
self.logger = logging.getLogger(__name__)

def create_trees(self,M):
df = self.df.sample(frac=self.sample_frac, replace=True)
self.trees = []
for i in range(M):
df = self.df.sample(frac=self.sample_frac, replace=True)
if self.n_attributes is None:
X_names = self.X_names
else:
Expand Down Expand Up @@ -595,39 +595,33 @@ def predict_all(self,df):
return y_hat

def validate_oob(self):
unique, names = self._oob_df()
self._oob_predict()
y_hat = []
for x in self.df.iloc:
votes = []
for n in names:
votes.append(x[n])
idx_max = np.argmax(votes)
y_hat.append(unique[idx_max])
return self.metrics.validate(y_hat, self.df)
df = self._oob_df()
df = self._oob_predict(df)
for index, row in df.iterrows():
if not row["votes"]:
continue
unique, counts = np.unique(row["votes"], return_counts=True)
idx_max = np.argmax(counts)
df.loc[index, "majority_vote"] = unique[idx_max]
df = df.astype({'majority_vote':'int'})
df = df.dropna(subset=["majority_vote"])
return self.metrics.validate(df["majority_vote"].values, df)

def _oob_predict(self):
for i, t in self.trees:
def _oob_predict(self,df):
for i, t in enumerate(self.trees):
idx = self.oob_indices[i]
for j in idx:
x = self.df.loc[j,:]
y = t.predict(x).value
name = self._oob_name(y)
self.df[name] += 1
df.loc[j]["votes"].append(y)
return df

def _oob_df(self):
y = self.df[self.y_name].values
N = y.size
unique = np.unique(y)
names = []
for u in unique:
name = self._oob_name(u)
names.append(name)
self.df[name] = 0
return unique, names

def _oob_name(self,value):
return "_".join(("oob",str(self.y_name),str(value)))
df = pd.DataFrame(index=self.df.index, dtype="object")
df[self.y_name] = self.df[self.y_name].values
df["votes"] = np.empty((len(df), 0)).tolist()
df["majority_vote"] = np.NaN
return df

def validate(self, df=None):
if df is None:
Expand Down

0 comments on commit 49a89eb

Please sign in to comment.