Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX(psoct): avoid loading all slices when files are "old matlab format" #27

Merged
merged 11 commits into from
Nov 22, 2024
11 changes: 9 additions & 2 deletions linc_convert/modalities/psoct/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,24 @@ def generate_pyramid(
slice(i * max_load, min((i + 1) * max_load, n))
for i, n in zip(chunk_index, prev_shape)
]
fullshape = omz[str(level - 1)].shape
dat = omz[str(level - 1)][tuple(slicer)]

# Discard the last voxel along odd dimensions
crop = [0 if x == 1 else x % 2 for x in dat.shape[-ndim:]]
crop = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain what is this for? Why do we need to use this full shape instead? Since this changes breaks another test from another modality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We divide the previous resolution by 2. Thinking in 1D, to do this we reshape a dimension [N] into a dimension [N//2, 2] (essentialy you get a stack of the odd and even voxels) and then average across the new small axis.

However, we can only do this if the original dimension is even. So if it's odd I crop the last voxel (by doing something like array[:-1]).

That said, I think that the current code assumes that the data is exactly 3D (no channel dimension). I might have fixed it in the other PR (to be merged).

What error do you get, and on what kind of data?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. The error from another modality is basically dimension not matching which makes sense to me now. Could you please point me to the other PR if it is not the one we just merged?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was the one you just merged

0 if y == 1 else x % 2 for x, y in zip(dat.shape[-ndim:], fullshape)
]
# Don't crop the axis not down-sampling
# cannot do if not no_pyramid_axis since it could be 0
if no_pyramid_axis is not None:
crop[no_pyramid_axis] = 0
slcr = [slice(-1) if x else slice(None) for x in crop]
dat = dat[tuple([Ellipsis, *slcr])]

if any(n == 0 for n in dat.shape):
# last strip had a single voxel, nothing to do
continue

patch_shape = dat.shape[-ndim:]

