diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 64c1b9b2e6..7a9a841eba 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -166,25 +166,17 @@ def get_binary_description(self): class BinaryRecordingSegment(BaseRecordingSegment): - def __init__(self, datfile, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset): + def __init__(self, file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset): BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start) self.num_channels = num_channels self.dtype = np.dtype(dtype) self.file_offset = file_offset self.time_axis = time_axis - self.datfile = datfile - self.file = open(self.datfile, "r") - self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_channels * np.dtype(dtype).itemsize) - if self.time_axis == 0: - self.shape = (self.num_samples, self.num_channels) - else: - self.shape = (self.num_channels, self.num_samples) - - byte_offset = self.file_offset - dtype_size_bytes = self.dtype.itemsize - data_size_bytes = dtype_size_bytes * self.num_samples * self.num_channels - self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY) - self.memmap_length = data_size_bytes + self.array_offset + self.file_path = file_path + self.file = open(self.file_path, "rb") + self.bytes_per_sample = self.num_channels * self.dtype.itemsize + self.data_size_in_bytes = Path(file_path).stat().st_size - file_offset + self.num_samples = self.data_size_in_bytes // self.bytes_per_sample def get_num_samples(self) -> int: """Returns the number of samples in this signal block @@ -200,23 +192,43 @@ def get_traces( end_frame: int | None = None, channel_indices: list | None = None, ) -> np.ndarray: - length = self.memmap_length - memmap_offset = self.memmap_offset + + # Calculate byte offsets for start and end frames + start_byte = self.file_offset + start_frame * self.bytes_per_sample + end_byte = self.file_offset + end_frame * self.bytes_per_sample + + # Calculate the length of the data chunk to load into memory + length = end_byte - start_byte + + # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY + memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) + memmap_offset *= mmap.ALLOCATIONGRANULARITY + + # Adjust the length so it includes the extra data from rounding down + # the memmap offset to a multiple of ALLOCATIONGRANULARITY + length += start_offset + + # Create the mmap object memmap_obj = mmap.mmap(self.file.fileno(), length=length, access=mmap.ACCESS_READ, offset=memmap_offset) - array = np.ndarray.__new__( - np.ndarray, - shape=self.shape, + # Create a numpy array using the mmap object as the buffer + # Note that the shape must be recalculated based on the new data chunk + if self.time_axis == 0: + shape = ((end_frame - start_frame), self.num_channels) + else: + shape = (self.num_channels, (end_frame - start_frame)) + + # Now the entire array should correspond to the data between start_frame and end_frame, so we can use it directly + traces = np.ndarray( + shape=shape, dtype=self.dtype, buffer=memmap_obj, - order="C", - offset=self.array_offset, + offset=start_offset, ) if self.time_axis == 1: - array = array.T + traces = traces.T - traces = array[start_frame:end_frame] if channel_indices is not None: traces = traces[:, channel_indices] diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 1f2e644be6..996718dc42 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -659,3 +659,28 @@ def retrieve_importing_provenance(a_class): } return info + + +def measure_memory_allocation(measure_in_process: bool = True) -> float: + """ + A local utility to measure memory allocation at a specific point in time. + Can measure either the process resident memory or system wide memory available + + Uses psutil package. + + Parameters + ---------- + measure_in_process : bool, True by default + Mesure memory allocation in the current process only, if false then measures at the system + level. + """ + import psutil + + if measure_in_process: + process = psutil.Process() + memory = process.memory_info().rss + else: + mem_info = psutil.virtual_memory() + memory = mem_info.total - mem_info.available + + return memory diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 8ea99e3d04..ea5edc6e6e 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -1,8 +1,11 @@ import pytest import numpy as np +from pathlib import Path from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording +from spikeinterface.core.core_tools import measure_memory_allocation +from spikeinterface.core.generate import NoiseGeneratorRecording def test_BinaryRecordingExtractor(create_cache_folder): @@ -51,15 +54,75 @@ def test_round_trip(tmp_path): dtype=dtype, ) + # Test for full traces assert np.allclose(recording.get_traces(), binary_recorder.get_traces()) - start_frame = 200 - end_frame = 500 + # Ttest for a sub-set of the traces + start_frame = 20 + end_frame = 40 smaller_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) binary_smaller_traces = binary_recorder.get_traces(start_frame=start_frame, end_frame=end_frame) np.allclose(smaller_traces, binary_smaller_traces) +@pytest.fixture(scope="module") +def folder_with_binary_files(tmpdir_factory): + tmp_path = Path(tmpdir_factory.mktemp("spike_interface_test")) + folder = tmp_path / "test_binary_recording" + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + recording = NoiseGeneratorRecording( + durations=[1.0], + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + ) + dtype = recording.get_dtype() + recording.save(folder=folder, overwrite=True) + + return folder + + +def test_sequential_reading_of_small_traces(folder_with_binary_files): + # Test that memmap is readed correctly when pointing to specific frames + folder = folder_with_binary_files + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + + file_paths = [folder / "traces_cached_seg0.raw"] + recording = BinaryRecordingExtractor( + num_chan=num_channels, + file_paths=file_paths, + sampling_frequency=sampling_frequency, + dtype=dtype, + ) + + full_traces = recording.get_traces() + + # Test for a sub-set of the traces + start_frame = 10 + end_frame = 15 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + # Test for a sub-set of the traces + start_frame = 1000 + end_frame = 1100 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + # Test for a sub-set of the traces + start_frame = 10_000 + end_frame = 11_000 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + if __name__ == "__main__": test_BinaryRecordingExtractor()