Skip to content

Commit

Permalink
Merge pull request #44 from arminwitte/refactoring_loss_parameters
Browse files Browse the repository at this point in the history
loss parameters refactored
  • Loading branch information
arminwitte committed Sep 14, 2023
2 parents 33270d1 + b4cf38f commit b8f7717
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 109 deletions.
48 changes: 23 additions & 25 deletions binarybeech/attributehandler.py
Expand Up @@ -77,18 +77,18 @@ def split(self, df):
]
N = len(df.index)
n = [len(df_.index) for df_ in split_df]


loss_args = [{}, {}]
if "__weights__" in df:
w = [df_["__weights__"].values for df_ in split_df]
else:
w = [None for df_ in split_df]
loss_args = [{"weights":df_["__weights__"].values} for df_ in split_df]

val = [
self.metrics.node_value(df_[self.y_name], w[i])
self.metrics.node_value(df_[self.y_name], **loss_args[i])
for i, df_ in enumerate(split_df)
]
loss = n[0] / N * self.metrics.loss(
split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], w[1])
split_df[0][self.y_name], val[0], **loss_args[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], **loss_args[1])
if loss < self.loss:
success = True
self.loss = loss
Expand Down Expand Up @@ -162,17 +162,16 @@ def fun(x):
if min(n) == 0:
return np.Inf

loss_args = [{}, {}]
if "__weights__" in df:
w = [df_["__weights__"].values for df_ in split_df]
else:
w = [None for df_ in split_df]
w = [{"weights":df_["__weights__"].values} for df_ in split_df]
val = [
self.metrics.node_value(df_[self.y_name], w[i])
self.metrics.node_value(df_[self.y_name], **loss_args[i])
for i, df_ in enumerate(split_df)
]
return n[0] / N * self.metrics.loss(
split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], w[1])
split_df[0][self.y_name], val[0], **loss_args[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], **loss_args[1])

return fun

Expand Down Expand Up @@ -214,17 +213,17 @@ def split(self, df):
N = len(df.index)
n = [len(df_.index) for df_ in self.split_df]

loss_args = [{}, {}]
if "__weights__" in df:
w = [df_["__weights__"].values for df_ in self.split_df]
else:
w = [None for df_ in self.split_df]
loss_args = [{"weights":df_["__weights__"].values} for df_ in self.split_df]

val = [
self.metrics.node_value(df_[self.y_name], w[i])
self.metrics.node_value(df_[self.y_name], **loss_args[i])
for i, df_ in enumerate(self.split_df)
]
self.loss = n[0] / N * self.metrics.loss(
self.split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(self.split_df[1][self.y_name], val[1], w[1])
self.split_df[0][self.y_name], val[0], **loss_args[0]
) + n[1] / N * self.metrics.loss(self.split_df[1][self.y_name], val[1], **loss_args[1])

return success

Expand Down Expand Up @@ -295,17 +294,16 @@ def fun(x):
split_df = [df[df[split_name] < x], df[df[split_name] >= x]]
n = [len(df_.index) for df_ in split_df]

loss_args = [{}, {}]
if "__weights__" in df:
w = [df_["__weights__"].values for df_ in split_df]
else:
w = [None for df_ in split_df]
loss_args = [{"weights":df_["__weights__"].values} for df_ in split_df]
val = [
self.metrics.node_value(df_[self.y_name], w[i])
self.metrics.node_value(df_[self.y_name], **loss_args[i])
for i, df_ in enumerate(split_df)
]
return n[0] / N * self.metrics.loss(
split_df[0][self.y_name], val[0], w[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], w[1])
split_df[0][self.y_name], val[0], **loss_args[0]
) + n[1] / N * self.metrics.loss(split_df[1][self.y_name], val[1], **loss_args[1])

return fun

Expand Down
24 changes: 13 additions & 11 deletions binarybeech/binarybeech.py
Expand Up @@ -226,12 +226,12 @@ def create_tree(self, leaf_loss_threshold=1e-12):
def _node_or_leaf(self, df):
y = df[self.y_name]

loss_args = {}
if "__weights__" in df:
w = df["__weights__"].values
else:
w = None
y_hat = self.dmgr.metrics.node_value(y, w)
loss_parent = self.dmgr.metrics.loss(y, y_hat, w)
loss_args["weights"] = df["__weights__"].values

