Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix t_starts not propagated to save_to_memory. #3120

Merged
merged 12 commits into from
Jul 9, 2024
27 changes: 19 additions & 8 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,24 +501,35 @@ def time_to_sample_index(self, time_s, segment_index=None):
rs = self._recording_segments[segment_index]
return rs.time_to_sample_index(time_s)

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
def _get_t_starts(self):
# handle t_starts
t_starts = []
has_time_vectors = []
for segment_index, rs in enumerate(self._recording_segments):
for rs in self._recording_segments:
d = rs.get_times_kwargs()
t_starts.append(d["t_start"])
has_time_vectors.append(d["time_vector"] is not None)

if all(t_start is None for t_start in t_starts):
t_starts = None
return t_starts

def _get_time_vectors(self):
time_vectors = []
for rs in self._recording_segments:
d = rs.get_times_kwargs()
time_vectors.append(d["time_vector"])
if all(time_vector is None for time_vector in time_vectors):
time_vectors = None
return time_vectors

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
kwargs, job_kwargs = split_job_kwargs(save_kwargs)

if format == "binary":
folder = kwargs["folder"]
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
dtype = kwargs.get("dtype", None) or self.get_dtype()
t_starts = self._get_t_starts()

write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)

Expand Down Expand Up @@ -575,11 +586,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

for segment_index, rs in enumerate(self._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
if time_vector is not None:
cached._recording_segments[segment_index].time_vector = time_vector
time_vectors = self._get_time_vectors()
if time_vectors is not None:
for segment_index, time_vector in enumerate(time_vectors):
if time_vector is not None:
cached.set_times(time_vector, segment_index=segment_index)

return cached

Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
@staticmethod
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)

t_starts = source_recording._get_t_starts()

if shms[0] is not None:
# if the computation was done in parallel then traces_list is shared array
# this can lead to problem
Expand All @@ -91,13 +94,14 @@ def from_recording(source_recording, **job_kwargs):
for shm in shms:
shm.close()
shm.unlink()
# TODO later : propagte t_starts ?

recording = NumpyRecording(
traces_list,
source_recording.get_sampling_frequency(),
t_starts=None,
t_starts=t_starts,
channel_ids=source_recording.channel_ids,
)
return recording


class NumpyRecordingSegment(BaseRecordingSegment):
Expand Down Expand Up @@ -206,15 +210,15 @@ def __del__(self):
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)

# TODO later : propagte t_starts ?
t_starts = source_recording._get_t_starts()

recording = SharedMemoryRecording(
shm_names=[shm.name for shm in shms],
shape_list=[traces.shape for traces in traces_list],
dtype=source_recording.dtype,
sampling_frequency=source_recording.sampling_frequency,
channel_ids=source_recording.channel_ids,
t_starts=None,
t_starts=t_starts,
main_shm_owner=True,
)

Expand Down
283 changes: 283 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,292 @@
import copy

import pytest
import numpy as np

from spikeinterface.core import generate_recording, generate_sorting
import spikeinterface.full as si


class TestTimeHandling:
"""
This class tests how time is handled in SpikeInterface. Under the hood,
time can be represented as a full `time_vector` or only as
`t_start` attribute on segments from which a vector of times
is generated on the fly. Both time representations are tested here.
"""

# Fixtures #####
@pytest.fixture(scope="session")
def time_vector_recording(self):
"""
Add time vectors to the recording, returning the
raw recording, recording with time vectors added to
segments, and list a the time vectors added to the recording.
"""
durations = [10, 15, 20]
raw_recording = generate_recording(num_channels=4, durations=durations)

return self._get_time_vector_recording(raw_recording)

@pytest.fixture(scope="session")
def t_start_recording(self):
"""
Add a t_starts to the recording, returning the
raw recording, recording with t_starts added to segments,
and a list of the time vectors generated from adding the
t_start to the recording times.
"""
durations = [10, 15, 20]
raw_recording = generate_recording(num_channels=4, durations=durations)

return self._get_t_start_recording(raw_recording)

def _get_time_vector_recording(self, raw_recording):
"""
Loop through all recording segments, adding a different time
vector to each segment. The time vector is the original times with
a t_start and irregularly spaced offsets to mimic irregularly
spaced timeseries data. Return the original recording,
recoridng with time vectors added and list including the added time vectors.
"""
times_recording = copy.deepcopy(raw_recording)
all_time_vectors = []
for segment_index in range(raw_recording.get_num_segments()):

t_start = segment_index + 1 * 100
offsets = np.arange(times_recording.get_num_samples(segment_index)) * (
1 / times_recording.get_sampling_frequency()
)
time_vector = t_start + times_recording.get_times(segment_index) + offsets

all_time_vectors.append(time_vector)
times_recording.set_times(times=time_vector, segment_index=segment_index)

assert np.array_equal(
times_recording._recording_segments[segment_index].time_vector,
time_vector,
), "time_vector was not properly set during test setup"

return (raw_recording, times_recording, all_time_vectors)

def _get_t_start_recording(self, raw_recording):
"""
For each segment in the recording, add a different `t_start`.
Return a list of time vectors generating from the recording times
+ the t_starts.
"""
t_start_recording = copy.deepcopy(raw_recording)

all_t_starts = []
for segment_index in range(raw_recording.get_num_segments()):

t_start = (segment_index + 1) * 100

all_t_starts.append(t_start + t_start_recording.get_times(segment_index))
t_start_recording._recording_segments[segment_index].t_start = t_start

return (raw_recording, t_start_recording, all_t_starts)

