Skip to content

Commit

Permalink
Docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 8, 2025
2 parents 437d8b4 + 158347c commit 2cad313
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
10 changes: 7 additions & 3 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from spikeinterface.core.template import Templates
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion
from spikeinterface.sortingcomponents.tools import cache_preprocessing, get_prototype_and_waveforms, get_shuffled_recording_slices
from spikeinterface.sortingcomponents.tools import (
cache_preprocessing,
get_prototype_and_waveforms,
get_shuffled_recording_slices,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.sparsity import compute_sparsity
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
Expand Down Expand Up @@ -208,8 +212,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
ms_after=ms_after,
seed=params["seed"],
return_waveforms=True,
**detection_params,
**job_kwargs
**detection_params,
**job_kwargs,
)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs)

ms_before = 1.0
ms_after = 1.0
prototype = get_prototype_and_waveforms(recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs)
prototype = get_prototype_and_waveforms(
recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs
)

peaks_local_mf_filtering = detect_peaks(
recording,
Expand Down
26 changes: 17 additions & 9 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j

return all_wfs

def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs):

def get_prototype_and_waveforms(
recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs
):
"""
Function to extract a prototype waveform from a peak list or from a peak detection. Note that in case
of a peak detection, the detection stops as soon as n_peaks are detected.
Expand Down Expand Up @@ -99,7 +102,7 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0
waveforms : numpy.array, optional
The extracted waveforms, returned if return_waveforms is True.
"""

seed = seed if seed else None
rng = np.random.default_rng(seed=seed)

Expand All @@ -110,6 +113,7 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0
if peaks is None:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import ExtractSparseWaveforms

node = ExtractSparseWaveforms(
recording,
parents=None,
Expand All @@ -123,16 +127,20 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0
recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs)

res = detect_peaks(
recording, pipeline_nodes=pipeline_nodes,
skip_after_n_peaks=n_peaks,
recording_slices=recording_slices,
**detection_kwargs,
**job_kwargs,
recording,
pipeline_nodes=pipeline_nodes,
skip_after_n_peaks=n_peaks,
recording_slices=recording_slices,
**detection_kwargs,
**job_kwargs,
)
waveforms = res[1]
else:
from spikeinterface.sortingcomponents.peak_selection import select_peaks
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed)

few_peaks = select_peaks(
peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed
)
waveforms = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
)
Expand Down Expand Up @@ -229,4 +237,4 @@ def get_shuffled_recording_slices(recording, seed=None, **job_kwargs):
rng = np.random.RandomState(seed)
recording_slices = rng.permutation(recording_slices)

return recording_slices
return recording_slices

0 comments on commit 2cad313

Please sign in to comment.