y_hat = self.dmgr.metrics.node_value(y, **loss_args)
loss_parent = self.dmgr.metrics.loss(y, y_hat, **loss_args)
# p = self._probability(df)
if (
loss_parent < self.leaf_loss_threshold
Expand Down Expand Up @@ -270,7 +270,8 @@ def _node_or_leaf(self, df):
decision_fun=self.dmgr[split_name].decide,
)
item.pinfo["N"] = len(df.index)
item.pinfo["r"] = self.dmgr.metrics.loss_prune(y, y_hat)
loss_args ={}
item.pinfo["r"] = self.dmgr.metrics.loss_prune(y, y_hat, **loss_args)
item.pinfo["R"] = (
item.pinfo["N"] / len(self.training_data.df.index) * item.pinfo["r"]
)
Expand All @@ -285,7 +286,8 @@ def _leaf(self, y, y_hat):
leaf = Node(value=y_hat)

leaf.pinfo["N"] = y.size
leaf.pinfo["r"] = self.dmgr.metrics.loss_prune(y, y_hat)
loss_args = {}
leaf.pinfo["r"] = self.dmgr.metrics.loss_prune(y, y_hat, **loss_args)
leaf.pinfo["R"] = (
leaf.pinfo["N"] / len(self.training_data.df.index) * leaf.pinfo["r"]
)
Expand Down Expand Up @@ -542,15 +544,15 @@ def _opt_fun(self, tree):
for i, x in enumerate(self.df.iloc):
delta[i] = tree.traverse(x).value
y = self.df[self.y_name].values

loss_args = {}
if "__weights__" in self.df:
w = self.df["__weights__"].values
else:
w = None
loss_args["weights"] = self.df["__weights__"].values

def fun(gamma):
y_ = y_hat + gamma * delta
p = self.dmgr.metrics.output_transform(y_)
return self.dmgr.metrics.loss(y, p, w)
return self.dmgr.metrics.loss(y, p, **loss_args)

return fun

Expand Down
32 changes: 4 additions & 28 deletions binarybeech/math.py
Expand Up @@ -17,7 +17,7 @@ def unique_weighted(x, w):
return np.array(u), np.array(c) / np.sum(c)


def gini_impurity_fast(x):
def gini_impurity(x):
unique, counts = np.unique(x, return_counts=True)
N = x.size
p = counts / N
Expand All @@ -29,13 +29,7 @@ def gini_impurity_weighted(x, w):
return 1.0 - np.sum(p**2)


def gini_impurity(x, w=None):
if w is None:
return gini_impurity_fast(x)
return gini_impurity_weighted(x, w)


def shannon_entropy_fast(x):
def shannon_entropy(x):
unique, counts = np.unique(x, return_counts=True)
N = x.size
p = counts / N
Expand All @@ -47,13 +41,7 @@ def shannon_entropy_weighted(x, w):
return -np.sum(p * np.log2(p))


def shannon_entropy(x, w=None):
if w is None:
return shannon_entropy_fast(x)
return shannon_entropy_weighted(x, w)


def misclassification_cost_fast(x):
def misclassification_cost(x):
unique, counts = np.unique(x, return_counts=True)
N = x.size
p = np.max(counts) / N
Expand All @@ -66,12 +54,6 @@ def misclassification_cost_weighted(x, w):
return 1.0 - p


def misclassification_cost(x, w=None):
if w is None:
return misclassification_cost_fast(x)
return misclassification_cost_weighted(x, w)


def logistic_loss(y, p):
p = np.clip(p, 1e-12, 1.0 - 1e-12)
return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))
Expand All @@ -94,13 +76,7 @@ def r_squared(y, y_hat):
return 1 - sse / sst


def majority_class(x, w=None):
if w is None:
return majority_class_fast(x)
return majority_class_weighted(x, w)


def majority_class_fast(x):
def majority_class(x):
unique, counts = np.unique(x, return_counts=True)
ind_max = np.argmax(counts)
return unique[ind_max]
Expand Down

0 comments on commit b8f7717

Please sign in to comment.