Skip to content

Commit

Permalink
Merge pull request #129 from INGEOTEC/develop
Browse files Browse the repository at this point in the history
Version - 2.0.6
  • Loading branch information
mgraffg committed Feb 1, 2024
2 parents bab04f0 + 30466ef commit 62f94a1
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 149 deletions.
2 changes: 1 addition & 1 deletion EvoMSA/__init__.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = '2.0.5'
__version__ = '2.0.6'

try:
from EvoMSA.text_repr import BoW, TextRepresentations, StackGeneralization, DenseBoW
Expand Down
232 changes: 145 additions & 87 deletions EvoMSA/back_prop.py
Expand Up @@ -13,33 +13,58 @@
# limitations under the License.
from typing import Union, List
import jax
from jax import nn
import jax.numpy as jnp
import numpy as np
from scipy.sparse import spmatrix
from jax.experimental.sparse import BCSR
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.base import clone
from IngeoML.optimizer import classifier
from EvoMSA.text_repr import BoW, DenseBoW


@jax.jit
def bow_model(params, X):
"""BoW model"""
"""BoWBP model"""

Y = X @ params['W_cl'] + params['W0_cl']
return Y


@jax.jit
def dense_model(params, X):
"""DenseBoW model"""
"""DenseBoWBP model"""

_ = X @ params['W'] + params['W0']
Y = _ / jnp.linalg.norm(_, axis=1, keepdims=True)
Y = Y @ params['W_cl'] + params['W0_cl']
return Y


@jax.jit
def stack_model(params, X, df):
"""StackBoWBP model"""

_ = X @ params['W'] + params['W0']
Y = _ / jnp.linalg.norm(_, axis=1, keepdims=True)
Y = Y @ params['W_cl'] + params['W0_cl']
Y = nn.softmax(Y, axis=1)
pesos = nn.softmax(params['E'])
return Y * pesos[0] + nn.softmax(df, axis=1) * pesos[1]


@jax.jit
def stack_model_binary(params, X, df):
"""StackBoWBP model"""

_ = X @ params['W'] + params['W0']
Y = _ / jnp.linalg.norm(_, axis=1, keepdims=True)
Y = Y @ params['W_cl'] + params['W0_cl']
Y = nn.sigmoid(Y)
pesos = nn.softmax(params['E'])
return Y * pesos[0] + nn.sigmoid(df) * pesos[1] - 0.5


class BoWBP(BoW):
"""BoWBP is a :py:class:`~EvoMSA.text_repr.BoW` with the difference that the parameters are fine-tuned using jax
Expand All @@ -54,15 +79,24 @@ class BoWBP(BoW):

def __init__(self, voc_size_exponent: int=15,
estimator_kwargs=dict(dual=True, class_weight='balanced'),
deviation=None,
validation_set=None,
deviation=None, fraction_initial_parameters=1,
optimizer_kwargs: dict=None,
**kwargs):
super(BoWBP, self).__init__(voc_size_exponent=voc_size_exponent,
estimator_kwargs=estimator_kwargs, **kwargs)
self.deviation = deviation
self.validation_set = validation_set
self.optimizer_kwargs = optimizer_kwargs
self.fraction_initial_parameters = fraction_initial_parameters
self.classes_ = None

@property
def fraction_initial_parameters(self):
"""Fraction of the training set to estimate the initial parameters"""
return self._fraction_initial_parameters

@fraction_initial_parameters.setter
def fraction_initial_parameters(self, value):
self._fraction_initial_parameters = value

@property
def evolution(self):
Expand All @@ -84,28 +118,6 @@ def optimizer_kwargs(self, value):
value = {}
self._optimizer_kwargs = value

@property
def validation_set(self):
"""Validation set"""
return self._validation_set

@validation_set.setter
def validation_set(self, value):
if value is None or value == 0:
self._validation_set = value
return
if hasattr(value, 'split'):
self._validation_set = value
return
assert isinstance(value, list) and len(value)
if isinstance(value[0], dict):
y = self.dependent_variable(value)
X = self._transform(value)
self._validation_set = [X, y]
else:
X, y = value
self._validation_set = [self._transform(X), y]

@property
def deviation(self):
"""Function to measure the deviation between the true observations and the predictions."""
Expand All @@ -115,18 +127,29 @@ def deviation(self):
def deviation(self, value):
self._deviation = value

def initial_parameters(self, X, y):
if y.ndim > 1:
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
if train_size == 1:
tr = np.arange(X.shape[0])
else:
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
m = self.estimator_class(**self.estimator_kwargs).fit(X[tr], y[tr])
W = jnp.array(m.coef_.T)
W0 = jnp.array(m.intercept_)
return dict(W_cl=W, W0_cl=W0)

@property
def parameters(self):
"""Parameter to optimize"""

W = jnp.array(self.estimator_instance.coef_.T)
W0 = jnp.array(self.estimator_instance.intercept_)
return dict(W_cl=W, W0_cl=W0)
return self._parameters

@parameters.setter
def parameters(self, value):
self.estimator_instance.coef_ = np.array(value['W_cl'].T)
self.estimator_instance.intercept_ = np.array(value['W0_cl'])
self._parameters = value

@property
def model(self):
Expand All @@ -135,56 +158,59 @@ def model(self):
def _transform(self, X):
return super(BoWBP, self).transform(X)

def set_validation_set(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None):
"""Procedure to create the validation set"""

if self.validation_set is None:
if len(D) < 2048:
test_size = 0.2
else:
test_size = 512
y = self.dependent_variable(D, y=y)
_ = StratifiedShuffleSplit(n_splits=1,
test_size=test_size).split(D, y)
tr, vs = next(_)
self.validation_set = [D[x] for x in vs]
D = [D[x] for x in tr]
y = y[tr]
elif self.validation_set == 0:
self.validation_set = None
return D, y

def _combine_optimizer_kwargs(self):
decoder = self.estimator_instance.classes_
n_outputs = 1 if decoder.shape[0] == 2 else decoder.shape[0]
optimizer_defaults = dict(array=BCSR.from_scipy_sparse, n_outputs=n_outputs,
return_evolution=True)
optimizer_defaults = dict(return_evolution=True)
optimizer_defaults.update(self.optimizer_kwargs)
return optimizer_defaults

def initial_parameters(self, D, y=None):
"""Compute the initial parameters"""
super(BoWBP, self).fit(D, y=y)
def model_args(self, D: List[Union[dict, list]]):
"""Extra arguments pass to the model"""
return None

def fit(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None) -> 'BoWBP':
D, y = self.set_validation_set(D, y=y)
self.initial_parameters(D, y=y)
optimizer_kwargs = self._combine_optimizer_kwargs()
texts = self._transform(D)
labels = self.dependent_variable(D, y=y)
p = classifier(self.parameters, self.model,
self.classes_ = np.unique(labels)
p = classifier(self.initial_parameters, self.model,
texts, labels,
deviation=self.deviation,
validation=self.validation_set,
model_args=self.model_args(D),
**optimizer_kwargs)
if optimizer_kwargs['return_evolution']:
self.evolution = p[1]
p = p[0]
self.parameters = p
return self

def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
X = self._transform(D)
params = self.parameters
args = self.model_args(D)
if args is None:
hy = self.model(params, BCSR.from_scipy_sparse(X))
else:
args = [self.array(x) for x in args]
hy = self.model(params, BCSR.from_scipy_sparse(X), *args)
return hy

def predict(self, D: List[Union[dict, list]]) -> np.ndarray:
df = self.decision_function(D)
if df.shape[1] == 1:
index = np.where(df.flatten() > 0, 1, 0)
else:
index = df.argmax(axis=1)
return self.classes_[index]

@staticmethod
def array(data):
"""Encode data on jax"""

if isinstance(data, spmatrix):
return BCSR.from_scipy_sparse(data)
return jnp.array(data)


class DenseBoWBP(DenseBoW, BoWBP):
"""DenseBoWBP is a :py:class:`~EvoMSA.text_repr.DenseBoW` with the difference that the parameters are fine-tuned using jax
Expand Down Expand Up @@ -214,6 +240,26 @@ def model(self):
def _transform(self, X):
return self.bow.transform(X)

