Skip to content

Commit

Permalink
BoW backpropagation
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 30, 2024
1 parent 3889b94 commit bc26f1d
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 150 deletions.
2 changes: 1 addition & 1 deletion EvoMSA/__init__.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = '2.0.5'
__version__ = '2.0.6'

try:
from EvoMSA.text_repr import BoW, TextRepresentations, StackGeneralization, DenseBoW
Expand Down
200 changes: 114 additions & 86 deletions EvoMSA/back_prop.py
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
from typing import Union, List
import jax
from jax import nn
import jax.numpy as jnp
import numpy as np
from scipy.special import softmax
from jax.experimental.sparse import BCSR
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.base import clone
Expand All @@ -24,22 +26,34 @@

@jax.jit
def bow_model(params, X):
"""BoW model"""
"""BoWBP model"""

Y = X @ params['W_cl'] + params['W0_cl']
return Y


@jax.jit
def dense_model(params, X):
"""DenseBoW model"""
"""DenseBoWBP model"""

_ = X @ params['W'] + params['W0']
Y = _ / jnp.linalg.norm(_, axis=1, keepdims=True)
Y = Y @ params['W_cl'] + params['W0_cl']
return Y


@jax.jit
def stack_model(params, X, df):
"""StackBoWBP model"""

_ = X @ params['W'] + params['W0']
Y = _ / jnp.linalg.norm(_, axis=1, keepdims=True)
Y = Y @ params['W_cl'] + params['W0_cl']
Y = nn.softmax(Y, axis=1)
pesos = nn.softmax(params['E'])
return Y * pesos[0] + df * pesos[1]


