Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rigid distance #298

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Release 1.0.1
* add the option to collect normalized MSE during fitting. False by default
* fix the rhd wrapper
* divide and conquer assigment now based on barycenters instead of simple extremas
* more efficient merging of the drifts with rigid registration of the templates
* exit the clustering if too many centroids are found (sign of bad channels)
* fixes in the meta merging GUI (RPV and dip)
* optimizations for the second component, less double counting
Expand Down
15 changes: 15 additions & 0 deletions circus/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def main(params, nb_cpu, nb_gpu, use_gpu):
smoothing_factor = params.getfloat('detection', 'smoothing_factor')
noise_window = params.getint('detection', 'noise_time')
low_channels_thr = params.getint('detection', 'low_channels_thr')

ss_scale = params.getfloat('clustering', 'smart_search_scale')
search_drifts = params.getboolean('clustering', 'search_drifts')
fixed_amplitudes = params.getboolean('clustering', 'fixed_amplitudes')

if not fixed_amplitudes:
Expand Down Expand Up @@ -1797,6 +1800,18 @@ def reject_rate(x, d, target):
normalization=templates_normalization, debug_plots=debug_plots
)
comm.Barrier()
sys.stderr.flush()

if search_drifts:

if comm.rank == 0:
print_and_log(["Identifying putative drifts for meta merging..."], 'default', logger)

algo.search_drifts(
params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu, debug_plots=debug_plots
)
comm.Barrier()

