From 4c4d8a73db42948b04dc23fde5b870b5e7d07193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 2 Jul 2025 16:13:17 +0200 Subject: [PATCH 1/3] Fix sampling rate issue when aggregating Different versions of Kilosort output different sampling rate (sometimes rounding it to 2 decimal places, sometimes not) This makes a crash when trying to aggregate both sortings together This PR fixes this --- src/spikeinterface/core/unitsaggregationsorting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 8f4a2732c3..300d668982 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import warnings import numpy as np @@ -62,7 +63,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None): sampling_frequency = sorting_list[0].get_sampling_frequency() num_segments = sorting_list[0].get_num_segments() - ok1 = all(sampling_frequency == sort.get_sampling_frequency() for sort in sorting_list) + ok1 = all(math.isclose(sampling_frequency, sort.get_sampling_frequency(), abs_tol=1e-2) for sort in sorting_list) ok2 = all(num_segments == sort.get_num_segments() for sort in sorting_list) if not (ok1 and ok2): raise ValueError("Sortings don't have the same sampling_frequency/num_segments") From 8167f65a5d6e5d9d54f6c61fdc25b66227363978 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Jul 2025 14:13:54 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/unitsaggregationsorting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 300d668982..cb824abd49 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -63,7 +63,9 @@ def __init__(self, sorting_list, renamed_unit_ids=None): sampling_frequency = sorting_list[0].get_sampling_frequency() num_segments = sorting_list[0].get_num_segments() - ok1 = all(math.isclose(sampling_frequency, sort.get_sampling_frequency(), abs_tol=1e-2) for sort in sorting_list) + ok1 = all( + math.isclose(sampling_frequency, sort.get_sampling_frequency(), abs_tol=1e-2) for sort in sorting_list + ) ok2 = all(num_segments == sort.get_num_segments() for sort in sorting_list) if not (ok1 and ok2): raise ValueError("Sortings don't have the same sampling_frequency/num_segments") From 9aff21d548d75a847f6fe67ddf8450376bd6a72e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 17 Jul 2025 16:59:04 +0200 Subject: [PATCH 3/3] Add sampling_frequency_max_diff to aggregate_sortings --- .../tests/test_unitsaggregationsorting.py | 17 ++++++++++++- .../core/unitsaggregationsorting.py | 24 ++++++++++--------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py index f60de5b62a..fadac094aa 100644 --- a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py +++ b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py @@ -149,5 +149,20 @@ def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] +def test_sampling_frequency_max_diff(): + """Test that the sampling frequency max diff is respected.""" + sorting1 = generate_sorting(sampling_frequency=30000, num_units=3) + sorting2 = generate_sorting(sampling_frequency=30000.01, num_units=3) + sorting3 = generate_sorting(sampling_frequency=30000.001, num_units=3) + + # Default is 0, so should not raise an error + with pytest.raises(ValueError): + aggregate_units([sorting1, sorting2, sorting3]) + + # This should not raise an warning + with pytest.warns(UserWarning): + aggregate_units([sorting1, sorting2, sorting3], sampling_frequency_max_diff=0.02) + + if __name__ == "__main__": - test_unitsaggregationsorting() + test_sampling_frequency_max_diff() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index cb824abd49..838660df46 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -4,9 +4,10 @@ import warnings import numpy as np -from .core_tools import define_function_from_class -from .base import BaseExtractor -from .basesorting import BaseSorting, BaseSortingSegment +from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.base import BaseExtractor +from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment +from spikeinterface.core.segmentutils import _check_sampling_frequencies class UnitsAggregationSorting(BaseSorting): @@ -19,6 +20,8 @@ class UnitsAggregationSorting(BaseSorting): List of BaseSorting objects to aggregate renamed_unit_ids: array-like If given, unit ids are renamed as provided. If None, unit ids are sequential integers. + sampling_frequency_max_diff : float, default: 0 + Maximum allowed difference of sampling frequencies across recordings Returns ------- @@ -26,7 +29,7 @@ class UnitsAggregationSorting(BaseSorting): The aggregated sorting object """ - def __init__(self, sorting_list, renamed_unit_ids=None): + def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_diff=0): unit_map = {} num_all_units = sum([sort.get_num_units() for sort in sorting_list]) @@ -60,15 +63,14 @@ def __init__(self, sorting_list, renamed_unit_ids=None): unit_map[unit_ids[u_id]] = {"sorting_id": s_i, "unit_id": unit_id} u_id += 1 - sampling_frequency = sorting_list[0].get_sampling_frequency() + sampling_frequencies = [sort.sampling_frequency for sort in sorting_list] num_segments = sorting_list[0].get_num_segments() - ok1 = all( - math.isclose(sampling_frequency, sort.get_sampling_frequency(), abs_tol=1e-2) for sort in sorting_list - ) - ok2 = all(num_segments == sort.get_num_segments() for sort in sorting_list) - if not (ok1 and ok2): - raise ValueError("Sortings don't have the same sampling_frequency/num_segments") + _check_sampling_frequencies(sampling_frequencies, sampling_frequency_max_diff) + sampling_frequency = sampling_frequencies[0] + num_segments_ok = all(num_segments == sort.get_num_segments() for sort in sorting_list) + if not num_segments_ok: + raise ValueError("Sortings don't have the same num_segments") BaseSorting.__init__(self, sampling_frequency, unit_ids)