Skip to content

Commit

Permalink
cleaning code
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 25, 2024
1 parent 0efc818 commit 435105d
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions EvoMSA/back_prop.py
Expand Up @@ -18,7 +18,6 @@
from jax.experimental.sparse import BCSR
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.base import clone
from IngeoML.utils import Batches
from IngeoML.optimizer import classifier
from EvoMSA.text_repr import BoW, DenseBoW

Expand Down Expand Up @@ -57,34 +56,23 @@ def __init__(self, voc_size_exponent: int=15,
estimator_kwargs=dict(dual=True, class_weight='balanced'),
deviation=None,
validation_set=None,
batches=None,
optimizer_kwargs: dict=None,
**kwargs):
super(BoWBP, self).__init__(voc_size_exponent=voc_size_exponent,
estimator_kwargs=estimator_kwargs, **kwargs)
self.deviation = deviation
self.validation_set = validation_set
self.optimizer_kwargs = optimizer_kwargs
self.batches = batches

@property
def evolution(self):
"""Evolution of the objective-function value"""
return self._evolution

@evolution.setter
def evolution(self, value):
self._evolution = value

@property
def batches(self):
"""Instance to create the batches"""
return self._batches

@batches.setter
def batches(self, value):
self._batches = value

@property
def optimizer_kwargs(self):
"""Arguments for the optimizer"""
Expand All @@ -100,7 +88,7 @@ def optimizer_kwargs(self, value):
def validation_set(self):
"""Validation set"""
return self._validation_set

@validation_set.setter
def validation_set(self, value):
if value is None or value == 0:
Expand All @@ -122,11 +110,11 @@ def validation_set(self, value):
def deviation(self):
"""Function to measure the deviation between the true observations and the predictions."""
return self._deviation

@deviation.setter
def deviation(self, value):
self._deviation = value

@property
def parameters(self):
"""Parameter to optimize"""
Expand Down Expand Up @@ -170,19 +158,13 @@ def set_validation_set(self, D: List[Union[dict, list]],
def _combine_optimizer_kwargs(self):
decoder = self.estimator_instance.classes_
n_outputs = 1 if decoder.shape[0] == 2 else decoder.shape[0]
optimizer_defaults = dict(array=BCSR.from_scipy_sparse,
every_k_schedule=4, n_outputs=n_outputs,
epochs=100, learning_rate=1e-4,
return_evolution=True,
n_iter_no_change=5)
optimizer_defaults = dict(array=BCSR.from_scipy_sparse, n_outputs=n_outputs,
return_evolution=True)
optimizer_defaults.update(self.optimizer_kwargs)
return optimizer_defaults

def fit(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None) -> 'BoWBP':
if self.batches is None:
self.batches = Batches(size=512 if len(D) >= 2048 else 256,
random_state=0)
D, y = self.set_validation_set(D, y=y)
super(BoWBP, self).fit(D, y=y)
optimizer_kwargs = self._combine_optimizer_kwargs()
Expand All @@ -192,7 +174,7 @@ def fit(self, D: List[Union[dict, list]],
texts, labels,
deviation=self.deviation,
validation=self.validation_set,
batches=self.batches, **optimizer_kwargs)
**optimizer_kwargs)
if optimizer_kwargs['return_evolution']:
self.evolution = p[1]
p = p[0]
Expand Down Expand Up @@ -225,14 +207,14 @@ def __init__(self, emoji: bool=True,
@property
def model(self):
return dense_model

def _transform(self, X):
return self.bow.transform(X)

@property
def weights(self):
return np.array([x.coef for x in self.text_representations])

@property
def bias(self):
return np.array([x.intercept for x in self.text_representations])
Expand Down

0 comments on commit 435105d

Please sign in to comment.