Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jacob/pc features #643

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 13 additions & 3 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b
return clu, Wall


def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, ncomps = 64):
def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,
ncomps=64, ix=None, merge_dim=True):
PID = torch.from_numpy(PID).long()

#iU = ops['iU'].cpu().numpy()
Expand All @@ -341,7 +342,11 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32,
x0 = xcenter #xy[0].mean() - xcenter

#print(dmin, dminx)
ix = torch.logical_and(torch.abs(xy[1] - y0) < dmin, torch.abs(xy[0] - x0) < dminx)
if ix is None:
ix = torch.logical_and(
torch.abs(xy[1] - y0) < dmin,
torch.abs(xy[0] - x0) < dminx
)
#print(ix.nonzero()[:,0])
igood = ix[PID].nonzero()[:,0]

Expand All @@ -362,7 +367,12 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32,
#print(ij.sum())
dd[ij.unsqueeze(-1), iC[:,j]-ch_min] = data[ij]

Xd = torch.reshape(dd, (nspikes, -1))
if merge_dim:
Xd = torch.reshape(dd, (nspikes, -1))
else:
# Keep channels and features separate
Xd = dd

return Xd, ch_min, ch_max, igood


Expand Down
26 changes: 23 additions & 3 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from kilosort import CCG
from kilosort.preprocessing import get_drift_matrix, fft_highpass
from kilosort.postprocessing import remove_duplicates, compute_spike_positions
from kilosort.postprocessing import (
remove_duplicates, compute_spike_positions, make_pc_features
)


def find_binary(data_dir: Union[str, os.PathLike]) -> Path:
Expand Down Expand Up @@ -195,6 +197,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
np.save((results_dir / 'templates.npy'), templates)
np.save((results_dir / 'templates_ind.npy'), templates_ind)

# pc features
if save_extra_vars:
# Save tF first since it gets updated in-place
np.save(results_dir / 'tF.npy', tF.cpu().numpy())
# This will momentarily copy tF which is pretty large, but it's on CPU
# so the extra memory hopefully won't be an issue.
tF = tF[kept_spikes]
pc_features, pc_feature_ind = make_pc_features(
ops, spike_templates, spike_clusters, tF
)
np.save(results_dir / 'pc_features.npy', pc_features)
np.save(results_dir / 'pc_feature_ind.npy', pc_feature_ind)

# contamination ratio
acg_threshold = ops['settings']['acg_threshold']
ccg_threshold = ops['settings']['ccg_threshold']
Expand Down Expand Up @@ -231,14 +246,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
f.write(f'{key} = {params[key]}\n')

if save_extra_vars:
# Also save tF and Wall, for easier debugging/analysis
np.save(results_dir / 'tF.npy', tF.cpu().numpy())
# Also save Wall, for easier debugging/analysis
np.save(results_dir / 'Wall.npy', Wall.cpu().numpy())
# And full st, clu, amp arrays with no spikes removed
np.save(results_dir / 'full_st.npy', st)
np.save(results_dir / 'full_clu.npy', clu)
np.save(results_dir / 'full_amp.npy', amplitudes)

# Remove cached .phy results if present from running Phy on a previous
# version of results in the same directory.
phy_cache_path = Path(results_dir / '.phy')
if phy_cache_path.is_dir():
shutil.rmtree(phy_cache_path)

return results_dir, similar_templates, is_ref, est_contam_rate


Expand Down
69 changes: 69 additions & 0 deletions kilosort/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch

from kilosort.clustering_qr import xy_templates, get_data_cpu


@njit("(int64[:], int32[:], int32)")
def remove_duplicates(spike_times, spike_clusters, dt=15):
Expand Down Expand Up @@ -42,3 +44,70 @@ def compute_spike_positions(st, tF, ops):
ys = (yc0 * tmass).sum(1).cpu().numpy()

return xs, ys


def make_pc_features(ops, spike_templates, spike_clusters, tF):
'''Get PC Features and corresponding indices for export to Phy.

NOTE: This function will update tF in-place!

Parameters
----------
ops : dict
Dictionary of state variables updated throughout the sorting process.
This function is intended to be used with the final state of ops, after
all sorting has finished.
spike_templates : np.ndarray
Vector of template ids with shape `(n_spikes,)`. This is equivalent to
`st[:,1]`, where `st` is returned by `template_matching.extract`.
spike_clusters : np.ndarray
Vector of cluster ids with shape `(n_pikes,)`. This is equivalent to
`clu` returned by `template_matching.merging_function`.
tF : torch.Tensor
Tensor of pc features as returned by `template_matching.extract`,
with shape `(n_spikes, nearest_chans, n_pcs)`.

Returns
-------
tF : torch.Tensor
As above, but with some data replaced so that features are associated
with the final clusters instead of templates. The second and third
dimensions are also swapped to conform to the shape expected by Phy.
feature_ind : np.ndarray
Channel indices associated with the data present in tF for each cluster,
with shape `(n_clusters, nearest_chans)`.

'''

# xy: template centers, iC: channels associated with each template
xy, iC = xy_templates(ops)
n_clusters = np.unique(spike_clusters).size
n_chans = ops['nearest_chans']
feature_ind = np.zeros((n_clusters, n_chans), dtype=np.uint32)

for i in np.unique(spike_clusters):
# Get templates associated with cluster (often just 1)
iunq = np.unique(spike_templates[spike_clusters==i]).astype(int)
# Get boolean mask with size (n_templates,), True if they match cluster
ix = torch.from_numpy(np.zeros(int(spike_templates.max())+1, bool))
ix[iunq] = True
# Get PC features for all spikes detected with those templates (Xd),
# and the indices in tF where those spikes occur (igood).
Xd, ch_min, ch_max, igood = get_data_cpu(
ops, xy, iC, spike_templates, tF, None, None,
dmin=ops['dmin'], dminx=ops['dminx'], ix=ix, merge_dim=False
)

# Take mean of features across spikes, find channels w/ largest norm
spike_mean = Xd.mean(0)
chan_norm = torch.linalg.norm(spike_mean, dim=1)
sorted_chans, ind = torch.sort(chan_norm, descending=True)
# Assign features to overwrite tF in-place
tF[igood,:] = Xd[:, ind[:n_chans], :]
# Save channel inds for phy
feature_ind[i,:] = ind[:n_chans].numpy() + ch_min.cpu().numpy()

# Swap last 2 dimensions to get ordering Phy expects
tF = torch.permute(tF, (0, 2, 1))

return tF, feature_ind
5 changes: 3 additions & 2 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,11 @@ def load_sorting(results_dir, device=None, load_extra_vars=False):
results = [ops, st, clu, similar_templates, is_ref, est_contam_rate]

if load_extra_vars:
# NOTE: tF and Wall always go on CPU, not CUDA
tF = np.load(results_dir / 'tF.npy')
tF = torch.from_numpy(tF).to(device)
tF = torch.from_numpy(tF)
Wall = np.load(results_dir / 'Wall.npy')
Wall = torch.from_numpy(Wall).to(device)
Wall = torch.from_numpy(Wall)
full_st = np.load(results_dir / 'full_st.npy')
full_clu = np.load(results_dir / 'full_clu.npy')
full_amp = np.load(results_dir / 'full_amp.npy')
Expand Down