/
art.py
62 lines (50 loc) · 1.82 KB
/
art.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
from rpy2.robjects.packages import importr
import rpy2.robjects as robjects
import rpy2.robjects.packages as rpackages
from rpy2.robjects.vectors import StrVector, FloatVector
import rpy2.robjects.numpy2ri
import logging
rpy2.robjects.numpy2ri.activate()
def bootstrap_r():
if not rpackages.isinstalled('RSNNS'):
logging.debug('RSNNS not found. Installing...')
utils = rpackages.importr('utils')
utils.chooseCRANmirror(ind=1)
utils.install_packages(StrVector(['RSNNS']))
logging.debug('Installed RSNNS')
class BinaryMemebershipMatrix:
def __init__(self, vectors, labels):
vectors = np.asarray(vectors)
labels = np.asarray(labels)
flat_labels = labels.T.reshape([-1])
bmm = np.eye(np.max(flat_labels) + 1)[flat_labels]
bmm = np.hstack(np.vsplit(bmm, np.max(labels) + 1)).T
self.data = (bmm * np.linalg.norm(vectors, axis=1)).T
def __array__(self):
return self.data
def __str__(self):
return str(np.asarray(self))
class ART:
def __init__(self):
bootstrap_r()
self.rsnns = importr('RSNNS')
self.predict = robjects.r('predict')
def cluster(self, bmm: BinaryMemebershipMatrix, *, n_clusters):
model = rsnns.art2(
np.asarray(bmm),
f2Unit=n_clusters
)
encodeClassLabels = robjects.r('encodeClassLabels')
return encodeClassLabels(model['fitted.values'])
if __name__ == '__main__':
# rsnns = importr('RSNNS')
# model = rsnns.art2(
# patterns,
# f2Unit=2#,
# #learnFuncParams=robjects.FloatVector([0.99, 20, 20, 0.1, 0]),
# #updateFuncParams=robjects.FloatVector([0.99, 20, 20, 0.1, 0])
# )
# predict = robjects.r('predict')
# predictions = predict(model, testPatterns)
art = ART()