Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory usage for finding nearest channels is much higher in Kilosort4 vs. Kilosort3 #644

Closed
naterenegar opened this issue Mar 27, 2024 · 5 comments
Assignees

Comments

@naterenegar
Copy link

Describe the issue:

Hello,

I'm working with data from a 2D MEA. It seems that kilosort3 works on the data but not kilosort4 due to the differences in electrode location upsampling in the two versions. In the spike extraction step, there are two calls to a function that finds the nearby channels for every channel. The first call finds the nearest upsampled locations to all of the original channel locations, and the second call finds the nearest upsampled locations to all of the upsampled locations.

In kilosort3, the upsampled electrodes seem to be distance gated before the second call. This is so we only look at upsampled locations near original channels? The snippet from extract_spikes.m:

NchanNear = 8;
[iC, dist] = getClosestChannels2(ycup, xcup, rez.yc, rez.xc, NchanNear);

igood = dist(1,:)<dNearActiveSite;
iC = iC(:, igood);
dist = dist(:, igood);

ycup = ycup(igood);
xcup = xcup(igood);

NchanNearUp =  min(numel(ycup), 10*NchanNear);
[iC2, dist2] = getClosestChannels2(ycup, xcup, ycup, xcup, NchanNearUp);

But in kilosort4, all of the upsampled locations are kept:

[ys, xs] = np.meshgrid(ops['yup'], ops['xup'])
ys, xs = ys.flatten(), xs.flatten()
ops['ycup'], ops['xcup'] = ys, xs

xc, yc = ops['xc'], ops['yc']
Nfilt = len(ys)

nC = ops['settings']['nearest_chans']
nC2 = ops['settings']['nearest_templates']
iC, ds = nearest_chans(ys, yc, xs, xc, nC, device=device)
iC2, ds2 = nearest_chans(ys, ys, xs, xs, nC2, device=device)

This snippet fails at the very last line with a CUDA out of memory error. My GPU has 24GB of VRAM. Specifically, the distance calculation in nearest_chans fails:

def nearest_chans(ys, yc, xs, xc, nC, device=torch.device('cuda')):
    ds = (ys - yc[:,np.newaxis])**2 + (xs - xc[:,np.newaxis])**2 # <-- fails here
    iC = np.argsort(ds, 0)[:nC]
    iC = torch.from_numpy(iC).to(device)
    ds = np.sort(ds, 0)[:nC]
    return iC, ds

I understand the main application of Kilosort is for in vivo shank probes. In this case, the probes are not very wide, and most electrodes are close together so that the number of upsampled electrodes is manageable. For 2D-MEAs, which can be wider than they are tall, and also can have large gaps between recording clusters, the upsampling adds many electrodes that are nowhere near recording sites.

In my case, the original HD-MEA has ~26,000 electrodes, but we can only record from ~1,000 sites. If the sites are spread across the MEA, then the upsampling procedure creates 4 times the number of total sites on the MEA (doubling in each dimension). Then the distance calculation is creating a ~(100000,100000) matrix of floats, which is tens of gigabytes. If the sites were distance gated, I'd guess this number would be drastically reduced.

Thanks,
Nathan

@marius10p
Copy link
Contributor

Thank you for this issue. I think we just missed the distance gating in Kilosort4. Were you getting good results overall with Kilosort3? There is another step where I would imagine problems (clustering in groups of nearest channels).

@naterenegar
Copy link
Author

Hi Marius,

Yeah kilosort3 was giving me good results! I haven't looked at the specific step you mentioned to see how it would do on MEA data.

P.S., it seems someone else has encountered this issue #647

@jacobpennington
Copy link
Collaborator

@naterenegar I just pushed version 4.0.4, which should address this problem. Would you mind trying it out on your data when you have time and letting us know how it goes? Note that you will likely need to set the new x_centers parameter for a 2D array like this (under "Extra settings" if you're using the GUI). The goal with that parameter is to not include too many templates in a single grouping, it will divide up the horizontal space of the probe into that many sections. I would start with a value around 10 for a large array, and try increasing it if the clustering step seems exceptionally slow. You can see where the grouping centers get placed by checking the box by that name under the probe plot in the GUI.

You might also need to set max_channel_distance (also under "Extra settings") if you still run into memory issues. This controls the distance gating you pointed out from previous versions. By default it will keep templates that are within max(dmin, dminx), but that might be overkill if there's a lot of space between channels. You can also preview those in the GUI with the "Universal Templates" check-box.

@yyyaaaaaaa
Copy link

@naterenegar
Hi, we are currently trying to process 2D MEA data using Kilosort, and it seems we are using the same product. I'd like to ask, when creating probes, do I need to create a probe array of 26400 probes, or just use the channels that are currently in use? Thanks :)

@jacobpennington
Copy link
Collaborator

Closing this for now, please let us know if you try out the new version and still have problems sorting this dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants