From 158347cad1dc7f63c7e7ed8b189588b43d1b6047 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:49:13 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 18 ++++++++----- .../tests/test_peak_detection.py | 4 ++- src/spikeinterface/sortingcomponents/tools.py | 26 ++++++++++++------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 47e5fc191f..b47f415390 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -11,7 +11,11 @@ from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion -from spikeinterface.sortingcomponents.tools import cache_preprocessing, get_prototype_and_waveforms, get_shuffled_recording_slices +from spikeinterface.sortingcomponents.tools import ( + cache_preprocessing, + get_prototype_and_waveforms, + get_shuffled_recording_slices, +) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.sortinganalyzer import create_sorting_analyzer @@ -204,14 +208,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["matched_filtering"]: prototype, waveforms = get_prototype_and_waveforms( - recording_w, - n_peaks=5000, - ms_before=ms_before, - ms_after=ms_after, + recording_w, + n_peaks=5000, + ms_before=ms_before, + ms_after=ms_after, seed=params["seed"], return_waveforms=True, - **detection_params, - **job_kwargs + **detection_params, + **job_kwargs, ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index e0bbf66af5..b6ef240f80 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ms_before = 1.0 ms_after = 1.0 - prototype = get_prototype_and_waveforms(recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs) + prototype = get_prototype_and_waveforms( + recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) peaks_local_mf_filtering = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index a5526871f1..6ee55c7b61 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -68,11 +68,14 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs): + +def get_prototype_and_waveforms( + recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs +): """ Helper function to extract a prototype waveform from a peak list or from a peak detection """ - + seed = seed if seed else None rng = np.random.default_rng(seed=seed) @@ -83,6 +86,7 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0 if peaks is None: from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ExtractSparseWaveforms + node = ExtractSparseWaveforms( recording, parents=None, @@ -96,16 +100,20 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0 recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) res = detect_peaks( - recording, pipeline_nodes=pipeline_nodes, - skip_after_n_peaks=n_peaks, - recording_slices=recording_slices, - **detection_kwargs, - **job_kwargs, + recording, + pipeline_nodes=pipeline_nodes, + skip_after_n_peaks=n_peaks, + recording_slices=recording_slices, + **detection_kwargs, + **job_kwargs, ) waveforms = res[1] else: from spikeinterface.sortingcomponents.peak_selection import select_peaks - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed) + + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed + ) waveforms = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) @@ -202,4 +210,4 @@ def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): rng = np.random.RandomState(seed) recording_slices = rng.permutation(recording_slices) - return recording_slices \ No newline at end of file + return recording_slices