/
classify_dataset_partition_iter.m
239 lines (190 loc) · 7.89 KB
/
classify_dataset_partition_iter.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
function [ ] = classify_dataset_partition_iter( data, params, logger, varargin )
%CLASSIFY_DATASET Summary of this function goes here
% Detailed explanation goes here
function [parts2level , n_parts] = get_parts(config, pyr_levels)
%% Decide the parts according to the config
switch lower(config)
case 'comb'
n_parts = pyr_levels*(pyr_levels+1)/2 ;
len = pyr_levels:-1:1 ;
val = 1:pyr_levels ;
case 'hier'
n_parts = pyr_levels*(pyr_levels-1)/2 ;
len = pyr_levels-1:-1:1 ;
val = 1:pyr_levels-1 ;
case 'level'
n_parts = pyr_levels ;
len = ones(pyr_levels, 1) ;
val = ones(pyr_levels, 1) ;
case 'base'
if pyr_levels ~= 1
warning('config is base but pyr_levels is not one')
pyr_levels = 1 ;
end
n_parts = pyr_levels ;
len = 1 ;
val = 1 ;
otherwise
error('classify_dataset_kfold:get_parts:incorrectConfig', ...
'Error.\nNot implemented config %s.', config)
end
% Parts to level correspondence matrix
temp1 = cumsum(len);
temp2 = zeros(1, temp1(end));
temp2(temp1(1:end-1)+1) = 1;
temp2(1) = 1;
parts2level = val(cumsum(temp2));
end % function
%% Default values
[epsi, del, pyr_levels, pyr_reduction, edge_thresh, clustering_func, MAX2,...
node_label, nits, VERBOSE, task_id, config] = input_parser( varargin ) ;
rng(0);
%% Decide the parts according to the config
[parts2level , n_parts] = get_parts(config, pyr_levels) ;
%% Database information
data.dataset.graphs = [data.dataset.graphs_train , data.dataset.graphs_test];
data.dataset.clss = [data.dataset.clss_train; data.dataset.clss_test];
ntrain = size(data.dataset.graphs_train, 2);
ntest = size(data.dataset.graphs_test, 2);
ngraphs = size(data.dataset.graphs, 2);
classes = unique(data.dataset.clss);
nclasses = size(classes, 1);
if VERBOSE
fprintf('Dataset Information\n\tNumber of graph:%d\t(Train %d\tTest %d)\n\tNumber of classes: %d\n',...
ngraphs,ntrain,ntest, nclasses) ;
end
combinations = cell(1, pyr_levels) ;
for j = 1:pyr_levels
combinations{j} = (1:MAX2(j)-2)' ;
end
combinations = allcomb( combinations ) ;
accs = zeros(nits, size(combinations, 1)) ;
for it = 1:nits
%% Create histogram indices
% Initialize storage
for i = 1:n_parts
M{i} = uint32(ceil(2*(params.T(1:MAX2(parts2level(i)))*log(2)+log(1/del))/epsi^2)) ;
global_var(i).idx_graph = cell(MAX2(parts2level(i))-2, 1) ;
global_var(i).idx_bin = cell(MAX2(parts2level(i))-2, 1) ;
global_var(i).hash_codes_uniq = cell(MAX2(parts2level(i))-2, 1) ;
end
%% Iterate whole dataset
for i = 1:ngraphs
if VERBOSE
fprintf('Graph: %d. ',i);
tic;
end
if ~strcmpi(config, 'base')
H = generateHierarchy( data.dataset.graphs(i), pyr_levels, clustering_func, pyr_reduction, edge_thresh ) ;
hier_graph = cell(n_parts, 1) ;
hc = 1 ;
% Construct hierarchy
for i_ = 1:pyr_levels
if ~strcmpi(config, 'hier')
hier_graph{hc} = getLevels(H, i_) ;
hc = hc + 1 ;
end % if
if ~strcmpi(config, 'level')
for j_ = i_+1:pyr_levels
hier_graph{hc} = getSubhierarchy(H, i_, j_) ;
hc = hc + 1 ;
end % for
end % if
end % for
else
hier_graph{n_parts} = data.dataset.graphs(i) ;
end % if
% Embedding
for j = 1:n_parts
if any(hier_graph{j}.am(:))
[ global_var(j) ] = graphlet_embedding(hier_graph{j} , i , M{j} , global_var(j), MAX2(parts2level(j)) , node_label ) ;
end % if
end % for
if VERBOSE
toc
end
end
% Histogram dimensions
dim_hists = cell(n_parts,1) ;
for i = 1:n_parts
dim_hists{i} = cellfun(@(x) size(x,1) ,global_var(i).hash_codes_uniq);
clear global_var(i).hash_codes_uniq;
end
%% Compute histograms and kernels
histograms = cell(n_parts,1);
for j = 1:n_parts
histograms{j} = cell(1, MAX2(parts2level(j))-2);
for i = 1:MAX2(parts2level(j))-2
histograms{j}{i} = sparse(global_var(j).idx_graph{i}, global_var(j).idx_bin{i}, 1, ngraphs, dim_hists{j}(i)) ;
end
end
% All possible combinations
combinations = cell(1, pyr_levels) ;
for j = 1:pyr_levels
combinations{j} = (1:MAX2(j)-2)' ;
end
combinations = allcomb( combinations ) ;
w_classes = ones(1, nclasses) ;
for i = 1:nclasses
w_classes(i) = nnz(data.dataset.clss_train == classes(i)) ;
end
w_classes = 1./w_classes ;
w_classes = w_classes/max(w_classes) ;
w_str = [] ;
for i = 1:nclasses
w_str = [w_str, sprintf('-w%d %f ', classes(i), w_classes(i))] ;
end
for c = 1:size(combinations,1)
comb = combinations(c,:);
% Concat histogram
comb_hist = [] ;
for i = 1:length(comb)
comb_hist = [comb_hist, combine_graphlet_hist(histograms(parts2level == i), comb(i), 'combine') ] ;
end
% Normalize hist
X = bsxfun(@times, comb_hist, 1./(sum(comb_hist,2)+eps)) ;
X_train = X(1:ntrain,:) ;
X_test = X(ntrain+(1:ntest),:) ;
KM_train = vl_alldist2(X_train',X_train','KL1') ;
KM_test = vl_alldist2(X_test',X_train','KL1') ;
%% Evaluate
% Evaluate nits times to get the accuracy mean and standard deviation
train_classes = data.dataset.clss(1:ntrain) ;
test_classes = data.dataset.clss(ntrain+(1:ntest)) ;
% Training and testing individual kernels
K_train = [(1:ntrain)' KM_train] ;
K_test = [(1:ntest)' KM_test] ;
cs = 5:5:100 ;
best_cv = 0 ;
for j = 1:length(cs)
options = sprintf('-s 0 -t 4 -v %d -c %f %s-b 1 -g 0.07 -h 0 -q',...
nits,cs(j), w_str) ;
model_libsvm = svmtrain(train_classes,K_train,options) ;
if(model_libsvm>best_cv)
best_cv = model_libsvm ;
best_c = cs(j) ;
end
end
options = sprintf('-s 0 -t 4 -c %f %s-b 1 -g 0.07 -h 0 -q',...
best_c, w_str) ;
model_libsvm = svmtrain(train_classes,K_train,options) ;
[~,acc,~] = svmpredict(test_classes,K_test,model_libsvm,'-b 1') ;
% Mean
accs(it, c) = acc(1) ;
end
clear global_var ;
end
maccs = mean(accs) ;
mstds = std(accs) ;
% Save results
if strcmpi(node_label, 'unlabel')
combinations = combinations + 2 ;
else
if size(combinations, 2) > 1
combinations(:, 2:end) = combinations(:, 2:end) + 2 ;
end
end
for i = 1:size(combinations, 1)
logger(epsi, del, combinations(i,:), node_label, pyr_levels, pyr_reduction, edge_thresh, func2str(clustering_func), config, nits, maccs(i), mstds(i)) ;
end
end