diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 5612fce32b..b2d02739e8 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -51,7 +51,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): }, }, "clustering": {"legacy": True}, - "matching": {"method": "wobble"}, + "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -264,6 +264,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["ms_after"] = exclude_sweep_ms clustering_params["verbose"] = verbose clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["noise_threshold"] = detection_params.get('detect_threshold', 4) legacy = clustering_params.get("legacy", True) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 681605a5bd..1619e82398 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -15,7 +15,7 @@ import random, string from spikeinterface.core import get_global_tmp_folder from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.core.waveform_tools import estimate_templates, estimate_templates_with_accumulator +from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -60,7 +60,7 @@ class CircusClustering: "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, - "noise_threshold": 5, + "noise_threshold": 4, "rank": 5, "noise_levels": None, "tmp_folder": None, @@ -239,7 +239,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - templates_array, templates_array_std = estimate_templates_with_accumulator( + templates_array = estimate_templates( recording, spikes, unit_ids, @@ -251,11 +251,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **job_kwargs, ) - with np.errstate(divide="ignore", invalid="ignore"): - peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :] - mask = ~np.isfinite(peak_snrs) - peak_snrs[mask] = 0 best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] valid_templates = best_snrs_ratio > params["noise_threshold"] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 4095c2a013..4f77862972 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -14,7 +14,7 @@ HAVE_HDBSCAN = False from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.core.waveform_tools import estimate_templates, estimate_templates_with_accumulator +from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser @@ -53,7 +53,7 @@ class RandomProjectionClustering: "random_seed": 42, "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, - "noise_threshold": 5, + "noise_threshold": 4, "tmp_folder": None, "verbose": True, } @@ -133,7 +133,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - templates_array, templates_array_std = estimate_templates_with_accumulator( + templates_array = estimate_templates( recording, spikes, unit_ids, @@ -145,11 +145,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **job_kwargs, ) - with np.errstate(divide="ignore", invalid="ignore"): - peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :] - mask = ~np.isfinite(peak_snrs) - peak_snrs[mask] = 0 best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] valid_templates = best_snrs_ratio > params["noise_threshold"]