diff --git a/demo_script.m b/demo_script.m index 6c99904..ad343a9 100644 --- a/demo_script.m +++ b/demo_script.m @@ -60,9 +60,23 @@ [C,f,P,S,YrA] = update_temporal_components(Yr,A,b,Cin,fin,P,options); %% classify components + [ROIvars.rval_space,ROIvars.rval_time,ROIvars.max_pr,ROIvars.sizeA,keep] = classify_components(Y,A,C,b,f,YrA,options); -A_keep = A(:,keep); -C_keep = C(keep,:); + +%% further classification with cnn_classifier +try % matlab 2017b or later is needed + [ind,value] = cnn_classifier(A,[d1,d2],'cnn_model',0.2); +catch + ind = true(size(A,2),1); +end +%% display kept and discarded components +A_keep = A(:,(keep & ind)); +C_keep = C((keep & ind),:); +figure; + subplot(121); montage(extract_patch(A(:,(keep & ind)),[d1,d2],[30,30]),'DisplayRange',[0,0.15]); + title('Kept Components'); + subplot(122); montage(extract_patch(A(:,~(keep & ind)),[d1,d2],[30,30]),'DisplayRange',[0,0.15]) + title('Discarded Components'); %% merge found components [Am,Cm,K_m,merged_ROIs,Pm,Sm] = merge_components(Yr,A_keep,b,C_keep,f,P,S,options); diff --git a/utilities/extract_patch.m b/utilities/extract_patch.m index 3dbd65b..5ee8357 100644 --- a/utilities/extract_patch.m +++ b/utilities/extract_patch.m @@ -17,6 +17,7 @@ nd = length(dims); if nd == 2; dims(3) = 1; patch_size(3) = 1; end K = size(A,2); +A = A/spdiags(sqrt(sum(A.^2,1))'+eps,0,K,K); % normalize to sum 1 for each compoennt cm = com(A,dims(1),dims(2),dims(3)); xx = -ceil(patch_size(1)/2-1):floor(patch_size(1)/2); yy = -ceil(patch_size(2)/2-1):floor(patch_size(2)/2);