Skip to content

Commit

Permalink
Default parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 30, 2024
1 parent 5341f63 commit 872f026
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
20 changes: 13 additions & 7 deletions EvoMSA/back_prop.py
Expand Up @@ -68,7 +68,7 @@ class BoWBP(BoW):

def __init__(self, voc_size_exponent: int=15,
estimator_kwargs=dict(dual=True, class_weight='balanced'),
deviation=None, fraction_initial_parameters=0.6,
deviation=None, fraction_initial_parameters=1,
optimizer_kwargs: dict=None,
**kwargs):
super(BoWBP, self).__init__(voc_size_exponent=voc_size_exponent,
Expand Down Expand Up @@ -118,9 +118,12 @@ def deviation(self, value):
def initial_parameters(self, X, y):
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
if train_size == 1:
tr = np.arange(X.shape[0])
else:
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
m = self.estimator_class(**self.estimator_kwargs).fit(X[tr], y[tr])
W = jnp.array(m.coef_.T)
W0 = jnp.array(m.intercept_)
Expand Down Expand Up @@ -210,9 +213,12 @@ def _transform(self, X):
def initial_parameters(self, X, y):
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
if train_size == 1:
tr = np.arange(X.shape[0])
else:
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
dense_w = self.weights.T
dense_bias = self.bias
_ = X[tr] @ dense_w + dense_bias
Expand Down
2 changes: 1 addition & 1 deletion EvoMSA/tests/test_back_prop.py
Expand Up @@ -32,7 +32,7 @@ def test_BoWBP():
class_weight='balanced'),
voc_size_exponent=15).fit(D)
bow2_coef = bow2.estimator_instance.coef_.T
bow = BoWBP(lang='es',
bow = BoWBP(lang='es', fraction_initial_parameters=0.6,
estimator_kwargs=dict(dual=True,
random_state=0,
class_weight='balanced')).fit(D)
Expand Down

0 comments on commit 872f026

Please sign in to comment.