Skip to content

Commit

Permalink
cleaning code
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 26, 2024
1 parent 435105d commit 3889b94
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions EvoMSA/back_prop.py
Expand Up @@ -163,10 +163,14 @@ def _combine_optimizer_kwargs(self):
optimizer_defaults.update(self.optimizer_kwargs)
return optimizer_defaults

def initial_parameters(self, D, y=None):
"""Compute the initial parameters"""
super(BoWBP, self).fit(D, y=y)

def fit(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None) -> 'BoWBP':
D, y = self.set_validation_set(D, y=y)
super(BoWBP, self).fit(D, y=y)
self.initial_parameters(D, y=y)
optimizer_kwargs = self._combine_optimizer_kwargs()
texts = self._transform(D)
labels = self.dependent_variable(D, y=y)
Expand Down Expand Up @@ -195,12 +199,11 @@ class DenseBoWBP(DenseBoW, BoWBP):
"""

def __init__(self, emoji: bool=True,
dataset: bool=True, keyword: bool=True,
dataset: bool=False, keyword: bool=True,
estimator_kwargs=dict(dual='auto', class_weight='balanced'),
validation_set=0, **kwargs):
**kwargs):
super(DenseBoWBP, self).__init__(emoji=emoji, dataset=dataset,
keyword=keyword,
validation_set=validation_set,
estimator_kwargs=estimator_kwargs,
**kwargs)

Expand Down Expand Up @@ -238,15 +241,6 @@ def parameters(self, value):
self.text_representations):
m.intercept = float(x)

def set_validation_set(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None):
"""Procedure to create the validation set"""

n_dim = len(self.text_representations)
if n_dim >= len(D) and self.validation_set == 0:
self.validation_set = None
return super(DenseBoWBP, self).set_validation_set(D=D, y=y)

def __sklearn_clone__(self):
ins = super(DenseBoWBP, self).__sklearn_clone__()
_ = [clone(m) for m in self.text_representations]
Expand Down

0 comments on commit 3889b94

Please sign in to comment.