Skip to content

Commit

Permalink
StackBoWBP Binary
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 31, 2024
1 parent 6159d51 commit 30466ef
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
30 changes: 21 additions & 9 deletions EvoMSA/back_prop.py
Expand Up @@ -16,7 +16,6 @@
from jax import nn
import jax.numpy as jnp
import numpy as np
from scipy.special import softmax
from scipy.sparse import spmatrix
from jax.experimental.sparse import BCSR
from sklearn.model_selection import StratifiedShuffleSplit
Expand Down Expand Up @@ -51,7 +50,19 @@ def stack_model(params, X, df):
Y = Y @ params['W_cl'] + params['W0_cl']
Y = nn.softmax(Y, axis=1)
pesos = nn.softmax(params['E'])
return Y * pesos[0] + df * pesos[1]
return Y * pesos[0] + nn.softmax(df, axis=1) * pesos[1]


@jax.jit
def stack_model_binary(params, X, df):
"""StackBoWBP model"""

_ = X @ params['W'] + params['W0']
Y = _ / jnp.linalg.norm(_, axis=1, keepdims=True)
Y = Y @ params['W_cl'] + params['W0_cl']
Y = nn.sigmoid(Y)
pesos = nn.softmax(params['E'])
return Y * pesos[0] + nn.sigmoid(df) * pesos[1] - 0.5


class BoWBP(BoW):
Expand All @@ -76,6 +87,7 @@ def __init__(self, voc_size_exponent: int=15,
self.deviation = deviation
self.optimizer_kwargs = optimizer_kwargs
self.fraction_initial_parameters = fraction_initial_parameters
self.classes_ = None

@property
def fraction_initial_parameters(self):
Expand Down Expand Up @@ -171,7 +183,7 @@ def fit(self, D: List[Union[dict, list]],
p = p[0]
self.parameters = p
return self

def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
X = self._transform(D)
params = self.parameters
Expand All @@ -182,15 +194,15 @@ def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
args = [self.array(x) for x in args]
hy = self.model(params, BCSR.from_scipy_sparse(X), *args)
return hy

def predict(self, D: List[Union[dict, list]]) -> np.ndarray:
df = self.decision_function(D)
if df.shape[1] == 1:
index = np.where(df.flatten() > 0, 1, 0)
else:
index = df.argmax(axis=1)
return self.classes_[index]

@staticmethod
def array(data):
"""Encode data on jax"""
Expand Down Expand Up @@ -227,7 +239,7 @@ def model(self):

def _transform(self, X):
return self.bow.transform(X)

def initial_parameters(self, X, y):
if y.ndim > 1:
y = y.argmax(axis=1)
Expand Down Expand Up @@ -261,11 +273,13 @@ def bias(self):
# _ = [clone(m) for m in self.text_representations]
# ins.text_representations = _
# return ins


class StackBoWBP(DenseBoWBP):
@property
def model(self):
if self.classes_.shape[0] == 2:
return stack_model_binary
return stack_model

def initial_parameters(self, X, y, df):
Expand All @@ -281,8 +295,6 @@ def model_args(self, D: List[Union[dict, list]]):
hy = getattr(self._bow_ins, self.decision_function_name)(X)
if hy.ndim == 1:
hy = np.atleast_2d(hy).T
else:
hy = softmax(hy, axis=1)
return (hy, )

def fit(self, D: List[Union[dict, list]],
Expand Down
5 changes: 5 additions & 0 deletions EvoMSA/tests/test_back_prop.py
Expand Up @@ -47,6 +47,11 @@ def test_binary():
"""Test BoWBP"""
D = list(tweet_iterator(TWEETS))
D = [x for x in D if x['klass'] in {'N', 'P'}]
stack = StackBoWBP(lang='es',
voc_size_exponent=13).fit(D)
hy = stack.predict(D)
acc = (np.array([x['klass'] for x in D]) == hy).mean()
assert acc > 0.95
bow = BoWBP(lang='es').fit(D)
hy = bow.predict(D)
acc = (np.array([x['klass'] for x in D]) == hy).mean()
Expand Down

0 comments on commit 30466ef

Please sign in to comment.