Skip to content

Commit

Permalink
Cleaning code
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 31, 2024
1 parent 872f026 commit 72a6b79
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions EvoMSA/back_prop.py
Expand Up @@ -17,9 +17,9 @@
import jax.numpy as jnp
import numpy as np
from scipy.special import softmax
from scipy.sparse import spmatrix
from jax.experimental.sparse import BCSR
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.base import clone
from IngeoML.optimizer import classifier
from EvoMSA.text_repr import BoW, DenseBoW

Expand Down Expand Up @@ -174,12 +174,25 @@ def fit(self, D: List[Union[dict, list]],
def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
X = self._transform(D)
params = self.parameters
hy = self.model(params, BCSR.from_scipy_sparse(X))
args = self.model_args(D)
if args is None:
hy = self.model(params, BCSR.from_scipy_sparse(X))
else:
args = [self.array(x) for x in args]
hy = self.model(params, BCSR.from_scipy_sparse(X), *args)
return hy

def predict(self, D: List[Union[dict, list]]) -> np.ndarray:
df = self.decision_function(D)
return self.classes_[df.argmax(axis=1)]
return self.classes_[df.argmax(axis=1)]

@staticmethod
def array(data):
"""Encode data on jax"""

if isinstance(data, spmatrix):
return BCSR.from_scipy_sparse(data)
return jnp.array(data)


class DenseBoWBP(DenseBoW, BoWBP):
Expand Down Expand Up @@ -272,12 +285,4 @@ def fit(self, D: List[Union[dict, list]],
_ = self._transform(D)
labels = self.dependent_variable(D, y=y)
self._bow_ins = self.estimator_class(**self.estimator_kwargs).fit(_, labels)
return self

def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:
X = self._transform(D)
params = self.parameters
df = self.model_args(D)[0]
hy = self.model(params, BCSR.from_scipy_sparse(X),
jnp.array(df))
return hy
return self

0 comments on commit 72a6b79

Please sign in to comment.