class BoWBP(BoW):
"""BoWBP is a :py:class:`~EvoMSA.text_repr.BoW` with the difference that the parameters are fine-tuned using jax
Expand All @@ -54,15 +68,23 @@ class BoWBP(BoW):

def __init__(self, voc_size_exponent: int=15,
estimator_kwargs=dict(dual=True, class_weight='balanced'),
deviation=None,
validation_set=None,
deviation=None, fraction_initial_parameters=0.6,
optimizer_kwargs: dict=None,
**kwargs):
super(BoWBP, self).__init__(voc_size_exponent=voc_size_exponent,
estimator_kwargs=estimator_kwargs, **kwargs)
self.deviation = deviation
self.validation_set = validation_set
self.optimizer_kwargs = optimizer_kwargs
self.fraction_initial_parameters = fraction_initial_parameters

@property
def fraction_initial_parameters(self):
"""Fraction of the training set to estimate the initial parameters"""
return self._fraction_initial_parameters

@fraction_initial_parameters.setter
def fraction_initial_parameters(self, value):
self._fraction_initial_parameters = value

@property
def evolution(self):
Expand All @@ -84,28 +106,6 @@ def optimizer_kwargs(self, value):
value = {}
self._optimizer_kwargs = value

@property
def validation_set(self):
"""Validation set"""
return self._validation_set

@validation_set.setter
def validation_set(self, value):
if value is None or value == 0:
self._validation_set = value
return
if hasattr(value, 'split'):
self._validation_set = value
return
assert isinstance(value, list) and len(value)
if isinstance(value[0], dict):
y = self.dependent_variable(value)
X = self._transform(value)
self._validation_set = [X, y]
else:
X, y = value
self._validation_set = [self._transform(X), y]

@property
def deviation(self):
"""Function to measure the deviation between the true observations and the predictions."""
Expand All @@ -115,18 +115,25 @@ def deviation(self):
def deviation(self, value):
self._deviation = value

def initial_parameters(self, X, y):
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
m = self.estimator_class(**self.estimator_kwargs).fit(X[tr], y[tr])
W = jnp.array(m.coef_.T)
W0 = jnp.array(m.intercept_)
return dict(W_cl=W, W0_cl=W0)

@property
def parameters(self):
"""Parameter to optimize"""

W = jnp.array(self.estimator_instance.coef_.T)
W0 = jnp.array(self.estimator_instance.intercept_)
return dict(W_cl=W, W0_cl=W0)
return self._parameters

@parameters.setter
def parameters(self, value):
self.estimator_instance.coef_ = np.array(value['W_cl'].T)
self.estimator_instance.intercept_ = np.array(value['W0_cl'])
self._parameters = value

@property
def model(self):
Expand All @@ -135,55 +142,41 @@ def model(self):
def _transform(self, X):
return super(BoWBP, self).transform(X)

def set_validation_set(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None):
"""Procedure to create the validation set"""

if self.validation_set is None:
if len(D) < 2048:
test_size = 0.2
else:
test_size = 512
y = self.dependent_variable(D, y=y)
_ = StratifiedShuffleSplit(n_splits=1,
test_size=test_size).split(D, y)
tr, vs = next(_)
self.validation_set = [D[x] for x in vs]
D = [D[x] for x in tr]
y = y[tr]
elif self.validation_set == 0:
self.validation_set = None
return D, y

def _combine_optimizer_kwargs(self):
decoder = self.estimator_instance.classes_
n_outputs = 1 if decoder.shape[0] == 2 else decoder.shape[0]
optimizer_defaults = dict(array=BCSR.from_scipy_sparse, n_outputs=n_outputs,
return_evolution=True)
optimizer_defaults = dict(return_evolution=True)
optimizer_defaults.update(self.optimizer_kwargs)
return optimizer_defaults

def initial_parameters(self, D, y=None):
"""Compute the initial parameters"""
super(BoWBP, self).fit(D, y=y)
def model_args(self, D: List[Union[dict, list]]):
"""Extra arguments pass to the model"""
return None

def fit(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None) -> 'BoWBP':
D, y = self.set_validation_set(D, y=y)
self.initial_parameters(D, y=y)
optimizer_kwargs = self._combine_optimizer_kwargs()
texts = self._transform(D)
labels = self.dependent_variable(D, y=y)
p = classifier(self.parameters, self.model,
self.classes_ = np.unique(labels)
p = classifier(self.initial_parameters, self.model,
texts, labels,
deviation=self.deviation,
validation=self.validation_set,
model_args=self.model_args(D),
**optimizer_kwargs)
if optimizer_kwargs['return_evolution']:
self.evolution = p[1]
p = p[0]
self.parameters = p
return self

def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
X = self._transform(D)
params = self.parameters
hy = self.model(params, BCSR.from_scipy_sparse(X))
return hy

def predict(self, D: List[Union[dict, list]]) -> np.ndarray:
df = self.decision_function(D)
return self.classes_[df.argmax(axis=1)]


class DenseBoWBP(DenseBoW, BoWBP):
Expand Down Expand Up @@ -213,6 +206,21 @@ def model(self):

def _transform(self, X):
return self.bow.transform(X)

def initial_parameters(self, X, y):
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
dense_w = self.weights.T
dense_bias = self.bias
_ = X[tr] @ dense_w + dense_bias
m = self.estimator_class(**self.estimator_kwargs).fit(_, y[tr])
W = jnp.array(m.coef_.T)
W0 = jnp.array(m.intercept_)
return dict(W_cl=W, W0_cl=W0,
W=jnp.array(dense_w), W0=jnp.array(dense_bias))

@property
def weights(self):
Expand All @@ -222,27 +230,47 @@ def weights(self):
def bias(self):
return np.array([x.intercept for x in self.text_representations])

# def __sklearn_clone__(self):
# ins = super(DenseBoWBP, self).__sklearn_clone__()
# _ = [clone(m) for m in self.text_representations]
# ins.text_representations = _
# return ins


class StackBoWBP(DenseBoWBP):
@property
def parameters(self):
"""Parameters to optimize"""
_parameters = super(DenseBoWBP, self).parameters
_parameters['W'] = jnp.array(self.weights.T)
_parameters['W0'] = jnp.array(self.bias)
return _parameters
def model(self):
return stack_model

@parameters.setter
def parameters(self, value):
self.estimator_instance.coef_ = np.array(value['W_cl'].T)
self.estimator_instance.intercept_ = np.array(value['W0_cl'])
for x, m in zip(np.array(value['W'].T),
self.text_representations):
m.coef[:] = x[:]
for x, m in zip(np.array(value['W0']),
self.text_representations):
m.intercept = float(x)

def __sklearn_clone__(self):
ins = super(DenseBoWBP, self).__sklearn_clone__()
_ = [clone(m) for m in self.text_representations]
ins.text_representations = _
return ins
def initial_parameters(self, X, y, df):
params = super(StackBoWBP, self).initial_parameters(X, y)
params['E'] = jnp.array([0.5, 0.5])
return params

def model_args(self, D: List[Union[dict, list]]):
if not hasattr(self, '_bow_ins'):
hy = BoW.train_predict_decision_function(self, D)
else:
X = super(StackBoWBP, self)._transform(D)
hy = getattr(self._bow_ins, self.decision_function_name)(X)
if hy.ndim == 1:
hy = np.atleast_2d(hy).T
else:
hy = softmax(hy, axis=1)
return (hy, )

def fit(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None) -> "StackBoWBP":
super(StackBoWBP, self).fit(D, y=y)
_ = self._transform(D)
labels = self.dependent_variable(D, y=y)
self._bow_ins = self.estimator_class(**self.estimator_kwargs).fit(_, labels)
return self

def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
X = self._transform(D)
params = self.parameters
df = self.model_args(D)[0]
hy = self.model(params, BCSR.from_scipy_sparse(X),
jnp.array(df))
return hy

0 comments on commit bc26f1d

Please sign in to comment.