Skip to content

Commit

Permalink
Merge pull request #1446 from fdeguire03/dev
Browse files Browse the repository at this point in the history
Add npy support to load_memmap
  • Loading branch information
pgunn authored Jan 10, 2025
2 parents 857ae12 + 67787fe commit beb0f9b
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions caiman/mmapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def load_memmap(filename: str, mode: str = 'r') -> tuple[Any, tuple, int]:
"""
logger = logging.getLogger("caiman")
if pathlib.Path(filename).suffix != '.mmap':
allowed_extensions = {'.mmap', '.npy'}

extension = pathlib.Path(filename).suffix
if extension not in allowed_extensions:
logger.error(f"Unknown extension for file {filename}")
raise ValueError(f'Unknown file extension for file {filename} (should be .mmap)')
raise ValueError(f'Unknown file extension for file {filename} (should be .mmap or .npy)')
# Strip path components and use CAIMAN_DATA/example_movies
# TODO: Eventually get the code to save these in a different dir
#fn_without_path = os.path.split(filename)[-1]
Expand All @@ -63,7 +66,22 @@ def load_memmap(filename: str, mode: str = 'r') -> tuple[Any, tuple, int]:
#d1, d2, d3, T, order = int(fpart[-9]), int(fpart[-7]), int(fpart[-5]), int(fpart[-1]), fpart[-3]

filename = caiman.paths.fn_relocated(filename)
Yr = np.memmap(filename, mode=mode, shape=prepare_shape((d1 * d2 * d3, T)), dtype=np.float32, order=order)
shape = prepare_shape((d1 * d2 * d3, T))
if extension == '.mmap':
Yr = np.memmap(filename, mode=mode, shape=shape, dtype=np.float32, order=order)
elif extension == '.npy':
Yr = np.load(filename, mmap_mode=mode)
if Yr.shape != shape:
raise ValueError(f"Data in npy file was an unexpected shape: {Yr.shape}, expected: {shape}")
if Yr.dtype != np.float32:
raise ValueError(f"Data in npy file was an unexpected dtype: {Yr.dtype}, expected: np.float32")
if order == 'C' and not Yr.flags['C_CONTIGUOUS']:
raise ValueError("Data in npy file is not in C-contiguous order as expected.")
elif order == 'F' and not Yr.flags['F_CONTIGUOUS']:
raise ValueError("Data in npy file is not in Fortran-contiguous order as expected.")



if d3 == 1:
return (Yr, (d1, d2), T)
else:
Expand Down

0 comments on commit beb0f9b

Please sign in to comment.