Skip to content

Commit

Permalink
Changed default device to None for run_kilosort, use GPU if available
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Mar 1, 2024
1 parent 328eaa0 commit 504339d
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def run_kilosort(settings=None, probe=None, probe_name=None, data_dir=None,
filename=None, file_object=None, data_dtype=None,
results_dir=None, do_CAR=True, invert_sign=False,
device=torch.device('cuda'), progress_bar=None,
device=None, progress_bar=None,
save_extra_vars=False):
"""Spike sort the given dataset.
Expand Down Expand Up @@ -65,6 +65,16 @@ def run_kilosort(settings=None, probe=None, probe_name=None, data_dir=None,
print("Interpreting binary file as default dtype='int16'. If data was "
"saved in a different format, specify `data_dtype`.")

if device is None:
if torch.cuda.is_available():
print('Using GPU for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cuda')
else:
print('Using CPU for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cpu')

# NOTE: Also modifies settings in-place
filename, data_dir, results_dir, probe = \
set_files(settings, filename, probe, probe_name, data_dir, results_dir)
Expand Down

0 comments on commit 504339d

Please sign in to comment.