def _get_fixture_data(self, request, fixture_name):
"""
A convenience function to get the data from a fixture
based on the name. This is used to allow parameterising
tests across fixtures.
"""
time_recording_fixture = request.getfixturevalue(fixture_name)
raw_recording, times_recording, all_times = time_recording_fixture
return (raw_recording, times_recording, all_times)

# Tests #####
def test_has_time_vector(self, time_vector_recording):
"""
Test the `has_time_vector` function returns `False` before
a time vector is added and `True` afterwards.
"""
raw_recording, times_recording, _ = time_vector_recording

for segment_idx in range(raw_recording.get_num_segments()):

assert raw_recording.has_time_vector(segment_idx) is False
assert times_recording.has_time_vector(segment_idx) is True

@pytest.mark.parametrize("mode", ["binary", "zarr"])
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path):
"""
Test `t_start` or `time_vector` is propagated to a saved recording,
by saving, reloading, and checking times are correct.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

folder_name = "recording"
recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name)

if mode == "zarr":
folder_name += ".zarr"
recording_load = si.load_extractor(tmp_path / folder_name)

self._check_times_match(recording_cache, all_times)
self._check_times_match(recording_load, all_times)

@pytest.mark.parametrize("sharedmem", [True, False])
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem):
"""
Test t_start and time_vector are propagated to recording saved into memory.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

recording_load = times_recording.save(format="memory", sharedmem=sharedmem)

self._check_times_match(recording_load, all_times)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_time_propagated_to_select_segments(self, request, fixture_name):
"""
Test that when `recording.select_segments()` is used, the times
are propagated to the new recoridng object.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

for segment_index in range(times_recording.get_num_segments()):
segment = times_recording.select_segments(segment_index)
assert np.array_equal(segment.get_times(), all_times[segment_index])

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_times_propagated_to_sorting(self, request, fixture_name):
"""
Check that when attached to a sorting object, the times are propagated
to the object. This means that all spike times should respect the
`t_start` or `time_vector` added.
"""
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
sorting = self._get_sorting_with_recording_attached(
recording_for_durations=raw_recording, recording_to_attach=times_recording
)
for segment_index in range(raw_recording.get_num_segments()):

if fixture_name == "time_vector_recording":
assert sorting.has_time_vector(segment_index=segment_index)

self._check_spike_times_are_correct(sorting, times_recording, segment_index)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_time_sample_converters(self, request, fixture_name):
"""
Test the `recording.sample_time_to_index` and
`recording.time_to_sample_index` convenience functions.
"""
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
with pytest.raises(ValueError) as e:
times_recording.sample_index_to_time(0)
assert "Provide 'segment_index'" in str(e)

for segment_index in range(times_recording.get_num_segments()):

sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index))
time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index)

assert time_ == all_times[segment_index][sample_index]

new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index)

assert new_sample_index == sample_index

@pytest.mark.parametrize("time_type", ["time_vector", "t_start"])
@pytest.mark.parametrize("bounds", ["start", "middle", "end"])
def test_slice_recording(self, time_type, bounds):
"""
Test times are correct after applying `frame_slice` or `time_slice`
to a recording or sorting (for `frame_slice`). The the recording times
should be correct with respect to the set `t_start` or `time_vector`.
"""
raw_recording = generate_recording(num_channels=4, durations=[10])

if time_type == "time_vector":
raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording)
else:
raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording)

sorting = self._get_sorting_with_recording_attached(
recording_for_durations=raw_recording, recording_to_attach=times_recording
)

# Take some different times, including min and max bounds of
# the recording, and some arbitaray times in the middle (20% and 80%).
if bounds == "start":
start_frame = 0
end_frame = int(times_recording.get_num_samples(0) * 0.8)
elif bounds == "end":
start_frame = int(times_recording.get_num_samples(0) * 0.2)
end_frame = times_recording.get_num_samples(0) - 1
elif bounds == "middle":
start_frame = int(times_recording.get_num_samples(0) * 0.2)
end_frame = int(times_recording.get_num_samples(0) * 0.8)

# Slice the recording and get the new times are correct
rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame)
sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame)

assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0)

# Test `time_slice`
start_time = times_recording.sample_index_to_time(start_frame)
end_time = times_recording.sample_index_to_time(end_frame)

rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time)

assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

# Helpers ####
def _check_times_match(self, recording, all_times):
"""
For every segment in a recording, check the `get_times()`
match the expected times in the list of time vectors, `all_times`.
"""
for segment_index in range(recording.get_num_segments()):
assert np.array_equal(recording.get_times(segment_index), all_times[segment_index])

def _check_spike_times_are_correct(self, sorting, times_recording, segment_index):
"""
For every unit in the `sorting`, for a particular segment, check that
the unit times match the times of the original recording as
retrieved with `get_times()`.
"""
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index)
rec_times = times_recording.get_times(segment_index=segment_index)

assert np.array_equal(
spike_times,
rec_times[spike_indexes],
)

def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach):
"""
Convenience function to create a sorting object with
a recording attached. Typically use the raw recordings
for the durations of which to make the sorter, as
the generate_sorter is not setup to handle the
(strange) edge case of the irregularly spaced
test time vectors.
"""
durations = [
recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments())
]

sorting = generate_sorting(num_units=10, durations=durations)

sorting.register_recording(recording_to_attach)
assert sorting.has_recording()

return sorting


# TODO: deprecate original implementations ###
JoeZiminski marked this conversation as resolved.
Show resolved Hide resolved
def test_time_handling(create_cache_folder):
cache_folder = create_cache_folder
durations = [[10], [10, 5]]
Expand Down
Loading