Skip to content

Commit

Permalink
Removing artefactual templates due to matched filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 9, 2025
1 parent 731c713 commit 7c32a65
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
},
},
"clustering": {"legacy": True},
"matching": {"method": "wobble"},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
"matched_filtering": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
Expand Down Expand Up @@ -264,6 +264,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["ms_after"] = exclude_sweep_ms
clustering_params["verbose"] = verbose
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
clustering_params["noise_threshold"] = detection_params.get('detect_threshold', 4)

legacy = clustering_params.get("legacy", True)

Expand Down
11 changes: 4 additions & 7 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random, string
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.waveform_tools import estimate_templates, estimate_templates_with_accumulator
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.sortingcomponents.peak_selection import select_peaks
Expand Down Expand Up @@ -60,7 +60,7 @@ class CircusClustering:
"few_waveforms": None,
"ms_before": 0.5,
"ms_after": 0.5,
"noise_threshold": 5,
"noise_threshold": 4,
"rank": 5,
"noise_levels": None,
"tmp_folder": None,
Expand Down Expand Up @@ -239,7 +239,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

templates_array, templates_array_std = estimate_templates_with_accumulator(
templates_array = estimate_templates(
recording,
spikes,
unit_ids,
Expand All @@ -251,11 +251,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
**job_kwargs,
)

with np.errstate(divide="ignore", invalid="ignore"):
peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :]
mask = ~np.isfinite(peak_snrs)
peak_snrs[mask] = 0
best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1)
peak_snrs = np.abs(templates_array[:, nbefore, :])
best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels]
valid_templates = best_snrs_ratio > params["noise_threshold"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
HAVE_HDBSCAN = False

from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.waveform_tools import estimate_templates, estimate_templates_with_accumulator
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
Expand Down Expand Up @@ -53,7 +53,7 @@ class RandomProjectionClustering:
"random_seed": 42,
"noise_levels": None,
"smoothing_kwargs": {"window_length_ms": 0.25},
"noise_threshold": 5,
"noise_threshold": 4,
"tmp_folder": None,
"verbose": True,
}
Expand Down Expand Up @@ -133,7 +133,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

templates_array, templates_array_std = estimate_templates_with_accumulator(
templates_array = estimate_templates(
recording,
spikes,
unit_ids,
Expand All @@ -145,11 +145,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
**job_kwargs,
)

with np.errstate(divide="ignore", invalid="ignore"):
peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :]
mask = ~np.isfinite(peak_snrs)
peak_snrs[mask] = 0
best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1)
peak_snrs = np.abs(templates_array[:, nbefore, :])
best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels]
valid_templates = best_snrs_ratio > params["noise_threshold"]

Expand Down

0 comments on commit 7c32a65

Please sign in to comment.