Skip to content

Commit

Permalink
Merge pull request #130 from INGEOTEC/develop
Browse files Browse the repository at this point in the history
Version - 2.0.7
  • Loading branch information
mgraffg committed Feb 22, 2024
2 parents 62f94a1 + dc95014 commit 478f352
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 24 deletions.
2 changes: 1 addition & 1 deletion EvoMSA/__init__.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = '2.0.6'
__version__ = '2.0.7'

try:
from EvoMSA.text_repr import BoW, TextRepresentations, StackGeneralization, DenseBoW
Expand Down
237 changes: 220 additions & 17 deletions EvoMSA/back_prop.py
Expand Up @@ -15,12 +15,14 @@
import jax
from jax import nn
import jax.numpy as jnp
import numpy as np
from scipy.sparse import spmatrix
from jax.experimental.sparse import BCSR
import numpy as np
from scipy.special import expit, softmax
from sklearn.model_selection import StratifiedShuffleSplit
from IngeoML.optimizer import classifier
from EvoMSA.text_repr import BoW, DenseBoW
from sklearn.base import clone
from IngeoML.optimizer import classifier, array
from IngeoML.utils import soft_BER
from EvoMSA.text_repr import BoW, DenseBoW, StackGeneralization


@jax.jit
Expand Down Expand Up @@ -65,6 +67,50 @@ def stack_model_binary(params, X, df):
return Y * pesos[0] + nn.sigmoid(df) * pesos[1] - 0.5


def initial_parameters(hy_dense, df, y,
nclasses=2, score=None):
"""Estimate initial parameters :py:class:`~EvoMSA.back_prop.StackBoWBP`"""
from sklearn.metrics import f1_score

def f(x):
hy = (x[0] * hy_dense + x[1] * df)
if nclasses ==2:
hy = np.where(hy > 0.5, 1, 0)
else:
hy = hy.argmax(axis=1)
return score(y, hy)

if score is None:
score = lambda y, hy: f1_score(y, hy, average='macro')
# df = softmax(df, axis=1)
# df2 = softmax(df2, axis=1)
value = np.linspace(0, 1, 100)
_ = [f([v, 1-v]) for v in value]
index = np.argmax(_)
return jnp.array([value[index], 1 - value[index]])


@jax.jit
def stackbow(params, X, X2):
mixer = nn.sigmoid(params['mixer'])
frst = X * mixer
scnd = X2 * (1 - mixer)
return frst + scnd


@jax.jit
def stackbow_b_k(params, X):
mixer = nn.softmax(params['mixer'])
return X @ mixer


@jax.jit
def stackbow_m_k(params, X):
mixer = nn.softmax(params['mixer'], axis=1)
hy = X * mixer
return hy.sum(axis=-1)


class BoWBP(BoW):
"""BoWBP is a :py:class:`~EvoMSA.text_repr.BoW` with the difference that the parameters are fine-tuned using jax
Expand Down Expand Up @@ -135,7 +181,7 @@ def initial_parameters(self, X, y):
tr = np.arange(X.shape[0])
else:
_ = StratifiedShuffleSplit(n_splits=1,
train_size=train_size).split(X, y)
train_size=train_size).split(X, y)
tr, _ = next(_)
m = self.estimator_class(**self.estimator_kwargs).fit(X[tr], y[tr])
W = jnp.array(m.coef_.T)
Expand Down Expand Up @@ -191,7 +237,7 @@ def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
if args is None:
hy = self.model(params, BCSR.from_scipy_sparse(X))
else:
args = [self.array(x) for x in args]
args = [array(x) for x in args]
hy = self.model(params, BCSR.from_scipy_sparse(X), *args)
return hy

Expand All @@ -203,14 +249,6 @@ def predict(self, D: List[Union[dict, list]]) -> np.ndarray:
index = df.argmax(axis=1)
return self.classes_[index]

@staticmethod
def array(data):
"""Encode data on jax"""

if isinstance(data, spmatrix):
return BCSR.from_scipy_sparse(data)
return jnp.array(data)


class DenseBoWBP(DenseBoW, BoWBP):
"""DenseBoWBP is a :py:class:`~EvoMSA.text_repr.DenseBoW` with the difference that the parameters are fine-tuned using jax
Expand Down Expand Up @@ -258,7 +296,7 @@ def initial_parameters(self, X, y):
W = jnp.array(m.coef_.T)
W0 = jnp.array(m.intercept_)
return dict(W_cl=W, W0_cl=W0,
W=jnp.array(dense_w), W0=jnp.array(dense_bias))
W=jnp.array(dense_w), W0=jnp.array(dense_bias))

@property
def weights(self):
Expand All @@ -275,6 +313,157 @@ def bias(self):
# return ins


