/
spikesorting_helperfunctions.py
138 lines (117 loc) · 5.66 KB
/
spikesorting_helperfunctions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
# +
import numpy as np
import matplotlib.pyplot as plt
def plotSignal(signal, sampling_freq, from_in_s=0, to_in_s=None, show=True):
"""
Plots data from a single channel
:param analog_stream: A AnalogStream object
:param sampling_freq: Sampling frequency
:param from_in_s: The start timestamp of the plot (0 <= from_in_s < to_in_s). Default: 0
:param to_in_s: The end timestamp of the plot (from_in_s < to_in_s <= duration). Default: None (= recording duration)
:param show: If True (default), the plot is directly created. For further plotting, use show=False
"""
# get start and end index
from_idx = max(0, int(from_in_s * sampling_freq))
if to_in_s is None:
to_idx = analog_stream.channel_data.shape[1]
else:
to_idx = min(analog_stream.channel_data.shape[1], int(to_in_s * sampling_frequency))
# get the timestamps for each sample
time = analog_stream.get_channel_sample_timestamps(channel_id, from_idx, to_idx)
# scale time to seconds:
scale_factor_for_second = Q_(1,time[1]).to(ureg.s).magnitude
time_in_sec = time[0] * scale_factor_for_second
# get the signal
signal = analog_stream.get_channel_in_range(channel_id, from_idx, to_idx)
# scale signal to µV:
scale_factor_for_uV = Q_(1,signal[1]).to(ureg.uV).magnitude
signal_in_uV = signal[0] * scale_factor_for_uV
# construct the plot
_ = plt.figure(figsize=(20,6))
_ = plt.plot(time_in_sec, signal_in_uV)
_ = plt.xlabel('Time (%s)' % ureg.s)
_ = plt.ylabel('Voltage (%s)' % ureg.uV)
_ = plt.title('Channel %s' % channel_info.info['Label'])
if show:
plt.show()
def detect_threshold_crossings(signal, fs, threshold, dead_time):
"""
Detect threshold crossings in a signal with dead time and return them as an array
The signal transitions from a sample above the threshold to a sample below the threshold for a detection and
the last detection has to be more than dead_time apart from the current one.
:param signal: The signal as a 1-dimensional numpy array
:param fs: The sampling frequency in Hz
:param threshold: The threshold for the signal
:param dead_time: The dead time in seconds.
"""
dead_time_idx = dead_time * fs
threshold_crossings = np.diff((signal <= threshold).astype(int) > 0).nonzero()[0]
distance_sufficient = np.insert(np.diff(threshold_crossings) >= dead_time_idx, 0, True)
while not np.all(distance_sufficient):
# repeatedly remove all threshold crossings that violate the dead_time
threshold_crossings = threshold_crossings[distance_sufficient]
distance_sufficient = np.insert(np.diff(threshold_crossings) >= dead_time_idx, 0, True)
return threshold_crossings
def get_next_minimum(signal, index, max_samples_to_search):
"""
Returns the index of the next minimum in the signal after an index
:param signal: The signal as a 1-dimensional numpy array
:param index: The scalar index
:param max_samples_to_search: The number of samples to search for a minimum after the index
"""
search_end_idx = min(index + max_samples_to_search, signal.shape[0])
min_idx = np.argmin(signal[index:search_end_idx])
return index + min_idx
def align_to_minimum(signal, fs, threshold_crossings, search_range):
"""
Returns the index of the next negative spike peak for all threshold crossings
:param signal: The signal as a 1-dimensional numpy array
:param fs: The sampling frequency in Hz
:param threshold_crossings: The array of indices where the signal crossed the detection threshold
:param search_range: The maximum duration in seconds to search for the minimum after each crossing
"""
search_end = int(search_range*fs)
aligned_spikes = [get_next_minimum(signal, t, search_end) for t in threshold_crossings]
return np.array(aligned_spikes)
def extract_waveforms(signal, fs, spikes_idx, pre, post):
"""
Extract spike waveforms as signal cutouts around each spike index as a spikes x samples numpy array
:param signal: The signal as a 1-dimensional numpy array
:param fs: The sampling frequency in Hz
:param spikes_idx: The sample index of all spikes as a 1-dim numpy array
:param pre: The duration of the cutout before the spike in seconds
:param post: The duration of the cutout after the spike in seconds
"""
cutouts = []
pre_idx = int(pre * fs)
post_idx = int(post * fs)
for index in spikes_idx:
if index-pre_idx >= 0 and index+post_idx <= signal.shape[0]:
cutout = signal[(index-pre_idx):(index+post_idx)]
cutouts.append(cutout)
return np.stack(cutouts)
def plot_waveforms(cutouts, fs, pre, post, n=100, color='k', show=True):
"""
Plot an overlay of spike cutouts
:param cutouts: A spikes x samples array of cutouts
:param fs: The sampling frequency in Hz
:param pre: The duration of the cutout before the spike in seconds
:param post: The duration of the cutout after the spike in seconds
:param n: The number of cutouts to plot, or None to plot all. Default: 100
:param color: The line color as a pyplot line/marker style. Default: 'k'=black
:param show: Set this to False to disable showing the plot. Default: True
"""
if n is None:
n = cutouts.shape[0]
n = min(n, cutouts.shape[0])
time_in_us = np.arange(-pre*1000, post*1000, 1e3/fs)
if show:
_ = plt.figure(figsize=(10,6))
for i in range(n):
_ = plt.plot(time_in_us, cutouts[i,]*1e6, color, linewidth=1, alpha=0.3)
_ = plt.xlabel('Time (ms)')
_ = plt.ylabel('Voltage (mV)')
_ = plt.title('Spike Waveforms')
if show:
plt.show()