gc.collect()
sys.stderr.flush()
io.get_overlaps(
Expand Down
4 changes: 3 additions & 1 deletion circus/config.params
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ sensitivity = 3 # Single parameter for clustering sensitivity. The l
cc_merge = 0.95 # If CC between two templates is higher, they are merged
dispersion = (5, 5) # Min and Max dispersion allowed for amplitudes [in MAD]
fine_amplitude = True # Optimize the amplitudes and compute a purity index for each template
search_drifts = True # Search for putative drifts (by rigid registration) to improve meta merging
make_plots = # Generate sanity plots of the clustering [Nothing or None if no plots]

[fitting]
Expand All @@ -74,7 +75,8 @@ sparsity_limit = 0 # Sparsity level (in percentage) for selecting templ
time_rpv = 5 # Time [in ms] to consider for Refraction Period Violations (RPV) (0 to disable)
rpv_threshold = 0.02 # Percentage of RPV allowed while merging
merge_drifts = True # Try to automatically merge drifts, i.e. non overlapping spiking neurons
drift_limit = 1 # Distance for drifts. The higher, the more non-overlapping the activities should be
drift_limit = 0.5 # Distance for drifts. The higher, the more non-overlapping the activities should be
drift_space = 50 # Maximal distance allowed for template translations during rigid registration [in um]

[converting]
erase_all = True # If False, a prompt will ask you to export if export has already been done
Expand Down
242 changes: 240 additions & 2 deletions circus/shared/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from circus.shared.files import load_data, write_datasets, get_overlaps, load_data_memshared, get_stas, load_sp_memshared, load_sp
from circus.shared.utils import get_tqdm_progressbar, get_shared_memory_flag, dip, dip_threshold, \
batch_folding_test_with_MPA, bhatta_dist, nd_bhatta_dist, test_if_support, test_if_purity, test_if_confusion
batch_folding_test_with_MPA, bhatta_dist, nd_bhatta_dist, test_if_support, test_if_purity, test_if_drifts, test_if_confusion
from circus.shared.messages import print_and_log
from circus.shared.probes import get_nodes_and_edges
from circus.shared.probes import get_nodes_and_edges, get_nodes_and_positions
from circus.shared.mpi import all_gather_array, comm, gather_array

import scipy.linalg
Expand Down Expand Up @@ -427,6 +427,7 @@ def slice_templates(params, to_remove=None, to_merge=None, extension='', input_e
template_shift = params.getint('detection', 'template_shift')
has_support = test_if_support(params, input_extension)
has_purity = test_if_purity(params, input_extension)
has_drifts = test_if_drifts(params, input_extension)
has_confusion = test_if_confusion(params, input_extension)
fine_amplitude = params.getboolean('clustering', 'fine_amplitude')
fixed_amplitudes = params.getboolean('clustering', 'fixed_amplitudes')
Expand All @@ -452,6 +453,10 @@ def slice_templates(params, to_remove=None, to_merge=None, extension='', input_e
old_purity = load_data(params, 'purity', extension=input_extension)
else:
old_purity = None # default assignment
if has_drifts:
old_drifts = load_data(params, 'drifts', extension=input_extension)
else:
old_drifts = None # default assignment
if has_confusion:
old_confusion = load_data(params, 'confusion', extension=input_extension)
else:
Expand Down Expand Up @@ -492,6 +497,11 @@ def slice_templates(params, to_remove=None, to_merge=None, extension='', input_e
else:
purity = None

if has_drifts:
drifts = hfile.create_dataset('drifts', shape=(len(to_keep), len(to_keep), 4), dtype=numpy.float32, chunks=True)
else:
drifts = None

if has_confusion:
confusion = hfile.create_dataset('confusion', shape=(len(to_keep), len(to_keep)), dtype=numpy.float32, chunks=True)
else:
Expand All @@ -513,6 +523,9 @@ def slice_templates(params, to_remove=None, to_merge=None, extension='', input_e
new_limits = old_limits[keep]
if has_purity:
new_purity = old_purity[keep]

if has_drifts:
new_drifts = old_drifts[keep]
if has_confusion:
new_confusion = old_confusion[keep, to_keep]
else:
Expand Down Expand Up @@ -542,12 +555,19 @@ def slice_templates(params, to_remove=None, to_merge=None, extension='', input_e
new_limits[:, 1] = numpy.max(ratios[:, numpy.newaxis] * old_limits[idx, :, 1], 0)
if has_purity:
new_purity = numpy.mean(old_purity[idx])

if has_drifts:
new_drifts = numpy.mean(old_drifts[idx], 0)

if has_confusion:
new_confusion = numpy.mean(old_confusion[idx][:, to_keep], 0)
else:
new_limits = old_limits[keep]
if has_purity:
new_purity = old_purity[keep]
if has_drifts:
new_drifts = old_drifts[keep]

if has_confusion:
new_confusion = old_confusion[keep, to_keep]

Expand All @@ -560,6 +580,8 @@ def slice_templates(params, to_remove=None, to_merge=None, extension='', input_e
if has_confusion:
confusion[count] = new_confusion

if has_drifts:
drifts[count] = new_drifts[to_keep, :]
# Copy templates to file.
templates = templates.tocoo()
if hdf5_compress:
Expand Down Expand Up @@ -1519,6 +1541,222 @@ def refine_amplitudes(params, nb_cpu, nb_gpu, use_gpu, normalization=True, debug
return


def search_drifts(params, nb_cpu, nb_gpu, use_gpu, debug_plots=''):

data_file = params.data_file
SHARED_MEMORY = get_shared_memory_flag(params)

if SHARED_MEMORY:
templates, mpi_memory_1 = load_data_memshared(params, 'templates', normalize=False, transpose=True)
else:
templates = load_data(params, 'templates')
templates = templates.T

norms = load_data(params, 'norm-templates')
best_elecs = load_data(params, 'electrodes')
nb_templates = templates.shape[0] // 2
_, positions = get_nodes_and_positions(params)
supports = load_data(params, 'supports')
N_e = params.getint('data', 'N_e')
N_t = params.getint('detection', 'N_t')
blosc_compress = params.getboolean('data', 'blosc_compress')
file_out_suff = params.get('data', 'file_out_suff')
plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')

drift_space = params.getfloat('clustering', 'drift_space')
drift_time = params.getint('clustering', 'drift_time')

mask_intersect = numpy.zeros((nb_templates, nb_templates), dtype=numpy.bool)
for i in range(nb_templates):
for j in range(i+1, nb_templates):
mask_intersect[i, j] = numpy.any(supports[i]*supports[j])

mask_intersect = numpy.maximum(mask_intersect, mask_intersect.T)

times = numpy.arange(N_t)
full_times = numpy.tile(times, N_e).reshape(N_e*N_t, 1)
non_zeros = numpy.where(numpy.std(positions, 0) > 0)[0]
full_xyz = numpy.repeat(positions[:, non_zeros], N_t, axis=0)
full_positions = numpy.hstack((full_xyz, full_times))

dimensions = []
for i in non_zeros:
dimensions += [numpy.unique(positions[:,i])]

if len(non_zeros) == 1:
grid = numpy.zeros((len(dimensions[0]), len(times)), dtype=numpy.float32)
elif len(non_zeros) == 2:
grid = numpy.zeros((len(dimensions[0]), len(dimensions[1]), len(times)), dtype=numpy.float32)
elif len(non_zeros) == 3:
grid = numpy.zeros((len(dimensions[0]), len(dimensions[1]), len(dimensions[2]), len(times)), dtype=numpy.float32)

mapping = numpy.zeros((len(dimensions) + 1, N_t * N_e), dtype=numpy.int32)
grid_mask = numpy.ones(grid.shape, dtype=numpy.bool)

dimensions += [times]

for c, pos in enumerate(full_positions):
for d, dim in enumerate(pos):
idx = numpy.where(pos[d] == dimensions[d])[0]
mapping[d, c] = idx

if len(dimensions) == 2:
grid_mask[mapping[0], mapping[1]] = False
elif len(dimensions) == 3:
grid_mask[mapping[0], mapping[1], mapping[2]] = False
elif len(dimensions) == 4:
grid_mask[mapping[0], mapping[1], mapping[2], mapping[3]] = False

missing = numpy.where(grid_mask == True)
if len(dimensions) == 2:
a = dimensions[0][missing[0]], dimensions[1][missing[1]]
elif len(dimensions) == 3:
a = dimensions[0][missing[0]], dimensions[1][missing[1]], dimensions[2][missing[2]]
elif len(dimensions) == 4:
a = dimensions[0][missing[0]], dimensions[1][missing[1]], dimensions[2][missing[2]], dimensions[3][missing[3]]

missing_positions = numpy.array(a)

all_temp = numpy.arange(comm.rank, nb_templates, comm.size)

if comm.rank == 0:
to_explore = get_tqdm_progressbar(params, all_temp)
else:
to_explore = all_temp

registration = numpy.zeros((len(to_explore), nb_templates, 4), dtype=numpy.float32)
boundaries = [(-drift_space, drift_space)] * len(non_zeros) + [(-drift_time, drift_time)]
boundaries = numpy.array(boundaries)

def guess_best_translation(source_template, target_template, positions):

src_norm = numpy.linalg.norm(source_template, axis=1)
tgt_norm = numpy.linalg.norm(target_template, axis=1)
bar_src = src_norm[:, numpy.newaxis] * positions
bar_src = bar_src.sum(0)/src_norm.sum()
bar_tgt = tgt_norm[:, numpy.newaxis] * positions
bar_tgt = bar_tgt.sum(0)/tgt_norm.sum()
return bar_src - bar_tgt

def get_difference(r, interpolator, target, full_positions):
registered = interpolator(full_positions + r)
return numpy.linalg.norm(target - registered)

for count, i in enumerate(to_explore):

source_template = templates[i].toarray().ravel()

if len(dimensions) == 2:
interp_full = scipy.interpolate.Rbf(full_positions[:,0], full_positions[:,1], source_template, epsilon=1e-6)
grid[missing[0], missing[1]] = interp_full(missing_positions[0], missing_positions[1])
grid[mapping[0], mapping[1]] = source_template
elif len(dimensions) == 3:
interp_full = scipy.interpolate.Rbf(full_positions[:,0], full_positions[:,1], full_positions[:,2], source_template, epsilon=1e-6)
grid[missing[0], missing[1], missing[2]] = interp_full(missing_positions[0], missing_positions[1], missing_positions[2])
grid[mapping[0], mapping[1], mapping[2]] = source_template
elif len(dimensions) == 4:
interp_full = scipy.interpolate.Rbf(full_positions[:,0], full_positions[:,1], full_positions[:,2], full_positions[:,3], source_template, epsilon=1e-6)
grid[missing[0], missing[1], missing[2], missing[3]] = interp_full(missing_positions[0], missing_positions[1], missing_positions[2], missing_positions[3])
grid[mapping[0], mapping[1], mapping[2], mapping[3]] = source_template

my_interpolating_function = scipy.interpolate.RegularGridInterpolator(dimensions, grid, bounds_error=False, fill_value=0)

for j in range(i+1, nb_templates):

target_template = templates[j].toarray().ravel()
guess = guess_best_translation(source_template.reshape(N_e, N_t), target_template.reshape(N_e, N_t), positions)
estimated_distance = numpy.linalg.norm(guess)

if mask_intersect[i, j] and estimated_distance < drift_space:

optim = scipy.optimize.differential_evolution(get_difference, bounds=boundaries, args=(my_interpolating_function, target_template, full_positions), polish=False)
registered = my_interpolating_function(full_positions + optim.x)

cc = scipy.signal.correlate(target_template, registered, 'same').max() / (numpy.linalg.norm(target_template) * numpy.linalg.norm(registered))
registration[count, j, 3] = cc
registration[count, j, non_zeros] = optim.x[:3][non_zeros]

if debug_plots not in ['None', '']:

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
fig = plt.figure()
gs = GridSpec(3, 4)
axs = fig.add_subplot(gs[0,:])
axs.plot(target_template, c='k')
axs.set_title('Template %d' %j)
axs.spines['right'].set_visible(False)
axs.spines['top'].set_visible(False)
axs.set_yticks([], [])
axs.set_xticks([], [])

axs = fig.add_subplot(gs[1,:])
axs.plot(source_template)
axs.plot(registered)
axs.set_title('Template %d [%s Drift %g]' %(i, registration[count, j, :3], registration[count, j, 3]))
axs.spines['right'].set_visible(False)
axs.spines['top'].set_visible(False)
axs.set_yticks([], [])
axs.set_xticks([], [])

axs = fig.add_subplot(gs[2,:])
tmp_1 = (source_template - target_template)**2
tmp_2 = (registered - target_template)**2
axs.plot(tmp_1)
axs.plot(tmp_2)
axs.set_title('Differences [%g, %g]' %(tmp_1.mean(), tmp_2.mean()))
axs.spines['right'].set_visible(False)
axs.spines['top'].set_visible(False)
axs.set_yticks([], [])
axs.set_xticks([], [])

# ...
plt.tight_layout()
# Save and close figure.
output_path = os.path.join(
plot_path,
"registration_t{}_t{}.{}".format(
i,
j,
debug_plots
)
)
fig.savefig(output_path)
plt.close(fig)

registration = gather_array(registration.flatten(), comm, 0, 1, 'float32', compress=blosc_compress)

comm.Barrier()

if SHARED_MEMORY:
for memory in mpi_memory_1:
memory.Free()

if comm.rank == 0:
indices = []
registration = registration.reshape(nb_templates, nb_templates, 4)
for idx in range(comm.size):
indices += list(numpy.arange(idx, nb_templates, comm.size))
indices = numpy.argsort(indices).astype(numpy.int32)
registration = registration[indices, :]

for i in range(4):
registration[:,:,i] = numpy.maximum(registration[:,:,i], registration[:,:,i].T)

mask = numpy.tril(numpy.ones((nb_templates, nb_templates)), -1) > 0
for i in range(3):
registration[:,:,i][mask] *= -1

file_name = file_out_suff + '.templates.hdf5'
hfile = h5py.File(file_name, 'r+', libver='earliest')

if 'drifts' not in hfile.keys():
hfile.create_dataset('drifts', data=registration)
else:
hfile['drifts'][:] = registration
hfile.close()


def delete_mixtures(params, nb_cpu, nb_gpu, use_gpu):

data_file = params.data_file
Expand Down
10 changes: 10 additions & 0 deletions circus/shared/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,16 @@ def load_data(params, data, extension=''):
if comm.rank == 0:
print_and_log(["No templates found! Check suffix?"], 'error', logger)
sys.exit(0)
elif data == 'drifts':
if os.path.exists(file_out_suff + '.templates%s.hdf5' % extension):
myfile = h5py.File(file_out_suff + '.templates%s.hdf5' % extension, 'r', libver='earliest')
if 'drifts' in myfile.keys():
drifts = myfile.get('drifts')[:]
else:
N_e, N_t, nb_templates = myfile.get('temp_shape')[:].ravel()
drifts = numpy.zeros((nb_templates//2, nb_templates//2, 3), dtype=numpy.float32)
myfile.close()
return drifts
elif data == 'confusion':
if os.path.exists(file_out_suff + '.templates%s.hdf5' % extension):
myfile = h5py.File(file_out_suff + '.templates%s.hdf5' % extension, 'r', libver='earliest')
Expand Down