Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 8, 2025
1 parent e4d7e72 commit 158347c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
18 changes: 11 additions & 7 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from spikeinterface.core.template import Templates
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion
from spikeinterface.sortingcomponents.tools import cache_preprocessing, get_prototype_and_waveforms, get_shuffled_recording_slices
from spikeinterface.sortingcomponents.tools import (
cache_preprocessing,
get_prototype_and_waveforms,
get_shuffled_recording_slices,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.sparsity import compute_sparsity
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
Expand Down Expand Up @@ -204,14 +208,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

if params["matched_filtering"]:
prototype, waveforms = get_prototype_and_waveforms(
recording_w,
n_peaks=5000,
ms_before=ms_before,
ms_after=ms_after,
recording_w,
n_peaks=5000,
ms_before=ms_before,
ms_after=ms_after,
seed=params["seed"],
return_waveforms=True,
**detection_params,
**job_kwargs
**detection_params,
**job_kwargs,
)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs)

ms_before = 1.0
ms_after = 1.0
prototype = get_prototype_and_waveforms(recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs)
prototype = get_prototype_and_waveforms(
recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs
)

peaks_local_mf_filtering = detect_peaks(
recording,
Expand Down
26 changes: 17 additions & 9 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j

return all_wfs

def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs):

def get_prototype_and_waveforms(
recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs
):
"""
Helper function to extract a prototype waveform from a peak list or from a peak detection
"""

seed = seed if seed else None
rng = np.random.default_rng(seed=seed)

Expand All @@ -83,6 +86,7 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0
if peaks is None:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import ExtractSparseWaveforms

node = ExtractSparseWaveforms(
recording,
parents=None,
Expand All @@ -96,16 +100,20 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0
recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs)

res = detect_peaks(
recording, pipeline_nodes=pipeline_nodes,
skip_after_n_peaks=n_peaks,
recording_slices=recording_slices,
**detection_kwargs,
**job_kwargs,
recording,
pipeline_nodes=pipeline_nodes,
skip_after_n_peaks=n_peaks,
recording_slices=recording_slices,
**detection_kwargs,
**job_kwargs,
)
waveforms = res[1]
else:
from spikeinterface.sortingcomponents.peak_selection import select_peaks
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed)

few_peaks = select_peaks(
peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed
)
waveforms = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
)
Expand Down Expand Up @@ -202,4 +210,4 @@ def get_shuffled_recording_slices(recording, seed=None, **job_kwargs):
rng = np.random.RandomState(seed)
recording_slices = rng.permutation(recording_slices)

return recording_slices
return recording_slices

0 comments on commit 158347c

Please sign in to comment.