def initial_parameters(self, X, y):
if y.ndim > 1:
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
if train_size == 1:
tr = np.arange(X.shape[0])
else:
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
tr, _ = next(_)
dense_w = self.weights.T
dense_bias = self.bias
_ = X[tr] @ dense_w + dense_bias
_ = _ / np.linalg.norm(_, axis=1, keepdims=True)
m = self.estimator_class(**self.estimator_kwargs).fit(_, y[tr])
W = jnp.array(m.coef_.T)
W0 = jnp.array(m.intercept_)
return dict(W_cl=W, W0_cl=W0,
W=jnp.array(dense_w), W0=jnp.array(dense_bias))

@property
def weights(self):
return np.array([x.coef for x in self.text_representations])
Expand All @@ -222,27 +268,39 @@ def weights(self):
def bias(self):
return np.array([x.intercept for x in self.text_representations])

# def __sklearn_clone__(self):
# ins = super(DenseBoWBP, self).__sklearn_clone__()
# _ = [clone(m) for m in self.text_representations]
# ins.text_representations = _
# return ins


class StackBoWBP(DenseBoWBP):
@property
def parameters(self):
"""Parameters to optimize"""
_parameters = super(DenseBoWBP, self).parameters
_parameters['W'] = jnp.array(self.weights.T)
_parameters['W0'] = jnp.array(self.bias)
return _parameters
def model(self):
if self.classes_.shape[0] == 2:
return stack_model_binary
return stack_model

def initial_parameters(self, X, y, df):
params = super(StackBoWBP, self).initial_parameters(X, y)
params['E'] = jnp.array([0.5, 0.5])
return params

def model_args(self, D: List[Union[dict, list]]):
if not hasattr(self, '_bow_ins'):
hy = BoW.train_predict_decision_function(self, D)
else:
X = super(StackBoWBP, self)._transform(D)
hy = getattr(self._bow_ins, self.decision_function_name)(X)
if hy.ndim == 1:
hy = np.atleast_2d(hy).T
return (hy, )

@parameters.setter
def parameters(self, value):
self.estimator_instance.coef_ = np.array(value['W_cl'].T)
self.estimator_instance.intercept_ = np.array(value['W0_cl'])
for x, m in zip(np.array(value['W'].T),
self.text_representations):
m.coef[:] = x[:]
for x, m in zip(np.array(value['W0']),
self.text_representations):
m.intercept = float(x)

def __sklearn_clone__(self):
ins = super(DenseBoWBP, self).__sklearn_clone__()
_ = [clone(m) for m in self.text_representations]
ins.text_representations = _
return ins
def fit(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None) -> "StackBoWBP":
super(StackBoWBP, self).fit(D, y=y)
_ = self._transform(D)
labels = self.dependent_variable(D, y=y)
self._bow_ins = self.estimator_class(**self.estimator_kwargs).fit(_, labels)
return self

0 comments on commit 62f94a1

Please sign in to comment.