class StackBoW(StackGeneralization):
"""StackBoW"""

def __init__(self, decision_function_models: list=None,
transform_models: list=[],
deviation=None, optimizer_kwargs: dict=None,
lang: str='es', **kwargs):
if decision_function_models is None:
estimator_kwargs = dict(dual='auto', class_weight='balanced')
bow_np = BoW(lang=lang, pretrain=False,
estimator_kwargs=estimator_kwargs)
bow = BoW(lang=lang,
estimator_kwargs=estimator_kwargs)
dense = DenseBoW(lang=lang, voc_size_exponent=15,
estimator_kwargs=estimator_kwargs)
decision_function_models = [bow_np, bow, dense]
assert len(decision_function_models) > 1
assert len(transform_models) == 0
super().__init__(decision_function_models=decision_function_models,
**kwargs)
self.deviation = deviation
self.optimizer_kwargs = optimizer_kwargs
self.classes_ = None
self._mixer_value = None

@property
def optimizer_kwargs(self):
"""Arguments for the optimizer"""
return self._optimizer_kwargs

@optimizer_kwargs.setter
def optimizer_kwargs(self, value):
if value is None:
value = {}
self._optimizer_kwargs = value

@property
def mixer_value(self):
"""Contribution of each classifier to the prediction"""
return self._mixer_value

@property
def deviation(self):
"""Function to measure the deviation between the true observations and the predictions."""
return self._deviation

@deviation.setter
def deviation(self, value):
self._deviation = value

def _fit_bin_2(self, dfs, y):
"""Fit a binary problem with 2 algorithms"""
X1 = jnp.c_[1 - dfs[0], dfs[0]]
X2 = jnp.c_[1 - dfs[1], dfs[1]]
h = {v: k for k, v in enumerate(self.classes_)}
y_ = jnp.array([h[i] for i in y])
y_ = np.c_[1 - y_, y_]
if self.deviation is None:
deviation = soft_BER
else:
deviation = self.deviation
params = jnp.linspace(0, 1, 100)
perf = [deviation(y_, p * X1 + (1 - p) * X2)
for p in params]
self._mixer_value = params[np.argmin(perf)]

def _fit_bin_k(self, dfs, y):
"""Fit binary classification problems with k classifiers"""
_ = np.ones(len(dfs))
X = np.concatenate(dfs, axis=1)
params = dict(mixer=jnp.array(_))
p = classifier(params, stackbow_b_k, X, y,
validation=0, epochs=10000,
deviation=self.deviation,
distribution=True,
**self.optimizer_kwargs)
self._mixer_value = softmax(p['mixer'])

def _fit_mul_2(self, dfs, y):
_ = np.zeros(dfs[0].shape[1])
params = dict(mixer=jnp.array(_))
p = classifier(params, stackbow, dfs[0], y,
model_args=(dfs[1], ),
validation=0, epochs=10000,
deviation=self.deviation,
distribution=True,
**self.optimizer_kwargs)
self._mixer_value = expit(p['mixer'])

def _fit_mul_k(self, dfs, y):
X = np.array([x.T for x in dfs]).T
_ = np.ones(X.shape[1:])
params = dict(mixer=jnp.array(_))
p = classifier(params, stackbow_m_k, X, y,
validation=0, epochs=10000,
deviation=self.deviation,
distribution=True,
**self.optimizer_kwargs)
self._mixer_value = softmax(p['mixer'], axis=1)

def fit(self, D: List[Union[dict, list]],
y: Union[np.ndarray, None]=None) -> 'StackBoW':
y = self.dependent_variable(D, y=y)
self.classes_ = np.unique(y)
dfs = [ins.train_predict_decision_function(D, y=y)
for ins in self._decision_function_models]
if dfs[0].shape[1] > 1:
dfs = [softmax(x, axis=1) for x in dfs]
if len(dfs) == 2:
self._fit_mul_2(dfs, y)
else:
self._fit_mul_k(dfs, y)
else:
dfs = [expit(x) for x in dfs]
if len(dfs) == 2:
self._fit_bin_2(dfs, y)
else:
self._fit_bin_k(dfs, y)
_ = [clone(ins).fit(D, y=y) for ins in self._decision_function_models]
self._decision_function_models = _
return self

def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
dfs = [ins.decision_function(D)
for ins in self._decision_function_models]
if dfs[0].shape[1] > 1:
dfs = [softmax(x, axis=1) for x in dfs]
mixer = self.mixer_value
if len(dfs) == 2:
frst = dfs[0] * mixer
scnd = dfs[1] * (1 - mixer)
return frst + scnd
X = np.array([x.T for x in dfs]).T
return (X * mixer).sum(axis=-1)
else:
dfs = [expit(x) for x in dfs]
p = self.mixer_value
if len(dfs) == 2:
X1 = jnp.c_[1 - dfs[0], dfs[0]]
X2 = jnp.c_[1 - dfs[1], dfs[1]]
return p * X1 + (1 - p) * X2
X = np.concatenate(dfs, axis=1)
hy = X @ p
return np.c_[1 - hy, hy]

