diff --git a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py index f60de5b62a..fadac094aa 100644 --- a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py +++ b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py @@ -149,5 +149,20 @@ def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] +def test_sampling_frequency_max_diff(): + """Test that the sampling frequency max diff is respected.""" + sorting1 = generate_sorting(sampling_frequency=30000, num_units=3) + sorting2 = generate_sorting(sampling_frequency=30000.01, num_units=3) + sorting3 = generate_sorting(sampling_frequency=30000.001, num_units=3) + + # Default is 0, so should not raise an error + with pytest.raises(ValueError): + aggregate_units([sorting1, sorting2, sorting3]) + + # This should not raise an warning + with pytest.warns(UserWarning): + aggregate_units([sorting1, sorting2, sorting3], sampling_frequency_max_diff=0.02) + + if __name__ == "__main__": - test_unitsaggregationsorting() + test_sampling_frequency_max_diff() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 8f4a2732c3..838660df46 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -1,11 +1,13 @@ from __future__ import annotations +import math import warnings import numpy as np -from .core_tools import define_function_from_class -from .base import BaseExtractor -from .basesorting import BaseSorting, BaseSortingSegment +from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.base import BaseExtractor +from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment +from spikeinterface.core.segmentutils import _check_sampling_frequencies class UnitsAggregationSorting(BaseSorting): @@ -18,6 +20,8 @@ class UnitsAggregationSorting(BaseSorting): List of BaseSorting objects to aggregate renamed_unit_ids: array-like If given, unit ids are renamed as provided. If None, unit ids are sequential integers. + sampling_frequency_max_diff : float, default: 0 + Maximum allowed difference of sampling frequencies across recordings Returns ------- @@ -25,7 +29,7 @@ class UnitsAggregationSorting(BaseSorting): The aggregated sorting object """ - def __init__(self, sorting_list, renamed_unit_ids=None): + def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_diff=0): unit_map = {} num_all_units = sum([sort.get_num_units() for sort in sorting_list]) @@ -59,13 +63,14 @@ def __init__(self, sorting_list, renamed_unit_ids=None): unit_map[unit_ids[u_id]] = {"sorting_id": s_i, "unit_id": unit_id} u_id += 1 - sampling_frequency = sorting_list[0].get_sampling_frequency() + sampling_frequencies = [sort.sampling_frequency for sort in sorting_list] num_segments = sorting_list[0].get_num_segments() - ok1 = all(sampling_frequency == sort.get_sampling_frequency() for sort in sorting_list) - ok2 = all(num_segments == sort.get_num_segments() for sort in sorting_list) - if not (ok1 and ok2): - raise ValueError("Sortings don't have the same sampling_frequency/num_segments") + _check_sampling_frequencies(sampling_frequencies, sampling_frequency_max_diff) + sampling_frequency = sampling_frequencies[0] + num_segments_ok = all(num_segments == sort.get_num_segments() for sort in sorting_list) + if not num_segments_ok: + raise ValueError("Sortings don't have the same num_segments") BaseSorting.__init__(self, sampling_frequency, unit_ids)