# Reshape into patches of shape 2x2x2
Expand All @@ -234,7 +241,7 @@ def generate_pyramid(
# -> flatten patches
smaller_shape = [max(n // 2, 1) for n in patch_shape]
if no_pyramid_axis is not None:
smaller_shape[2 * no_pyramid_axis] = patch_shape[no_pyramid_axis]
smaller_shape[no_pyramid_axis] = patch_shape[no_pyramid_axis]

dat = dat.reshape(batch + smaller_shape + [-1])

Expand Down
216 changes: 145 additions & 71 deletions linc_convert/modalities/psoct/multi_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import json
import math
import os
from contextlib import contextmanager
from functools import wraps
from itertools import product
from typing import Any, Callable, Optional
from typing import Callable, Mapping, Optional
from warnings import warn

import cyclopts
Expand All @@ -38,54 +37,126 @@


def _automap(func: Callable) -> Callable:
"""Decorator to automatically map the array in the mat file.""" # noqa: D401
"""Automatically maps the array in the mat file."""

@wraps(func)
def wrapper(inp: str, out: str = None, **kwargs: dict) -> Any: # noqa: ANN401
def wrapper(inp: list[str], out: str = None, **kwargs: dict) -> callable:
if out is None:
out = os.path.splitext(inp[0])[0]
out += ".nii.zarr" if kwargs.get("nii", False) else ".ome.zarr"
kwargs["nii"] = kwargs.get("nii", False) or out.endswith(".nii.zarr")
with _mapmat(inp, kwargs.get("key", None)) as dat:
return func(dat, out, **kwargs)
dat = _mapmat(inp, kwargs.get("key", None))
return func(dat, out, **kwargs)

return wrapper


@contextmanager
def _mapmat(fnames: list[str], key: str = None) -> None:
"""Load or memory-map an array stored in a .mat file."""
loaded_data = []

for fname in fnames:
try:
# "New" .mat file
f = h5py.File(fname, "r")
except Exception:
# "Old" .mat file
f = loadmat(fname)

class _ArrayWrapper:
def _get_key(self, f: Mapping) -> str:
key = self.key
if key is None:
if not len(f.keys()):
raise Exception(f"{fname} is empty")
key = list(f.keys())[0]
raise Exception(f"{self.file} is empty")
for key in f.keys():
if key[:1] != "_":
break
if len(f.keys()) > 1:
warn(
f"More than one key in .mat file {fname}, "
f"More than one key in .mat file {self.file}, "
f'arbitrarily loading "{key}"'
)

if key not in f.keys():
raise Exception(f"Key {key} not found in file {fname}")
raise Exception(f"Key {key} not found in file {self.file}")

return key


class _H5ArrayWrapper(_ArrayWrapper):
def __init__(self, file: h5py.File, key: str | None) -> None:
self.file = file
self.key = key
self.array = file.get(self._get_key(self.file))

def __del__(self) -> None:
if hasattr(self.file, "close"):
self.file.close()

def load(self) -> np.ndarray:
self.array = self.array[...]
if hasattr(self.file, "close"):
self.file.close()
self.file = None
return self.array

@property
def shape(self) -> list[int]:
return self.array.shape

@property
def dtype(self) -> np.dtype:
return self.array.dtype

def __len__(self) -> int:
return len(self.array)

def __getitem__(self, index: object) -> np.ndarray:
return self.array[index]


if len(fnames) == 1:
yield f.get(key)
if hasattr(f, "close"):
f.close()
break
loaded_data.append(f.get(key))
yield loaded_data
# yield np.stack(loaded_data, axis=-1)
class _MatArrayWrapper(_ArrayWrapper):
def __init__(self, file: str, key: str | None) -> None:
self.file = file
self.key = key
self.array = None

def __del__(self) -> None:
if hasattr(self.file, "close"):
self.file.close()

def load(self) -> np.ndarray:
f = loadmat(self.file)
self.array = f.get(self._get_key(f))
self.file = None
return self.array

@property
def shape(self) -> list[int]:
if self.array is None:
self.load()
return self.array.shape

@property
def dtype(self) -> np.dtype:
if self.array is None:
self.load()
return self.array.dtype

def __len__(self) -> int:
if self.array is None:
self.load()
return len(self.array)

def __getitem__(self, index: object) -> np.ndarray:
if self.array is None:
self.load()
return self.array[index]


def _mapmat(fnames: list[str], key: str = None) -> list[_ArrayWrapper]:
"""Load or memory-map an array stored in a .mat file."""
# loaded_data = []

def make_wrapper(fname: str) -> callable:
try:
# "New" .mat file
f = h5py.File(fname, "r")
return _H5ArrayWrapper(f, key)
except Exception:
# "Old" .mat file
return _MatArrayWrapper(fname, key)

return [make_wrapper(fname) for fname in fnames]


@multi_slice.default
Expand All @@ -109,8 +180,12 @@ def convert(
"""
Matlab to OME-Zarr.

Convert OCT volumes in raw matlab files
into a pyramidal OME-ZARR (or NIfTI-Zarr) hierarchy.
Convert OCT volumes in raw matlab files into a pyramidal
OME-ZARR (or NIfTI-Zarr) hierarchy.

This command assumes that each slice in a volume is stored in a
different mat file. All slices must have the same shape, and will
be concatenated into a 3D Zarr.

Parameters
----------
Expand All @@ -133,7 +208,7 @@ def convert(
max_levels
Maximum number of pyramid levels
no_pool
Index of dimension to not pool when building pyramid
Index of dimension to not pool when building pyramid.
nii
Convert to nifti-zarr. True if path ends in ".nii.zarr"
orientation
Expand Down Expand Up @@ -163,10 +238,10 @@ def convert(
omz = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=omz, overwrite=True)

if not hasattr(inp[0], "dtype"):
raise Exception("Input is not numpy array. This is likely unexpected")
if len(inp[0].shape) != 2:
raise Exception("Input array is not 2d")
# if not hasattr(inp[0], "dtype"):
# raise Exception("Input is not an array. This is likely unexpected")
if len(inp[0].shape) < 2:
raise Exception("Input array is not 2d:", inp[0].shape)
# Prepare chunking options
opt = {
"dimension_separator": r"/",
Expand All @@ -177,10 +252,10 @@ def convert(
}
inp: list = inp
inp_shape = (*inp[0].shape, len(inp))
inp_chunk = [min(x, max_load) for x in inp_shape]
nk = ceildiv(inp_shape[0], inp_chunk[0])
nj = ceildiv(inp_shape[1], inp_chunk[1])
ni = ceildiv(inp_shape[2], inp_chunk[2])
inp_chunk = [min(x, max_load) for x in inp_shape[-3:]]
nk = ceildiv(inp_shape[-3], inp_chunk[0])
nj = ceildiv(inp_shape[-2], inp_chunk[1])
ni = len(inp)

nblevels = min(
[int(math.ceil(math.log2(x))) for i, x in enumerate(inp_shape) if i != no_pool]
Expand All @@ -193,34 +268,33 @@ def convert(
omz.create_dataset(str(0), shape=inp_shape, **opt)

# iterate across input chunks
for i, j, k in product(range(ni), range(nj), range(nk)):
loaded_chunk = np.stack(
[
inp[index][
k * inp_chunk[0] : (k + 1) * inp_chunk[0],
j * inp_chunk[1] : (j + 1) * inp_chunk[1],
]
for index in range(i * inp_chunk[2], (i + 1) * inp_chunk[2])
],
axis=-1,
)

print(
f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]",
"/",
f"[{ni:03d}, {nj:03d}, {nk:03d}]",
# f"({1 + level}/{nblevels})",
end="\r",
)

# save current chunk
omz["0"][
k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0],
j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1],
i * inp_chunk[2] : i * inp_chunk[2] + loaded_chunk.shape[2],
] = loaded_chunk

generate_pyramid(omz, nblevels - 1, mode="mean")
for i in range(ni):
for j, k in product(range(nj), range(nk)):
loaded_chunk = inp[i][
...,
k * inp_chunk[0] : (k + 1) * inp_chunk[0],
j * inp_chunk[1] : (j + 1) * inp_chunk[1],
]

print(
f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]",
"/",
f"[{ni:03d}, {nj:03d}, {nk:03d}]",
# f"({1 + level}/{nblevels})",
end="\r",
)

# save current chunk
omz["0"][
...,
k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0],
j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1],
i,
] = loaded_chunk

inp[i] = None # no ref count -> delete array

generate_pyramid(omz, nblevels - 1, mode="mean", no_pyramid_axis=no_pool)

print("")

Expand All @@ -234,7 +308,7 @@ def convert(
no_pool=no_pool,
space_unit=ome_unit,
space_scale=vx,
multiscales_type=("2x2x2" if no_pool is None else "2x2") + "mean window",
multiscales_type=(("2x2x2" if no_pool is None else "2x2") + "mean window"),
)

if not nii:
Expand Down
23 changes: 14 additions & 9 deletions linc_convert/modalities/psoct/single_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from contextlib import contextmanager
from functools import wraps
from itertools import product
from typing import Any, Callable, Optional
from typing import Callable, Optional
from warnings import warn

import cyclopts
Expand All @@ -38,10 +38,10 @@


def _automap(func: Callable) -> Callable:
"""Decorator to automatically map the array in the mat file.""" # noqa: D401
"""Automatically map the array in the mat file."""

@wraps(func)
def wrapper(inp: str, out: str = None, **kwargs: dict) -> Any: # noqa: ANN401
def wrapper(inp: str, out: str = None, **kwargs: dict) -> None:
if out is None:
out = os.path.splitext(inp[0])[0]
out += ".nii.zarr" if kwargs.get("nii", False) else ".ome.zarr"
Expand All @@ -65,9 +65,14 @@ def _mapmat(fname: str, key: str = None) -> None:
if key is None:
if not len(f.keys()):
raise Exception(f"{fname} is empty")
key = list(f.keys())[0]
for key in f.keys():
if key[:1] != "_":
break
if len(f.keys()) > 1:
warn(f'More than one key in .mat file {fname}, arbitrarily loading "{key}"')
warn(
f"More than one key in .mat file {fname}, "
f'arbitrarily loading "{key}"'
)

if key not in f.keys():
raise Exception(f"Key {key} not found in file {fname}")
Expand Down Expand Up @@ -153,9 +158,9 @@ def convert(
omz = zarr.group(store=omz, overwrite=True)

if not hasattr(inp, "dtype"):
raise Exception("Input is not a numpy array. This is likely unexpected")
raise Exception("Input is not a numpy array. This is unexpected.")
if len(inp.shape) < 3:
raise Exception("Input array is not 3d")
raise Exception("Input array is not 3d:", inp.shape)
# Prepare chunking options
opt = {
"dimension_separator": r"/",
Expand Down Expand Up @@ -203,7 +208,7 @@ def convert(
i * inp_chunk[2] : i * inp_chunk[2] + loaded_chunk.shape[2],
] = loaded_chunk

generate_pyramid(omz, nblevels - 1, mode="mean")
generate_pyramid(omz, nblevels - 1, mode="mean", no_pyramid_axis=no_pool)

print("")

Expand All @@ -217,7 +222,7 @@ def convert(
no_pool=no_pool,
space_unit=ome_unit,
space_scale=vx,
multiscales_type=("2x2x2" if no_pool is None else "2x2") + "mean window",
multiscales_type=(("2x2x2" if no_pool is None else "2x2") + "mean window"),
)

if not nii:
Expand Down