def predict(self, D: List[Union[dict, list]]) -> np.ndarray:
df = self.decision_function(D)
index = df.argmax(axis=1)
return self.classes_[index]


class StackBoWBP(DenseBoWBP):
@property
def model(self):
Expand All @@ -284,12 +473,26 @@ def model(self):

def initial_parameters(self, X, y, df):
params = super(StackBoWBP, self).initial_parameters(X, y)
params['E'] = jnp.array([0.5, 0.5])
dense_w = self.weights.T
dense_bias = self.bias
Xd = X @ dense_w + dense_bias
if self.classes_.shape[0] > 2:
y = y.argmax(axis=1)
hy_dense = self.train_predict_decision_function([1] * Xd.shape[0], y=y, X=Xd)
if self.classes_.shape[0] > 2:
df = expit(df)
hy_dense = expit(hy_dense)
else:
df = softmax(df, axis=1)
hy_dense = softmax(hy_dense, axis=1)
params['E'] = initial_parameters(hy_dense, df, y,
nclasses=self.classes_.shape[0])
return params

def model_args(self, D: List[Union[dict, list]]):
if not hasattr(self, '_bow_ins'):
hy = BoW.train_predict_decision_function(self, D)
X = self._transform(D)
hy = self.train_predict_decision_function(D, X=X)
else:
X = super(StackBoWBP, self)._transform(D)
hy = getattr(self._bow_ins, self.decision_function_name)(X)
Expand Down
47 changes: 43 additions & 4 deletions EvoMSA/tests/test_back_prop.py
Expand Up @@ -17,8 +17,8 @@
from microtc.utils import tweet_iterator
from jax.experimental.sparse import BCSR
import numpy as np
from EvoMSA.back_prop import BoWBP, bow_model, DenseBoWBP, StackBoWBP
from EvoMSA.text_repr import BoW, DenseBoW
from EvoMSA.back_prop import BoWBP, bow_model, DenseBoWBP, StackBoWBP, StackBoW
from EvoMSA.text_repr import BoW, DenseBoW, BoW
from EvoMSA.tests.test_base import TWEETS


Expand Down Expand Up @@ -115,9 +115,48 @@ def test_DenseBoWBP():
assert len(hy) == len(D)


def test_StackBoWBP():
def test_StackBoWBP_initial_parameters():
"""Test StackBoWBP"""

dataset = list(tweet_iterator(TWEETS))
ins = StackBoWBP(lang='es', voc_size_exponent=13).fit(dataset)
assert 'E' in ins.parameters
assert np.fabs(ins.parameters['E'] - np.array([0.5, 0.5])).sum() > 0
D = [x for x in dataset if x['klass'] in {'N', 'P'}]
ins = StackBoWBP(lang='es', voc_size_exponent=13).fit(D)
assert np.fabs(ins.parameters['E'] - np.array([0.5, 0.5])).sum() > 0


def test_StackBoW():
"""Test StackBoW"""
from sklearn.metrics import f1_score

D = list(tweet_iterator(TWEETS))
D2 = [x for x in D if x['klass'] in {'N', 'P'}]
ins = StackBoW(lang='es').fit(D2)
y = np.array([x['klass'] for x in D2])
_ = f1_score(y, ins.predict(D2), average='macro')
assert _ > 0.95
ins = StackBoW(lang='es').fit(D)
y = np.array([x['klass'] for x in D])
_ = f1_score(y, ins.predict(D), average='macro')
assert _ > 0.95


def test_StackBoW_3_algs():
"""Test StackBoW 3 classifiers"""
from sklearn.metrics import f1_score

D = list(tweet_iterator(TWEETS))
D2 = [x for x in D if x['klass'] in {'N', 'P'}]
bow = BoW(lang='es')
bow2 = BoW(lang='es', pretrain=False)
dense = DenseBoW(lang='es', voc_size_exponent=13)
ins = StackBoW([bow, bow2, dense], lang='es').fit(D2)
y = np.array([x['klass'] for x in D2])
_ = f1_score(y, ins.predict(D2), average='macro')
assert _ > 0.95
bow2 = BoW(lang='es', pretrain=False)
ins = StackBoW([bow, bow2, dense], lang='es').fit(D)
y = np.array([x['klass'] for x in D])
_ = f1_score(y, ins.predict(D), average='macro')
assert _ > 0.95

0 comments on commit 478f352

Please sign in to comment.