From 4005c7588b5f0506c68e119528d07ded1c101998 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 1 Jul 2024 20:41:20 +0100 Subject: [PATCH] Fix t_starts not propagated to save memory. --- src/spikeinterface/core/baserecording.py | 4 ++-- src/spikeinterface/core/numpyextractors.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index aab7577b31..bb96fb06ca 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -545,11 +545,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording - cached = SharedMemoryRecording.from_recording(self, **job_kwargs) + cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs) else: from spikeinterface.core import NumpyRecording - cached = NumpyRecording.from_recording(self, **job_kwargs) + cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 0ba1c05417..83b99d6c44 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -85,7 +85,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N } @staticmethod - def from_recording(source_recording, **job_kwargs): + def from_recording(source_recording, t_starts=None, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array @@ -99,9 +99,10 @@ def from_recording(source_recording, **job_kwargs): recording = NumpyRecording( traces_list, source_recording.get_sampling_frequency(), - t_starts=None, + t_starts=t_starts, channel_ids=source_recording.channel_ids, ) + return recording class NumpyRecordingSegment(BaseRecordingSegment): @@ -211,7 +212,7 @@ def __del__(self): shm.unlink() @staticmethod - def from_recording(source_recording, **job_kwargs): + def from_recording(source_recording, t_starts=None, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) # TODO later : propagte t_starts ? @@ -222,7 +223,7 @@ def from_recording(source_recording, **job_kwargs): dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, - t_starts=None, + t_starts=t_starts, main_shm_owner=True, )