Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix t_starts not propagated to save_to_memory. #3120

Merged
merged 12 commits into from
Jul 9, 2024
27 changes: 19 additions & 8 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None):
rs = self._recording_segments[segment_index]
return rs.time_to_sample_index(time_s)

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
def _get_t_starts(self):
# handle t_starts
t_starts = []
has_time_vectors = []
for segment_index, rs in enumerate(self._recording_segments):
for rs in self._recording_segments:
d = rs.get_times_kwargs()
t_starts.append(d["t_start"])
has_time_vectors.append(d["time_vector"] is not None)

if all(t_start is None for t_start in t_starts):
t_starts = None
return t_starts

def _get_time_vectors(self):
time_vectors = []
for rs in self._recording_segments:
d = rs.get_times_kwargs()
time_vectors.append(d["time_vector"])
if all(time_vector is None for time_vector in time_vectors):
time_vectors = None
return time_vectors

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
kwargs, job_kwargs = split_job_kwargs(save_kwargs)

if format == "binary":
folder = kwargs["folder"]
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
dtype = kwargs.get("dtype", None) or self.get_dtype()
t_starts = self._get_t_starts()

write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)

Expand Down Expand Up @@ -572,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

for segment_index, rs in enumerate(self._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
if time_vector is not None:
cached._recording_segments[segment_index].time_vector = time_vector
time_vectors = self._get_time_vectors()
if time_vectors is not None:
for segment_index, time_vector in enumerate(time_vectors):
if time_vector is not None:
cached.set_times(time_vector, segment_index=segment_index)

return cached

Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
@staticmethod
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)

t_starts = source_recording._get_t_starts()

if shms[0] is not None:
# if the computation was done in parallel then traces_list is shared array
# this can lead to problem
Expand All @@ -91,13 +94,14 @@ def from_recording(source_recording, **job_kwargs):
for shm in shms:
shm.close()
shm.unlink()
# TODO later : propagte t_starts ?

recording = NumpyRecording(
traces_list,
source_recording.get_sampling_frequency(),
t_starts=None,
t_starts=t_starts,
channel_ids=source_recording.channel_ids,
)
return recording


class NumpyRecordingSegment(BaseRecordingSegment):
Expand Down Expand Up @@ -206,15 +210,15 @@ def __del__(self):
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)

# TODO later : propagte t_starts ?
t_starts = source_recording._get_t_starts()

recording = SharedMemoryRecording(
shm_names=[shm.name for shm in shms],
shape_list=[traces.shape for traces in traces_list],
dtype=source_recording.dtype,
sampling_frequency=source_recording.sampling_frequency,
channel_ids=source_recording.channel_ids,
t_starts=None,
t_starts=t_starts,
main_shm_owner=True,
)

Expand Down
Loading
Loading