-
Notifications
You must be signed in to change notification settings - Fork 0
/
combine_sampler.py
191 lines (158 loc) · 7.25 KB
/
combine_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from torch.utils.data.sampler import Sampler
import random
class CombineSampler(Sampler):
"""
l_inds (list of lists)
cl_b (int): classes in a batch
n_cl (int): num of obs per class inside the batch
"""
def __init__(self, l_inds, cl_b, n_cl):
self.l_inds = l_inds
self.max = -1
self.cl_b = cl_b
self.n_cl = n_cl
self.batch_size = cl_b * n_cl
self.flat_list = []
for inds in l_inds:
if len(inds) > self.max:
self.max = len(inds)
def __iter__(self):
# shuffle elements inside each class
l_inds = list(map(lambda a: random.sample(a, len(a)), self.l_inds))
# add elements till every class has the same num of obs
for inds in l_inds:
n_els = self.max - len(inds) + 1 # take out 1?
inds.extend(inds[:n_els]) # max + 1
# split lists of a class every n_cl elements
split_list_of_indices = []
for inds in l_inds:
# drop the last < n_cl elements
while len(inds) >= self.n_cl:
split_list_of_indices.append(inds[:self.n_cl])
inds = inds[self.n_cl:]
# shuffle the order of classes
random.shuffle(split_list_of_indices)
self.flat_list = [item for sublist in split_list_of_indices for item in sublist]
return iter(self.flat_list)
def __len__(self):
return len(self.flat_list)
class CombineSamplerAdvanced(Sampler):
"""
l_inds (list of lists)
cl_b (int): classes in a batch
n_cl (int): num of obs per class inside the batch
"""
def __init__(self, l_inds, num_classes, num_elements_class, dict_class_distances, iterations_for_epoch):
self.l_inds = l_inds
self.num_classes = num_classes
self.num_elements_class = num_elements_class
self.batch_size = self.num_classes * self.num_elements_class
self.flat_list = []
self.iterations_for_epoch = iterations_for_epoch
self.dict_class_distances = dict_class_distances
def __iter__(self):
self.flat_list = []
for ii in range(int(self.iterations_for_epoch)):
temp_list = []
# get the class
pivot_class_index = random.randint(0, self.num_classes - 1)
pivot_class = self.l_inds[pivot_class_index]
# put the elements of the class in a temp list
pivot_elements = random.sample(pivot_class, self.num_elements_class)
temp_list.extend(pivot_elements)
# get the k nearest neighbors of the class
other_class_indices = self.dict_class_distances[pivot_class_index][:self.num_classes - 1]
# for each neighbor, put the elements of it in a temp list
for class_index in other_class_indices:
other_class = self.l_inds[class_index]
# toDO - try/except error if class has less than k elements, in which case get all of them
other_elements = random.sample(other_class, self.num_elements_class)
temp_list.extend(other_elements)
# shuffle the temp list
random.shuffle(temp_list)
self.flat_list.extend(temp_list)
return iter(self.flat_list)
def __len__(self):
return len(self.flat_list)
class CombineSamplerSuperclass(Sampler):
"""
l_inds (list of lists)
cl_b (int): classes in a batch
n_cl (int): num of obs per class inside the batch
"""
def __init__(self, l_inds, num_classes, num_elements_class, dict_superclass, iterations_for_epoch):
self.l_inds = l_inds
self.num_classes = num_classes
self.num_elements_class = num_elements_class
self.flat_list = []
self.iterations_for_epoch = iterations_for_epoch
self.dict_superclass = dict_superclass
def __iter__(self):
self.flat_list = []
for ii in range(int(self.iterations_for_epoch)):
temp_list = []
# randomly sample the superclass
superclass = random.choice(list(self.dict_superclass.keys()))
list_of_potential_classes = self.dict_superclass[superclass]
# randomly sample k classes for the superclass
classes = random.sample(list_of_potential_classes, self.num_classes)
# get the n objects for each class
for class_index in classes:
# classes are '141742158611' etc instead of 1, 2, 3, ..., this should be fixed by finding a mapping between two types of names
class_ = self.l_inds[class_index]
# check if the number of elements is >= self.num_elements_class
if len(class_) >= self.num_elements_class:
elements = random.sample(class_, self.num_elements_class)
else:
elements = random.choices(class_, k=self.num_elements_class)
temp_list.extend(elements)
# shuffle the temp list
random.shuffle(temp_list)
self.flat_list.extend(temp_list)
return iter(self.flat_list)
def __len__(self):
return len(self.flat_list)
class CombineSamplerSuperclass2(Sampler):
"""
l_inds (list of lists)
cl_b (int): classes in a batch
n_cl (int): num of obs per class inside the batch
"""
def __init__(self, l_inds, num_classes, num_elements_class, dict_superclass, iterations_for_epoch):
self.l_inds = l_inds
self.num_classes = num_classes
self.num_elements_class = num_elements_class
self.flat_list = []
self.iterations_for_epoch = iterations_for_epoch
self.dict_superclass = dict_superclass
def __iter__(self):
self.flat_list = []
for ii in range(int(self.iterations_for_epoch)):
temp_list = []
# randomly sample the superclass
superclass_1 = random.choice(list(self.dict_superclass.keys()))
list_of_potential_classes_1 = self.dict_superclass[superclass_1]
superclass_2 = superclass_1
while superclass_2 == superclass_1:
superclass_2 = random.choice(list(self.dict_superclass.keys()))
list_of_potential_classes_2 = self.dict_superclass[superclass_2]
# randomly sample k classes for the superclass
classes = random.sample(list_of_potential_classes_1, self.num_classes // 2)
classes_2 = random.sample(list_of_potential_classes_2, self.num_classes // 2)
classes.extend(classes_2)
# get the n objects for each class
for class_index in classes:
# classes are '141742158611' etc instead of 1, 2, 3, ..., this should be fixed by finding a mapping between two types of names
class_ = self.l_inds[class_index]
# check if the number of elements is >= self.num_elements_class
if len(class_) >= self.num_elements_class:
elements = random.sample(class_, self.num_elements_class)
else:
elements = random.choices(class_, k=self.num_elements_class)
temp_list.extend(elements)
# shuffle the temp list
random.shuffle(temp_list)
self.flat_list.extend(temp_list)
return iter(self.flat_list)
def __len__(self):
return len(self.flat_list)