From 16d40899d3d8dbc925d00cd034da8fb93af47946 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 1 Jul 2024 20:41:20 +0100 Subject: [PATCH 01/10] Fix t_starts not propagated to save memory. --- src/spikeinterface/core/baserecording.py | 4 ++-- src/spikeinterface/core/numpyextractors.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index aab7577b31..bb96fb06ca 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -545,11 +545,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording - cached = SharedMemoryRecording.from_recording(self, **job_kwargs) + cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs) else: from spikeinterface.core import NumpyRecording - cached = NumpyRecording.from_recording(self, **job_kwargs) + cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 0ba1c05417..b60ecb52a6 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -85,7 +85,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N } @staticmethod - def from_recording(source_recording, **job_kwargs): + def from_recording(source_recording, t_starts=None, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array @@ -95,13 +95,14 @@ def from_recording(source_recording, **job_kwargs): for shm in shms: shm.close() shm.unlink() - # TODO later : propagte t_starts ? + recording = NumpyRecording( traces_list, source_recording.get_sampling_frequency(), - t_starts=None, + t_starts=t_starts, channel_ids=source_recording.channel_ids, ) + return recording class NumpyRecordingSegment(BaseRecordingSegment): @@ -211,18 +212,16 @@ def __del__(self): shm.unlink() @staticmethod - def from_recording(source_recording, **job_kwargs): + def from_recording(source_recording, t_starts=None, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) - # TODO later : propagte t_starts ? - recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], shape_list=[traces.shape for traces in traces_list], dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, - t_starts=None, + t_starts=t_starts, main_shm_owner=True, ) From 3e9652b7c64cfe112ec549bd340d55ad95d97720 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 2 Jul 2024 09:43:10 +0100 Subject: [PATCH 02/10] force tests From f1d2755103cb8f9397e12aa1dad31033ba2e9fb3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 8 Jul 2024 13:33:09 +0200 Subject: [PATCH 03/10] Add _get_t_starts function and move t_starts retrieval to from_recording functions --- src/spikeinterface/core/baserecording.py | 31 +++++++++++++++------- src/spikeinterface/core/numpyextractors.py | 9 +++++-- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f4a276a396..3c46193c02 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -501,24 +501,35 @@ def time_to_sample_index(self, time_s, segment_index=None): rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) - def _save(self, format="binary", verbose: bool = False, **save_kwargs): + def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for segment_index, rs in enumerate(self._recording_segments): + for rs in self._recording_segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) - has_time_vectors.append(d["time_vector"] is not None) if all(t_start is None for t_start in t_starts): t_starts = None + return t_starts + def _get_time_vectors(self): + time_vectors = [] + for rs in self._recording_segments: + d = rs.get_times_kwargs() + time_vectors.append(d["time_vector"]) + if all(time_vector is None for time_vector in time_vectors): + time_vectors = None + return time_vectors + + def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() + t_starts = self._get_t_starts() write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) @@ -548,11 +559,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording - cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs) + cached = SharedMemoryRecording.from_recording(self, **job_kwargs) else: from spikeinterface.core import NumpyRecording - cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs) + cached = NumpyRecording.from_recording(self, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor @@ -575,11 +586,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] - if time_vector is not None: - cached._recording_segments[segment_index].time_vector = time_vector + time_vectors = self._get_time_vectors() + if time_vectors is not None: + for segment_index, time_vector in enumerate(time_vectors): + if time_vector is not None: + cached.set_times(time_vector, segment_index=segment_index) return cached diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 2cf03927a3..f4790817a8 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -81,8 +81,11 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N } @staticmethod - def from_recording(source_recording, t_starts=None, **job_kwargs): + def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + + t_starts = source_recording._get_t_starts() + if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array # this can lead to problem @@ -204,9 +207,11 @@ def __del__(self): shm.unlink() @staticmethod - def from_recording(source_recording, t_starts=None, **job_kwargs): + def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) + t_starts = source_recording._get_t_starts() + recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], shape_list=[traces.shape for traces in traces_list], From 989aa8b4eb29b70e04b3e9a730ab649eee2a084f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 15:45:44 +0100 Subject: [PATCH 04/10] Add more tests for time handling. --- .../core/tests/test_time_handling.py | 280 ++++++++++++++++++ 1 file changed, 280 insertions(+) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 487a893096..8a6971b0b7 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -1,9 +1,289 @@ +import copy + import pytest import numpy as np from spikeinterface.core import generate_recording, generate_sorting +import spikeinterface.full as si + + +class TestTimeHandling: + + # Fixtures ##### + @pytest.fixture(scope="session") + def raw_recording(self): + """ + A three-segment raw recording without times added. + """ + durations = [10, 15, 20] + recording = generate_recording(num_channels=4, durations=durations) + return recording + + @pytest.fixture(scope="session") + def time_vector_recording(self, raw_recording): + """ + Add time vectors to the recording, returning the + raw recording, recording with time vectors added to + segments, and list a the time vectors added to the recording. + """ + return self._get_time_vector_recording(raw_recording) + + @pytest.fixture(scope="session") + def t_start_recording(self, raw_recording): + """ + Add a t_starts to the recording, returning the + raw recording, recording with t_starts added to segments, + and a list of the time vectors generated from adding the + t_start to the recording times. + """ + return self._get_t_start_recording(raw_recording) + + def _get_time_vector_recording(self, raw_recording): + """ + Loop through all recording segments, adding a different time + vector to each segment. The time vector is the original times with + a t_start and irregularly spaced offsets to mimic irregularly + spaced timeseries data. Return the original recording, + recoridng with time vectors added and list including the added time vectors. + """ + times_recording = copy.deepcopy(raw_recording) + all_time_vectors = [] + for segment_index in range(raw_recording.get_num_segments()): + + t_start = segment_index + 1 * 100 + offsets = np.arange(times_recording.get_num_samples(segment_index)) * ( + 1 / times_recording.get_sampling_frequency() + ) + time_vector = t_start + times_recording.get_times(segment_index) + offsets + + all_time_vectors.append(time_vector) + times_recording.set_times(times=time_vector, segment_index=segment_index) + + assert np.array_equal( + times_recording._recording_segments[segment_index].time_vector, + time_vector, + ), "time_vector was not properly set during test setup" + + return (raw_recording, times_recording, all_time_vectors) + + def _get_t_start_recording(self, raw_recording): + """ + For each segment in the recording, add a different `t_start`. + Return a list of time vectors generating from the recording times + + the t_starts. + """ + t_start_recording = copy.deepcopy(raw_recording) + + all_t_starts = [] + for segment_index in range(raw_recording.get_num_segments()): + + t_start = (segment_index + 1) * 100 + + all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) + t_start_recording._recording_segments[segment_index].t_start = t_start + + return (raw_recording, t_start_recording, all_t_starts) + + def _get_fixture_data(self, request, fixture_name): + """ + A convenience function to get the data from a fixture + based on the name. This is used to allow parameterising + tests across fixtures. + """ + time_recording_fixture = request.getfixturevalue(fixture_name) + raw_recording, times_recording, all_times = time_recording_fixture + return (raw_recording, times_recording, all_times) + + # Tests ##### + def test_has_time_vector(self, time_vector_recording): + """ + Test the `has_time_vector` function returns `False` before + a time vector is added and `True` afterwards. + """ + raw_recording, times_recording, _ = time_vector_recording + + for segment_idx in range(raw_recording.get_num_segments()): + + assert raw_recording.has_time_vector(segment_idx) is False + assert times_recording.has_time_vector(segment_idx) is True + + @pytest.mark.parametrize("mode", ["binary", "zarr"]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path): + """ + Test `t_start` or `time_vector` is propagated to a saved recording, + by saving, reloading, and checking times are correct. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + folder_name = "recording" + recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name) + + if mode == "zarr": + folder_name += ".zarr" + recording_load = si.load_extractor(tmp_path / folder_name) + + self._check_times_match(recording_cache, all_times) + self._check_times_match(recording_load, all_times) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("sharedmem", [True, False]) + def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem): + """ + Test t_start and time_vector are propagated to recording saved into memory. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + recording_load = times_recording.save(format="memory", sharedmem=sharedmem) + + self._check_times_match(recording_load, all_times) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_propagated_to_select_segments(self, request, fixture_name): + """ + Test that when `recording.select_segments()` is used, the times + are propagated to the new recoridng object. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + for segment_index in range(times_recording.get_num_segments()): + segment = times_recording.select_segments(segment_index) + assert np.array_equal(segment.get_times(), all_times[segment_index]) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_sorting(self, request, fixture_name): + """ + Check that when attached to a sorting object, the times are propagated + to the object. This means that all spike times should respect the + `t_start` or `time_vector` added. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + for segment_index in range(raw_recording.get_num_segments()): + + if fixture_name == "time_vector_recording": + assert sorting.has_time_vector(segment_index=segment_index) + + self._check_spike_times_are_correct(sorting, times_recording, segment_index) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_sample_converters(self, request, fixture_name): + """ + Test the `recording.sample_time_to_index` and + `recording.time_to_sample_index` convenience functions. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + with pytest.raises(ValueError) as e: + times_recording.sample_index_to_time(0) + assert "Provide 'segment_index'" in str(e) + + for segment_index in range(times_recording.get_num_segments()): + + sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index)) + time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index) + + assert time_ == all_times[segment_index][sample_index] + + new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index) + + assert new_sample_index == sample_index + + @pytest.mark.parametrize("time_type", ["time_vector", "t_start"]) + @pytest.mark.parametrize("bounds", ["start", "middle", "end"]) + def test_slice_recording(self, time_type, bounds): + """ + Test after `frame_slice` and `time_slice` a recording or + sorting (for `frame_slice`), the recording times are + correct with respect to the set `t_start` or `time_vector`. + """ + raw_recording = generate_recording(num_channels=4, durations=[10]) + + if time_type == "time_vector": + raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording) + else: + raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording) + + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + + # Take some different times, including min and max bounds of + # the recording, and some arbitaray times in the middle (20% and 80%). + if bounds == "start": + start_frame = 0 + end_frame = int(times_recording.get_num_samples(0) * 0.8) + elif bounds == "end": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = times_recording.get_num_samples(0) - 1 + elif bounds == "middle": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = int(times_recording.get_num_samples(0) * 0.8) + + # Slice the recording and get the new times are correct + rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame) + sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame) + + assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0) + + # Test `time_slice` + start_time = times_recording.sample_index_to_time(start_frame) + end_time = times_recording.sample_index_to_time(end_frame) + + rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time) + + assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + # Helpers #### + def _check_times_match(self, recording, all_times): + """ + For every segment in a recording, check the `get_times()` + match the expected times in the list of time vectors, `all_times`. + """ + for segment_index in range(recording.get_num_segments()): + assert np.array_equal(recording.get_times(segment_index), all_times[segment_index]) + + def _check_spike_times_are_correct(self, sorting, times_recording, segment_index): + """ + For every unit in the `sorting`, for a particular segment, check that + the unit times match the times of the original recording as + retrieved with `get_times()`. + """ + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) + spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + rec_times = times_recording.get_times(segment_index=segment_index) + + assert np.array_equal( + spike_times, + rec_times[spike_indexes], + ) + + def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach): + """ + Convenience function to create a sorting object with + a recording attached. Typically use the raw recordings + for the durations of which to make the sorter, as + the generate_sorter is not setup to handle the + (strange) edge case of the irregularly spaced + test time vectors. + """ + durations = [ + recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments()) + ] + + sorting = generate_sorting(num_units=10, durations=durations) + + sorting.register_recording(recording_to_attach) + assert sorting.has_recording() + + return sorting +# TODO: deprecate original implementations ### def test_time_handling(create_cache_folder): cache_folder = create_cache_folder durations = [[10], [10, 5]] From 5e11c4effbe08c53f4154342138d2143e12b8021 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 15:46:50 +0100 Subject: [PATCH 05/10] Remove some indirection in the fixtures. --- .../core/tests/test_time_handling.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 8a6971b0b7..fb929ce5a9 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -10,15 +10,6 @@ class TestTimeHandling: # Fixtures ##### - @pytest.fixture(scope="session") - def raw_recording(self): - """ - A three-segment raw recording without times added. - """ - durations = [10, 15, 20] - recording = generate_recording(num_channels=4, durations=durations) - return recording - @pytest.fixture(scope="session") def time_vector_recording(self, raw_recording): """ @@ -26,6 +17,9 @@ def time_vector_recording(self, raw_recording): raw recording, recording with time vectors added to segments, and list a the time vectors added to the recording. """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) + return self._get_time_vector_recording(raw_recording) @pytest.fixture(scope="session") @@ -36,6 +30,9 @@ def t_start_recording(self, raw_recording): and a list of the time vectors generated from adding the t_start to the recording times. """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) + return self._get_t_start_recording(raw_recording) def _get_time_vector_recording(self, raw_recording): From 3ebd3b50d00feb464ec025f8270d9caf60223a89 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 15:49:49 +0100 Subject: [PATCH 06/10] Minor tidy, maintain order of parameterisation across tests. --- src/spikeinterface/core/tests/test_time_handling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index fb929ce5a9..e80564eb14 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -123,8 +123,8 @@ def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_ self._check_times_match(recording_cache, all_times) self._check_times_match(recording_load, all_times) - @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) @pytest.mark.parametrize("sharedmem", [True, False]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem): """ Test t_start and time_vector are propagated to recording saved into memory. @@ -191,9 +191,9 @@ def test_time_sample_converters(self, request, fixture_name): @pytest.mark.parametrize("bounds", ["start", "middle", "end"]) def test_slice_recording(self, time_type, bounds): """ - Test after `frame_slice` and `time_slice` a recording or - sorting (for `frame_slice`), the recording times are - correct with respect to the set `t_start` or `time_vector`. + Test times are correct after applying `frame_slice` or `time_slice` + to a recording or sorting (for `frame_slice`). The the recording times + should be correct with respect to the set `t_start` or `time_vector`. """ raw_recording = generate_recording(num_channels=4, durations=[10]) From 5011dd25fd8e0d92407c737a348f3833a7b86c68 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 16:01:12 +0100 Subject: [PATCH 07/10] Fix tests. --- src/spikeinterface/core/tests/test_time_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index e80564eb14..49fa622f7a 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -11,7 +11,7 @@ class TestTimeHandling: # Fixtures ##### @pytest.fixture(scope="session") - def time_vector_recording(self, raw_recording): + def time_vector_recording(self): """ Add time vectors to the recording, returning the raw recording, recording with time vectors added to @@ -23,7 +23,7 @@ def time_vector_recording(self, raw_recording): return self._get_time_vector_recording(raw_recording) @pytest.fixture(scope="session") - def t_start_recording(self, raw_recording): + def t_start_recording(self): """ Add a t_starts to the recording, returning the raw recording, recording with t_starts added to segments, From 08adf1304833436829eca0951f50c4081f3914ca Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 16:07:40 +0100 Subject: [PATCH 08/10] Add class docstring. --- src/spikeinterface/core/tests/test_time_handling.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 49fa622f7a..5d46fb3eed 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -8,6 +8,12 @@ class TestTimeHandling: + """ + This class tests how time is handled in SpikeInterface. Under the hood, + time can be represented as a full `time_vector` or only as + `t_start` attribute on segments from which a vector of times + is generated on the fly. Both time representations are tested here. + """ # Fixtures ##### @pytest.fixture(scope="session") From b2bac4d46679c18d6b6a53079ec9ff4d02e507eb Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 18:47:17 +0100 Subject: [PATCH 09/10] Make test time vector actually irregularly spaced! --- src/spikeinterface/core/tests/test_time_handling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 5d46fb3eed..eb169b77d5 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -54,9 +54,12 @@ def _get_time_vector_recording(self, raw_recording): for segment_index in range(raw_recording.get_num_segments()): t_start = segment_index + 1 * 100 - offsets = np.arange(times_recording.get_num_samples(segment_index)) * ( + + some_small_increasing_numbers = np.arange(times_recording.get_num_samples(segment_index)) * ( 1 / times_recording.get_sampling_frequency() ) + + offsets = np.cumsum(some_small_increasing_numbers) time_vector = t_start + times_recording.get_times(segment_index) + offsets all_time_vectors.append(time_vector) From d3cba031b3f5d44bcfd02f8e104d2716a2aef23b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 10:39:23 +0200 Subject: [PATCH 10/10] Remove deprecated tests --- .../core/tests/test_time_handling.py | 66 ------------------- 1 file changed, 66 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index eb169b77d5..049d5ab6e5 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -287,69 +287,3 @@ def _get_sorting_with_recording_attached(self, recording_for_durations, recordin assert sorting.has_recording() return sorting - - -# TODO: deprecate original implementations ### -def test_time_handling(create_cache_folder): - cache_folder = create_cache_folder - durations = [[10], [10, 5]] - - # test multi-segment - for i, dur in enumerate(durations): - rec = generate_recording(num_channels=4, durations=dur) - sort = generate_sorting(num_units=10, durations=dur) - - for segment_index in range(rec.get_num_segments()): - original_times = rec.get_times(segment_index=segment_index) - new_times = original_times + 5 - rec.set_times(new_times, segment_index=segment_index) - - sort.register_recording(rec) - assert sort.has_recording() - - rec_cache = rec.save(folder=cache_folder / f"rec{i}") - - for segment_index in range(sort.get_num_segments()): - assert rec.has_time_vector(segment_index=segment_index) - assert sort.has_time_vector(segment_index=segment_index) - - # times are correctly saved by the recording - assert np.allclose( - rec.get_times(segment_index=segment_index), rec_cache.get_times(segment_index=segment_index) - ) - - # spike times are correctly adjusted - for u in sort.get_unit_ids(): - spike_times = sort.get_unit_spike_train(u, segment_index=segment_index, return_times=True) - rec_times = rec.get_times(segment_index=segment_index) - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) - - -def test_frame_slicing(): - duration = [10] - - rec = generate_recording(num_channels=4, durations=duration) - sort = generate_sorting(num_units=10, durations=duration) - - original_times = rec.get_times() - new_times = original_times + 5 - rec.set_times(new_times) - - sort.register_recording(rec) - - start_frame = 3 * rec.get_sampling_frequency() - end_frame = 7 * rec.get_sampling_frequency() - - rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame) - sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame) - - for u in sort_slice.get_unit_ids(): - spike_times = sort_slice.get_unit_spike_train(u, return_times=True) - rec_times = rec_slice.get_times() - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) - - -if __name__ == "__main__": - test_frame_slicing()