Skip to content

Commit

Permalink
Bug in binary
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jan 31, 2024
1 parent be9e76b commit 6159d51
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
6 changes: 5 additions & 1 deletion EvoMSA/back_prop.py
Expand Up @@ -185,7 +185,11 @@ def decision_function(self, D: List[Union[dict, list]]) -> np.ndarray:

def predict(self, D: List[Union[dict, list]]) -> np.ndarray:
df = self.decision_function(D)
return self.classes_[df.argmax(axis=1)]
if df.shape[1] == 1:
index = np.where(df.flatten() > 0, 1, 0)
else:
index = df.argmax(axis=1)
return self.classes_[index]

@staticmethod
def array(data):
Expand Down
7 changes: 7 additions & 0 deletions EvoMSA/tests/test_back_prop.py
Expand Up @@ -48,8 +48,15 @@ def test_binary():
D = list(tweet_iterator(TWEETS))
D = [x for x in D if x['klass'] in {'N', 'P'}]
bow = BoWBP(lang='es').fit(D)
hy = bow.predict(D)
acc = (np.array([x['klass'] for x in D]) == hy).mean()
assert acc > 0.9
dense = DenseBoWBP(lang='es',
voc_size_exponent=13).fit(D)
hy = dense.predict(D)
acc = (np.array([x['klass'] for x in D]) == hy).mean()
assert acc > 0.85



# def test_BoWBP_validation_set():
Expand Down

0 comments on commit 6159d51

Please sign in to comment.