Skip to content

Commit

Permalink
Merge pull request #45 from arminwitte/metrics_regularisation
Browse files Browse the repository at this point in the history
Metrics regularisation
  • Loading branch information
arminwitte committed Sep 16, 2023
2 parents b8f7717 + 301c51e commit 028fcb9
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 17 deletions.
33 changes: 23 additions & 10 deletions binarybeech/attributehandler.py
Expand Up @@ -78,9 +78,12 @@ def split(self, df):
N = len(df.index)
n = [len(df_.index) for df_ in split_df]

loss_args = [{}, {}]
loss_args = {key: self.algorithm_kwargs[key] for key in ["lambda_l1", "lambda_l2"]}
loss_args = [loss_args.copy(), loss_args.copy()]
if "__weights__" in df:
loss_args = [{"weights":df_["__weights__"].values} for df_ in split_df]
for i, df_ in enumerate(split_df):
loss_args[i]["weights"] = df_["__weights__"].values


val = [
self.metrics.node_value(df_[self.y_name], **loss_args[i])
Expand Down Expand Up @@ -162,9 +165,12 @@ def fun(x):
if min(n) == 0:
return np.Inf

loss_args = [{}, {}]
loss_args = {key: self.algorithm_kwargs[key] for key in ["lambda_l1", "lambda_l2"]}
loss_args = [loss_args.copy(), loss_args.copy()]
if "__weights__" in df:
w = [{"weights":df_["__weights__"].values} for df_ in split_df]
for i, df_ in enumerate(split_df):
loss_args[i]["weights"] = df_["__weights__"].values

val = [
self.metrics.node_value(df_[self.y_name], **loss_args[i])
for i, df_ in enumerate(split_df)
Expand Down Expand Up @@ -212,10 +218,13 @@ def split(self, df):
]
N = len(df.index)
n = [len(df_.index) for df_ in self.split_df]

loss_args = [{}, {}]

loss_args = {key: self.algorithm_kwargs[key] for key in ["lambda_l1", "lambda_l2"]}
loss_args = [loss_args.copy(), loss_args.copy()]
if "__weights__" in df:
loss_args = [{"weights":df_["__weights__"].values} for df_ in self.split_df]
for i, df_ in enumerate(self.split_df):
loss_args[i]["weights"] = df_["__weights__"].values


val = [
self.metrics.node_value(df_[self.y_name], **loss_args[i])
Expand Down Expand Up @@ -293,10 +302,14 @@ def _opt_fun(self, df):
def fun(x):
split_df = [df[df[split_name] < x], df[df[split_name] >= x]]
n = [len(df_.index) for df_ in split_df]

loss_args = [{}, {}]


loss_args = {key: self.algorithm_kwargs[key] for key in ["lambda_l1", "lambda_l2"]}
loss_args = [loss_args.copy(), loss_args.copy()]
if "__weights__" in df:
loss_args = [{"weights":df_["__weights__"].values} for df_ in split_df]
for i, df_ in enumerate(split_df):
loss_args[i]["weights"] = df_["__weights__"].values

val = [
self.metrics.node_value(df_[self.y_name], **loss_args[i])
for i, df_ in enumerate(split_df)
Expand Down
21 changes: 17 additions & 4 deletions binarybeech/binarybeech.py
Expand Up @@ -100,12 +100,15 @@ def __init__(
min_split_samples=1,
max_depth=10,
min_split_loss = 0.,
lambda_l1 = 0.,
lambda_l2 = 0.,
method="regression",
handle_missings="simple",
attribute_handlers=None,
seed=None,
algorithm_kwargs={},
):
algorithm_kwargs.update(locals())
super().__init__(
training_data,
df,
Expand All @@ -125,6 +128,7 @@ def __init__(
self.max_depth = max_depth
self.min_split_loss = min_split_loss


self.depth = 0
self.seed = seed

Expand Down Expand Up @@ -226,7 +230,7 @@ def create_tree(self, leaf_loss_threshold=1e-12):
def _node_or_leaf(self, df):
y = df[self.y_name]

loss_args = {}
loss_args = {key:self.algorithm_kwargs[key] for key in ["lambda_l1", "lambda_l2"]}
if "__weights__" in df:
loss_args["weights"] = df["__weights__"].values

Expand Down Expand Up @@ -270,7 +274,7 @@ def _node_or_leaf(self, df):
decision_fun=self.dmgr[split_name].decide,
)
item.pinfo["N"] = len(df.index)
loss_args ={}
loss_args = {key:self.algorithm_kwargs[key] for key in ["lambda_l1", "lambda_l2"]}
item.pinfo["r"] = self.dmgr.metrics.loss_prune(y, y_hat, **loss_args)
item.pinfo["R"] = (
item.pinfo["N"] / len(self.training_data.df.index) * item.pinfo["r"]
Expand All @@ -286,7 +290,7 @@ def _leaf(self, y, y_hat):
leaf = Node(value=y_hat)

leaf.pinfo["N"] = y.size
loss_args = {}
loss_args = {key:self.algorithm_kwargs[key] for key in ["lambda_l1","lambda_l2"]}
leaf.pinfo["r"] = self.dmgr.metrics.loss_prune(y, y_hat, **loss_args)
leaf.pinfo["R"] = (
leaf.pinfo["N"] / len(self.training_data.df.index) * leaf.pinfo["r"]
Expand Down Expand Up @@ -402,6 +406,8 @@ def __init__(
sample_frac=1,
n_attributes=None,
learning_rate=0.1,
lambda_l1 = 0.,
lambda_l2 = 0.,
cart_settings={},
init_method="logistic",
gamma=None,
Expand All @@ -410,6 +416,7 @@ def __init__(
seed=None,
algorithm_kwargs={},
):
algorithm_kwargs.update(locals())
super().__init__(
training_data,
df,
Expand Down Expand Up @@ -545,7 +552,7 @@ def _opt_fun(self, tree):
delta[i] = tree.traverse(x).value
y = self.df[self.y_name].values

loss_args = {}
loss_args = {key:self.algorithm_kwargs[key] for key in ["lambda_l1", "lambda_l2"]}
if "__weights__" in self.df:
loss_args["weights"] = self.df["__weights__"].values

Expand Down Expand Up @@ -641,13 +648,16 @@ def __init__(
X_names=None,
sample_frac=1,
n_attributes=None,
lambda_l1 = 0.,
lambda_l2 = 0.,
cart_settings={},
method="classification",
handle_missings="simple",
attribute_handlers=None,
seed=None,
algorithm_kwargs={},
):
algorithm_kwargs.update(locals())
super().__init__(
training_data,
df,
Expand Down Expand Up @@ -803,13 +813,16 @@ def __init__(
verbose=False,
sample_frac=1,
n_attributes=None,
lambda_l1 = 0.,
lambda_l2 = 0.,
cart_settings={},
method="regression",
handle_missings="simple",
attribute_handlers=None,
seed=None,
algorithm_kwargs={},
):
algorithm_kwargs.update(locals())
super().__init__(
training_data,
df,
Expand Down
27 changes: 25 additions & 2 deletions binarybeech/metrics.py
Expand Up @@ -138,6 +138,27 @@ def bins(self, df, y_name, attribute):
@staticmethod
def check(x):
return math.check_interval(x)


class RegressionMetricsRegularized(RegressionMetrics):
def __init__(self):
super().__init__()

def node_value(self, y, **kwargs):
y = np.array(y).ravel()
n = y.shape[0]
lambda_l1 = kwargs.get("lambda_l1")
lambda_l2 = kwargs.get("lambda_l2")
y_sum = np.sum(y)

if y_sum < -lambda_l1:
return (y_sum + lambda_l1)/(n + lambda_l2)
elif y_sum > lambda_l1:
return (y_sum - lambda_l1)/(n + lambda_l2)
else:
return 0.




class LogisticMetrics(Metrics):
Expand Down Expand Up @@ -237,8 +258,9 @@ def loss(self, y, y_hat, **kwargs):

def loss_prune(self, y, y_hat, **kwargs):
# Implementation of the loss pruning calculation for classification
if "weights" in kwargs.keys():
return math.misclassification_cost_weighted(y, kwargs["weights"])
# if "weights" in kwargs.keys():
# print(len(y), len(y_hat), len(kwargs["weights"]))
# return math.misclassification_cost_weighted(y, kwargs["weights"])
return math.misclassification_cost(y)

def node_value(self, y, **kwargs):
Expand Down Expand Up @@ -376,6 +398,7 @@ def from_data(self, y, algorithm_kwargs):

metrics_factory = MetricsFactory()
metrics_factory.register("regression", RegressionMetrics)
metrics_factory.register("regression:regularized", RegressionMetrics)
metrics_factory.register("classification:gini", ClassificationMetrics)
metrics_factory.register("classification:entropy", ClassificationMetricsEntropy)
metrics_factory.register("logistic", LogisticMetrics)
Expand Down
98 changes: 98 additions & 0 deletions data/prostate.data
@@ -0,0 +1,98 @@
lcavol lweight age lbph svi lcp gleason pgg45 lpsa train
1 -0.579818495 2.769459 50 -1.38629436 0 -1.38629436 6 0 -0.4307829 T
2 -0.994252273 3.319626 58 -1.38629436 0 -1.38629436 6 0 -0.1625189 T
3 -0.510825624 2.691243 74 -1.38629436 0 -1.38629436 7 20 -0.1625189 T
4 -1.203972804 3.282789 58 -1.38629436 0 -1.38629436 6 0 -0.1625189 T
5 0.751416089 3.432373 62 -1.38629436 0 -1.38629436 6 0 0.3715636 T
6 -1.049822124 3.228826 50 -1.38629436 0 -1.38629436 6 0 0.7654678 T
7 0.737164066 3.473518 64 0.61518564 0 -1.38629436 6 0 0.7654678 F
8 0.693147181 3.539509 58 1.53686722 0 -1.38629436 6 0 0.8544153 T
9 -0.776528789 3.539509 47 -1.38629436 0 -1.38629436 6 0 1.0473190 F
10 0.223143551 3.244544 63 -1.38629436 0 -1.38629436 6 0 1.0473190 F
11 0.254642218 3.604138 65 -1.38629436 0 -1.38629436 6 0 1.2669476 T
12 -1.347073648 3.598681 63 1.26694760 0 -1.38629436 6 0 1.2669476 T
13 1.613429934 3.022861 63 -1.38629436 0 -0.59783700 7 30 1.2669476 T
14 1.477048724 2.998229 67 -1.38629436 0 -1.38629436 7 5 1.3480731 T
15 1.205970807 3.442019 57 -1.38629436 0 -0.43078292 7 5 1.3987169 F
16 1.541159072 3.061052 66 -1.38629436 0 -1.38629436 6 0 1.4469190 T
17 -0.415515444 3.516013 70 1.24415459 0 -0.59783700 7 30 1.4701758 T
18 2.288486169 3.649359 66 -1.38629436 0 0.37156356 6 0 1.4929041 T
19 -0.562118918 3.267666 41 -1.38629436 0 -1.38629436 6 0 1.5581446 T
20 0.182321557 3.825375 70 1.65822808 0 -1.38629436 6 0 1.5993876 T
21 1.147402453 3.419365 59 -1.38629436 0 -1.38629436 6 0 1.6389967 T
22 2.059238834 3.501043 60 1.47476301 0 1.34807315 7 20 1.6582281 F
23 -0.544727175 3.375880 59 -0.79850770 0 -1.38629436 6 0 1.6956156 T
24 1.781709133 3.451574 63 0.43825493 0 1.17865500 7 60 1.7137979 T
25 0.385262401 3.667400 69 1.59938758 0 -1.38629436 6 0 1.7316555 F
26 1.446918983 3.124565 68 0.30010459 0 -1.38629436 6 0 1.7664417 F
27 0.512823626 3.719651 65 -1.38629436 0 -0.79850770 7 70 1.8000583 T
28 -0.400477567 3.865979 67 1.81645208 0 -1.38629436 7 20 1.8164521 F
29 1.040276712 3.128951 67 0.22314355 0 0.04879016 7 80 1.8484548 T
30 2.409644165 3.375880 65 -1.38629436 0 1.61938824 6 0 1.8946169 T
31 0.285178942 4.090169 65 1.96290773 0 -0.79850770 6 0 1.9242487 T
32 0.182321557 3.804438 65 1.70474809 0 -1.38629436 6 0 2.0082140 F
33 1.275362800 3.037354 71 1.26694760 0 -1.38629436 6 0 2.0082140 T
34 0.009950331 3.267666 54 -1.38629436 0 -1.38629436 6 0 2.0215476 F
35 -0.010050336 3.216874 63 -1.38629436 0 -0.79850770 6 0 2.0476928 T
36 1.308332820 4.119850 64 2.17133681 0 -1.38629436 7 5 2.0856721 F
37 1.423108334 3.657131 73 -0.57981850 0 1.65822808 8 15 2.1575593 T
38 0.457424847 2.374906 64 -1.38629436 0 -1.38629436 7 15 2.1916535 T
39 2.660958594 4.085136 68 1.37371558 1 1.83258146 7 35 2.2137539 T
40 0.797507196 3.013081 56 0.93609336 0 -0.16251893 7 5 2.2772673 T
41 0.620576488 3.141995 60 -1.38629436 0 -1.38629436 9 80 2.2975726 T
42 1.442201993 3.682610 68 -1.38629436 0 -1.38629436 7 10 2.3075726 F
43 0.582215620 3.865979 62 1.71379793 0 -0.43078292 6 0 2.3272777 T
44 1.771556762 3.896909 61 -1.38629436 0 0.81093022 7 6 2.3749058 F
45 1.486139696 3.409496 66 1.74919985 0 -0.43078292 7 20 2.5217206 T
46 1.663926098 3.392829 61 0.61518564 0 -1.38629436 7 15 2.5533438 T
47 2.727852828 3.995445 79 1.87946505 1 2.65675691 9 100 2.5687881 T
48 1.163150810 4.035125 68 1.71379793 0 -0.43078292 7 40 2.5687881 F
49 1.745715531 3.498022 43 -1.38629436 0 -1.38629436 6 0 2.5915164 F
50 1.220829921 3.568123 70 1.37371558 0 -0.79850770 6 0 2.5915164 F
51 1.091923301 3.993603 68 -1.38629436 0 -1.38629436 7 50 2.6567569 T
52 1.660131027 4.234831 64 2.07317193 0 -1.38629436 6 0 2.6775910 T
53 0.512823626 3.633631 64 1.49290410 0 0.04879016 7 70 2.6844403 F
54 2.127040520 4.121473 68 1.76644166 0 1.44691898 7 40 2.6912431 F
55 3.153590358 3.516013 59 -1.38629436 0 -1.38629436 7 5 2.7047113 F
56 1.266947603 4.280132 66 2.12226154 0 -1.38629436 7 15 2.7180005 T
57 0.974559640 2.865054 47 -1.38629436 0 0.50077529 7 4 2.7880929 F
58 0.463734016 3.764682 49 1.42310833 0 -1.38629436 6 0 2.7942279 T
59 0.542324291 4.178226 70 0.43825493 0 -1.38629436 7 20 2.8063861 T
60 1.061256502 3.851211 61 1.29472717 0 -1.38629436 7 40 2.8124102 T
61 0.457424847 4.524502 73 2.32630162 0 -1.38629436 6 0 2.8419982 T
62 1.997417706 3.719651 63 1.61938824 1 1.90954250 7 40 2.8535925 F
63 2.775708850 3.524889 72 -1.38629436 0 1.55814462 9 95 2.8535925 T
64 2.034705648 3.917011 66 2.00821403 1 2.11021320 7 60 2.8820035 F
65 2.073171929 3.623007 64 -1.38629436 0 -1.38629436 6 0 2.8820035 F
66 1.458615023 3.836221 61 1.32175584 0 -0.43078292 7 20 2.8875901 F
67 2.022871190 3.878466 68 1.78339122 0 1.32175584 7 70 2.9204698 T
68 2.198335072 4.050915 72 2.30757263 0 -0.43078292 7 10 2.9626924 T
69 -0.446287103 4.408547 69 -1.38629436 0 -1.38629436 6 0 2.9626924 T
70 1.193922468 4.780383 72 2.32630162 0 -0.79850770 7 5 2.9729753 T
71 1.864080131 3.593194 60 -1.38629436 1 1.32175584 7 60 3.0130809 T
72 1.160020917 3.341093 77 1.74919985 0 -1.38629436 7 25 3.0373539 T
73 1.214912744 3.825375 69 -1.38629436 1 0.22314355 7 20 3.0563569 F
74 1.838961071 3.236716 60 0.43825493 1 1.17865500 9 90 3.0750055 F
75 2.999226163 3.849083 69 -1.38629436 1 1.90954250 7 20 3.2752562 T
76 3.141130476 3.263849 68 -0.05129329 1 2.42036813 7 50 3.3375474 T
77 2.010894999 4.433789 72 2.12226154 0 0.50077529 7 60 3.3928291 T
78 2.537657215 4.354784 78 2.32630162 0 -1.38629436 7 10 3.4355988 T
79 2.648300197 3.582129 69 -1.38629436 1 2.58399755 7 70 3.4578927 T
80 2.779440197 3.823192 63 -1.38629436 0 0.37156356 7 50 3.5130369 F
81 1.467874348 3.070376 66 0.55961579 0 0.22314355 7 40 3.5160131 T
82 2.513656063 3.473518 57 0.43825493 0 2.32727771 7 60 3.5307626 T
83 2.613006652 3.888754 77 -0.52763274 1 0.55961579 7 30 3.5652984 T
84 2.677590994 3.838376 65 1.11514159 0 1.74919985 9 70 3.5709402 F
85 1.562346305 3.709907 60 1.69561561 0 0.81093022 7 30 3.5876769 T
86 3.302849259 3.518980 64 -1.38629436 1 2.32727771 7 60 3.6309855 T
87 2.024193067 3.731699 58 1.63899671 0 -1.38629436 6 0 3.6800909 T
88 1.731655545 3.369018 62 -1.38629436 1 0.30010459 7 30 3.7123518 T
89 2.807593831 4.718052 65 -1.38629436 1 2.46385324 7 60 3.9843437 T
90 1.562346305 3.695110 76 0.93609336 1 0.81093022 7 75 3.9936030 T
91 3.246490992 4.101817 68 -1.38629436 0 -1.38629436 6 0 4.0298060 T
92 2.532902848 3.677566 61 1.34807315 1 -1.38629436 7 15 4.1295508 T
93 2.830267834 3.876396 68 -1.38629436 1 1.32175584 7 60 4.3851468 T
94 3.821003607 3.896909 44 -1.38629436 1 2.16905370 7 40 4.6844434 T
95 2.907447359 3.396185 52 -1.38629436 1 2.46385324 7 10 5.1431245 F
96 2.882563575 3.773910 68 1.55814462 1 1.55814462 7 80 5.4775090 T
97 3.471966453 3.974998 68 0.43825493 1 2.90416508 7 20 5.5829322 F
2 changes: 1 addition & 1 deletion tests/test_adaboost.py
Expand Up @@ -13,7 +13,7 @@ def test_adaboost_iris():
val = c.validate()
acc = val["accuracy"]
np.testing.assert_array_equal(p[:10], ["setosa"] * 10)
assert acc <= 1.0 and acc > 0.98
assert acc <= 1.0 and acc > 0.97


def test_adaboost_titanic():
Expand Down
1 change: 1 addition & 0 deletions tests/test_datamanager.py
Expand Up @@ -11,6 +11,7 @@ def test_datamanager_info():
assert ah == ["default", "clustering"]
assert m == [
"regression",
"regression:regularized",
"classification:gini",
"classification:entropy",
"logistic",
Expand Down

0 comments on commit 028fcb9

Please sign in to comment.