/
utils.py
72 lines (56 loc) · 1.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from itertools import combinations
from huber_obj import pair_col_diff_norm2
def norm(x):
return np.sqrt(np.sum(x**2))
def norm2(x):
return np.sum(x**2)
def cluster(x, epsilon):
'''
cluster n points according to x, points i and j are in the
same group if ||xi - xj|| < epsilon
Args:
x: (d, n) stores each point's "centroid" in its columns
Output:
clustered: list containing the ID of the cluster of each point
num_clusters: total number of clusters
'''
d, n = x.shape
idx = np.array(list(combinations(list(range(n)), 2)))
dist = np.sqrt(pair_col_diff_norm2(x, idx)) # (n, n)
tmp = - np.ones((n, n))
tmp[np.triu_indices(n, 1)] = dist
idx = np.where((tmp < epsilon) & (tmp > 0))
idx = list(zip(*idx))
G = Graph(n)
for i in idx:
G.add_edge(i[0], i[1])
CC = ConnectedComponent(G)
num_clusters = CC.count
clustered = CC.ID
return clustered, num_clusters
class Graph:
def __init__(self, V):
self.V = V
self.E = 0
self.adj = [[] for _ in range(V)]
def add_edge(self, v, w):
self.adj[v].append(w)
self.adj[w].append(v)
self.E += 1
class ConnectedComponent:
def __init__(self, G):
self.G = G
self.count = 0
self.ID = [0] * G.V
self.marked = [False] * G.V
for v in range(G.V):
if not self.marked[v]:
self.dfs(G, v)
self.count += 1
def dfs(self, G, v):
self.marked[v] = True
self.ID[v] = self.count
for w in G.adj[v]:
if not self.marked[w]:
self.dfs(G, w)