Skip to content

Commit

Permalink
Bug in binary problems
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 31, 2024
1 parent 72a6b79 commit be9e76b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
6 changes: 4 additions & 2 deletions EvoMSA/back_prop.py
Expand Up @@ -116,7 +116,8 @@ def deviation(self, value):
self._deviation = value

def initial_parameters(self, X, y):
y = y.argmax(axis=1)
if y.ndim > 1:
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
if train_size == 1:
tr = np.arange(X.shape[0])
Expand Down Expand Up @@ -224,7 +225,8 @@ def _transform(self, X):
return self.bow.transform(X)

def initial_parameters(self, X, y):
y = y.argmax(axis=1)
if y.ndim > 1:
y = y.argmax(axis=1)
train_size = self.fraction_initial_parameters
if train_size == 1:
tr = np.arange(X.shape[0])
Expand Down
9 changes: 9 additions & 0 deletions EvoMSA/tests/test_back_prop.py
Expand Up @@ -43,6 +43,15 @@ def test_BoWBP():
assert bow.predict(D) is not None


def test_binary():
"""Test BoWBP"""
D = list(tweet_iterator(TWEETS))
D = [x for x in D if x['klass'] in {'N', 'P'}]
bow = BoWBP(lang='es').fit(D)
dense = DenseBoWBP(lang='es',
voc_size_exponent=13).fit(D)


# def test_BoWBP_validation_set():
# """Test the validation_set property"""
# D = list(tweet_iterator(TWEETS))
Expand Down

0 comments on commit be9e76b

Please sign in to comment.