Skip to content

Commit

Permalink
initial parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Feb 1, 2024
1 parent 6d5071c commit 8c517ee
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
29 changes: 24 additions & 5 deletions EvoMSA/back_prop.py
Expand Up @@ -16,6 +16,7 @@
from jax import nn
import jax.numpy as jnp
import numpy as np
from scipy.special import expit, softmax
from scipy.sparse import spmatrix
from jax.experimental.sparse import BCSR
from sklearn.model_selection import StratifiedShuffleSplit
Expand Down Expand Up @@ -65,19 +66,24 @@ def stack_model_binary(params, X, df):
return Y * pesos[0] + nn.sigmoid(df) * pesos[1] - 0.5


def initial_parameters(df, df2, y, score=None):
def initial_parameters(df, df2, y,
nclasses=2, score=None):
"""Estimate initial parameters :py:class:`~EvoMSA.back_prop.StackBoWBP`"""
from sklearn.metrics import f1_score
from scipy.special import softmax

def f(x):
hy = (x[0] * df + x[1] * df2).argmax(axis=1)
hy = (x[0] * df + x[1] * df2)
if nclasses ==2:
hy = np.where(hy > 0.5, 1, 0)
else:
hy = hy.argmax(axis=1)
return score(y, hy)

if score is None:
score = lambda y, hy: f1_score(y, hy, average='macro')
df = softmax(df, axis=1)
df2 = softmax(df2, axis=1)
# df = softmax(df, axis=1)
# df2 = softmax(df2, axis=1)
value = np.linspace(0, 1, 100)
_ = [f([v, 1-v]) for v in value]
index = np.argmax(_)
Expand Down Expand Up @@ -303,7 +309,20 @@ def model(self):

def initial_parameters(self, X, y, df):
params = super(StackBoWBP, self).initial_parameters(X, y)
params['E'] = jnp.array([0.5, 0.5])
dense_w = self.weights.T
dense_bias = self.bias
Xd = X @ dense_w + dense_bias
if self.classes_.shape[0] > 2:
y = y.argmax(axis=1)
df2 = self.train_predict_decision_function([1] * Xd.shape[0], y=y, X=Xd)
if self.classes_.shape[0] > 2:
df = expit(df)
df2 = expit(df2)
else:
df = softmax(df, axis=1)
df2 = softmax(df2, axis=1)
params['E'] = initial_parameters(df, df2, y,
nclasses=self.classes_.shape[0])
return params

def model_args(self, D: List[Union[dict, list]]):
Expand Down
7 changes: 5 additions & 2 deletions EvoMSA/tests/test_back_prop.py
Expand Up @@ -115,9 +115,12 @@ def test_DenseBoWBP():
assert len(hy) == len(D)


def test_StackBoWBP():
def test_StackBoWBP_initial_parameters():
"""Test StackBoWBP"""

dataset = list(tweet_iterator(TWEETS))
ins = StackBoWBP(lang='es', voc_size_exponent=13).fit(dataset)
assert 'E' in ins.parameters
assert np.fabs(ins.parameters['E'] - np.array([0.5, 0.5])).sum() > 0
D = [x for x in dataset if x['klass'] in {'N', 'P'}]
ins = StackBoWBP(lang='es', voc_size_exponent=13).fit(D)
assert np.fabs(ins.parameters['E'] - np.array([0.5, 0.5])).sum() > 0

0 comments on commit 8c517ee

Please sign in to comment.