Skip to content

Commit

Permalink
Merge pull request #643 from MouseLand/jacob/pc_features
Browse files Browse the repository at this point in the history
Jacob/pc features
  • Loading branch information
jacobpennington committed Mar 29, 2024
2 parents 336a976 + 207d4a3 commit 60f48a2
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 8 deletions.
16 changes: 13 additions & 3 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b
return clu, Wall


def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, ncomps = 64):
def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,
ncomps=64, ix=None, merge_dim=True):
PID = torch.from_numpy(PID).long()

#iU = ops['iU'].cpu().numpy()
Expand All @@ -341,7 +342,11 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32,
x0 = xcenter #xy[0].mean() - xcenter

#print(dmin, dminx)
ix = torch.logical_and(torch.abs(xy[1] - y0) < dmin, torch.abs(xy[0] - x0) < dminx)
if ix is None:
ix = torch.logical_and(
torch.abs(xy[1] - y0) < dmin,
torch.abs(xy[0] - x0) < dminx
)
#print(ix.nonzero()[:,0])
igood = ix[PID].nonzero()[:,0]

Expand All @@ -362,7 +367,12 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32,
#print(ij.sum())
dd[ij.unsqueeze(-1), iC[:,j]-ch_min] = data[ij]

Xd = torch.reshape(dd, (nspikes, -1))
if merge_dim:
Xd = torch.reshape(dd, (nspikes, -1))
else:
# Keep channels and features separate
Xd = dd

return Xd, ch_min, ch_max, igood


Expand Down
26 changes: 23 additions & 3 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from kilosort import CCG
from kilosort.preprocessing import get_drift_matrix, fft_highpass
from kilosort.postprocessing import remove_duplicates, compute_spike_positions
from kilosort.postprocessing import (
remove_duplicates, compute_spike_positions, make_pc_features
)

_torch_warning = ".*PyTorch does not support non-writable tensors"

Expand Down Expand Up @@ -197,6 +199,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
np.save((results_dir / 'templates.npy'), templates)
np.save((results_dir / 'templates_ind.npy'), templates_ind)

# pc features
if save_extra_vars:
# Save tF first since it gets updated in-place
np.save(results_dir / 'tF.npy', tF.cpu().numpy())
# This will momentarily copy tF which is pretty large, but it's on CPU
# so the extra memory hopefully won't be an issue.
tF = tF[kept_spikes]
pc_features, pc_feature_ind = make_pc_features(
ops, spike_templates, spike_clusters, tF
)
np.save(results_dir / 'pc_features.npy', pc_features)
np.save(results_dir / 'pc_feature_ind.npy', pc_feature_ind)

# contamination ratio
acg_threshold = ops['settings']['acg_threshold']
ccg_threshold = ops['settings']['ccg_threshold']
Expand Down Expand Up @@ -233,14 +248,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
f.write(f'{key} = {params[key]}\n')

if save_extra_vars:
# Also save tF and Wall, for easier debugging/analysis
np.save(results_dir / 'tF.npy', tF.cpu().numpy())
# Also save Wall, for easier debugging/analysis
np.save(results_dir / 'Wall.npy', Wall.cpu().numpy())
# And full st, clu, amp arrays with no spikes removed
np.save(results_dir / 'full_st.npy', st)
np.save(results_dir / 'full_clu.npy', clu)
np.save(results_dir / 'full_amp.npy', amplitudes)

# Remove cached .phy results if present from running Phy on a previous
# version of results in the same directory.
phy_cache_path = Path(results_dir / '.phy')
if phy_cache_path.is_dir():
shutil.rmtree(phy_cache_path)

return results_dir, similar_templates, is_ref, est_contam_rate


Expand Down
69 changes: 69 additions & 0 deletions kilosort/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch

from kilosort.clustering_qr import xy_templates, get_data_cpu


@njit("(int64[:], int32[:], int32)")
def remove_duplicates(spike_times, spike_clusters, dt=15):
Expand Down Expand Up @@ -42,3 +44,70 @@ def compute_spike_positions(st, tF, ops):
ys = (yc0 * tmass).sum(1).cpu().numpy()

return xs, ys


def make_pc_features(ops, spike_templates, spike_clusters, tF):
'''Get PC Features and corresponding indices for export to Phy.
NOTE: This function will update tF in-place!
Parameters
----------
ops : dict
Dictionary of state variables updated throughout the sorting process.
This function is intended to be used with the final state of ops, after
all sorting has finished.
spike_templates : np.ndarray
Vector of template ids with shape `(n_spikes,)`. This is equivalent to
`st[:,1]`, where `st` is returned by `template_matching.extract`.
spike_clusters : np.ndarray
Vector of cluster ids with shape `(n_pikes,)`. This is equivalent to
`clu` returned by `template_matching.merging_function`.
tF : torch.Tensor
Tensor of pc features as returned by `template_matching.extract`,
with shape `(n_spikes, nearest_chans, n_pcs)`.
Returns
-------
tF : torch.Tensor
As above, but with some data replaced so that features are associated
with the final clusters instead of templates. The second and third
dimensions are also swapped to conform to the shape expected by Phy.
feature_ind : np.ndarray
Channel indices associated with the data present in tF for each cluster,
with shape `(n_clusters, nearest_chans)`.
'''

# xy: template centers, iC: channels associated with each template
xy, iC = xy_templates(ops)
n_clusters = np.unique(spike_clusters).size
n_chans = ops['nearest_chans']
feature_ind = np.zeros((n_clusters, n_chans), dtype=np.uint32)

for i in np.unique(spike_clusters):
# Get templates associated with cluster (often just 1)
iunq = np.unique(spike_templates[spike_clusters==i]).astype(int)
# Get boolean mask with size (n_templates,), True if they match cluster
ix = torch.from_numpy(np.zeros(int(spike_templates.max())+1, bool))
ix[iunq] = True
# Get PC features for all spikes detected with those templates (Xd),
# and the indices in tF where those spikes occur (igood).
Xd, ch_min, ch_max, igood = get_data_cpu(
ops, xy, iC, spike_templates, tF, None, None,
dmin=ops['dmin'], dminx=ops['dminx'], ix=ix, merge_dim=False
)

# Take mean of features across spikes, find channels w/ largest norm
spike_mean = Xd.mean(0)
chan_norm = torch.linalg.norm(spike_mean, dim=1)
sorted_chans, ind = torch.sort(chan_norm, descending=True)
# Assign features to overwrite tF in-place
tF[igood,:] = Xd[:, ind[:n_chans], :]
# Save channel inds for phy
feature_ind[i,:] = ind[:n_chans].numpy() + ch_min.cpu().numpy()

# Swap last 2 dimensions to get ordering Phy expects
tF = torch.permute(tF, (0, 2, 1))

return tF, feature_ind
5 changes: 3 additions & 2 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,11 @@ def load_sorting(results_dir, device=None, load_extra_vars=False):
results = [ops, st, clu, similar_templates, is_ref, est_contam_rate]

if load_extra_vars:
# NOTE: tF and Wall always go on CPU, not CUDA
tF = np.load(results_dir / 'tF.npy')
tF = torch.from_numpy(tF).to(device)
tF = torch.from_numpy(tF)
Wall = np.load(results_dir / 'Wall.npy')
Wall = torch.from_numpy(Wall).to(device)
Wall = torch.from_numpy(Wall)
full_st = np.load(results_dir / 'full_st.npy')
full_clu = np.load(results_dir / 'full_clu.npy')
full_amp = np.load(results_dir / 'full_amp.npy')
Expand Down

0 comments on commit 60f48a2

Please sign in to comment.