-
Notifications
You must be signed in to change notification settings - Fork 2
/
ex1.py
32 lines (24 loc) · 920 Bytes
/
ex1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from gmum.melm import MELM
from sklearn.datasets import load_svmlight_file
import sys
from sklearn import cross_validation
from gmum.utils import BAC
import numpy as np
if __name__ == '__main__':
X, y = load_svmlight_file(sys.argv[1])
X = X.toarray()
skf = cross_validation.StratifiedKFold(y, n_folds=5)
for clf in [
MELM(gamma=1.0, k=2, random_state=666, n_starts=17, classifier='KNN'),
MELM(gamma=1.0, k=2, random_state=666, n_starts=17, classifier='SVM'),
MELM(gamma=1.0, k=2, random_state=666, n_starts=17, classifier='KDE'),
]:
print clf
scores = []
for train_index, test_index in skf:
clf.fit(X[train_index], y[train_index])
score = BAC(clf.predict(X[test_index]), y[test_index])
scores.append(score)
print ' ', score
print np.mean(scores), '+/-', np.std(scores)
print