From fb08a7dc6ff31e879df00cca47e754bf29604b4b Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 15:42:40 +0000 Subject: [PATCH 1/8] FIX/ENH(psoct): do not load all slices at once (if old mat files) --- linc_convert/modalities/psoct/multi_slice.py | 213 +++++++++++++------ 1 file changed, 146 insertions(+), 67 deletions(-) diff --git a/linc_convert/modalities/psoct/multi_slice.py b/linc_convert/modalities/psoct/multi_slice.py index 1c55be46..557ec33c 100644 --- a/linc_convert/modalities/psoct/multi_slice.py +++ b/linc_convert/modalities/psoct/multi_slice.py @@ -9,10 +9,9 @@ import json import math import os -from contextlib import contextmanager from functools import wraps from itertools import product -from typing import Any, Callable, Optional +from typing import Callable, Optional from warnings import warn import cyclopts @@ -38,54 +37,129 @@ def _automap(func: Callable) -> Callable: - """Decorator to automatically map the array in the mat file.""" # noqa: D401 + """Automatically maps the array in the mat file.""" @wraps(func) - def wrapper(inp: str, out: str = None, **kwargs: dict) -> Any: # noqa: ANN401 + def wrapper(inp: list[str], out: str = None, **kwargs: dict) -> callable: if out is None: out = os.path.splitext(inp[0])[0] out += ".nii.zarr" if kwargs.get("nii", False) else ".ome.zarr" kwargs["nii"] = kwargs.get("nii", False) or out.endswith(".nii.zarr") - with _mapmat(inp, kwargs.get("key", None)) as dat: - return func(dat, out, **kwargs) + dat = _mapmat(inp, kwargs.get("key", None)) + return func(dat, out, **kwargs) return wrapper -@contextmanager -def _mapmat(fnames: list[str], key: str = None) -> None: - """Load or memory-map an array stored in a .mat file.""" - loaded_data = [] - - for fname in fnames: - try: - # "New" .mat file - f = h5py.File(fname, "r") - except Exception: - # "Old" .mat file - f = loadmat(fname) +class _ArrayWrapper: + def _get_key(self, f) -> str: + key = self.key if key is None: if not len(f.keys()): - raise Exception(f"{fname} is empty") - key = list(f.keys())[0] + raise Exception(f"{self.file} is empty") + for key in f.keys(): + if key[:1] != '_': + break if len(f.keys()) > 1: warn( - f"More than one key in .mat file {fname}, " + f"More than one key in .mat file {self.file}, " f'arbitrarily loading "{key}"' ) if key not in f.keys(): - raise Exception(f"Key {key} not found in file {fname}") + raise Exception(f"Key {key} not found in file {self.file}") + + return key + + +class _H5ArrayWrapper(_ArrayWrapper): + + def __init__(self, file, key) -> None: + self.file = file + self.key = key + self.array = file.get(self._get_key(self.file)) + + def __del__(self) -> None: + if hasattr(self.file, 'close'): + self.file.close() + + def load(self) -> np.ndarray: + self.array = self.array[...] + if hasattr(self.file, 'close'): + self.file.close() + self.file = None + return self.array + + @property + def shape(self) -> list[int]: + return self.array.shape + + @property + def dtype(self) -> np.dtype: + return self.array.dtype + + def __len__(self) -> int: + return len(self.array) + + def __getitem__(self, index) -> np.ndarray: + return self.array[index] + + +class _MatArrayWrapper(_ArrayWrapper): + + def __init__(self, file, key) -> None: + self.file = file + self.key = key + self.array = None + + def __del__(self) -> None: + if hasattr(self.file, 'close'): + self.file.close() + + def load(self) -> np.ndarray: + f = loadmat(self.file) + self.array = f.get(self._get_key(f)) + self.file = None + return self.array + + @property + def shape(self) -> list[int]: + if self.array is None: + self.load() + return self.array.shape + + @property + def dtype(self) -> np.dtype: + if self.array is None: + self.load() + return self.array.dtype + + def __len__(self) -> int: + if self.array is None: + self.load() + return len(self.array) + + def __getitem__(self, index) -> np.ndarray: + if self.array is None: + self.load() + return self.array[index] + + +def _mapmat(fnames: list[str], key: str = None) -> list[_ArrayWrapper]: + """Load or memory-map an array stored in a .mat file.""" + # loaded_data = [] + + def make_wrapper(fname: str) -> callable: + try: + # "New" .mat file + f = h5py.File(fname, "r") + return _H5ArrayWrapper(f, key) + except Exception: + # "Old" .mat file + return _MatArrayWrapper(fname, key) - if len(fnames) == 1: - yield f.get(key) - if hasattr(f, "close"): - f.close() - break - loaded_data.append(f.get(key)) - yield loaded_data - # yield np.stack(loaded_data, axis=-1) + return [make_wrapper(fname) for fname in fnames] @multi_slice.default @@ -163,10 +237,10 @@ def convert( omz = zarr.storage.DirectoryStore(out) omz = zarr.group(store=omz, overwrite=True) - if not hasattr(inp[0], "dtype"): - raise Exception("Input is not numpy array. This is likely unexpected") - if len(inp[0].shape) != 2: - raise Exception("Input array is not 2d") + # if not hasattr(inp[0], "dtype"): + # raise Exception("Input is not an array. This is likely unexpected") + if len(inp[0].shape) < 2: + raise Exception("Input array is not 2d:", inp[0].shape) # Prepare chunking options opt = { "dimension_separator": r"/", @@ -177,13 +251,15 @@ def convert( } inp: list = inp inp_shape = (*inp[0].shape, len(inp)) - inp_chunk = [min(x, max_load) for x in inp_shape] - nk = ceildiv(inp_shape[0], inp_chunk[0]) - nj = ceildiv(inp_shape[1], inp_chunk[1]) - ni = ceildiv(inp_shape[2], inp_chunk[2]) + inp_chunk = [min(x, max_load) for x in inp_shape[-3:]] + nk = ceildiv(inp_shape[-3], inp_chunk[0]) + nj = ceildiv(inp_shape[-2], inp_chunk[1]) + ni = len(inp) nblevels = min( - [int(math.ceil(math.log2(x))) for i, x in enumerate(inp_shape) if i != no_pool] + [int(math.ceil(math.log2(x))) + for i, x in enumerate(inp_shape) + if i != no_pool] ) nblevels = min(nblevels, int(math.ceil(math.log2(max_load)))) nblevels = min(nblevels, max_levels) @@ -193,32 +269,32 @@ def convert( omz.create_dataset(str(0), shape=inp_shape, **opt) # iterate across input chunks - for i, j, k in product(range(ni), range(nj), range(nk)): - loaded_chunk = np.stack( - [ - inp[index][ - k * inp_chunk[0] : (k + 1) * inp_chunk[0], - j * inp_chunk[1] : (j + 1) * inp_chunk[1], - ] - for index in range(i * inp_chunk[2], (i + 1) * inp_chunk[2]) - ], - axis=-1, - ) - - print( - f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]", - "/", - f"[{ni:03d}, {nj:03d}, {nk:03d}]", - # f"({1 + level}/{nblevels})", - end="\r", - ) - - # save current chunk - omz["0"][ - k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0], - j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1], - i * inp_chunk[2] : i * inp_chunk[2] + loaded_chunk.shape[2], - ] = loaded_chunk + for i in range(ni): + + for j, k in product(range(nj), range(nk)): + loaded_chunk = inp[i][ + ..., + k * inp_chunk[0]: (k + 1) * inp_chunk[0], + j * inp_chunk[1]: (j + 1) * inp_chunk[1], + ] + + print( + f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]", + "/", + f"[{ni:03d}, {nj:03d}, {nk:03d}]", + # f"({1 + level}/{nblevels})", + end="\r", + ) + + # save current chunk + omz["0"][ + ..., + k * inp_chunk[0]: k * inp_chunk[0] + loaded_chunk.shape[0], + j * inp_chunk[1]: j * inp_chunk[1] + loaded_chunk.shape[1], + i, + ] = loaded_chunk + + inp[i] = None # no ref count -> delete array generate_pyramid(omz, nblevels - 1, mode="mean") @@ -234,7 +310,9 @@ def convert( no_pool=no_pool, space_unit=ome_unit, space_scale=vx, - multiscales_type=("2x2x2" if no_pool is None else "2x2") + "mean window", + multiscales_type=( + ("2x2x2" if no_pool is None else "2x2") + "mean window" + ), ) if not nii: @@ -250,5 +328,6 @@ def convert( if center: affine = center_affine(affine, shape[:3]) niftizarr_write_header( - omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), nifti_version=2 + omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), + nifti_version=2 ) From 4f71ddd203c32a5854772621ed8e086499e2bb05 Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 15:51:34 +0000 Subject: [PATCH 2/8] FIX(psoct): hints to please ruff --- linc_convert/modalities/psoct/multi_slice.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/linc_convert/modalities/psoct/multi_slice.py b/linc_convert/modalities/psoct/multi_slice.py index 557ec33c..03ec4e71 100644 --- a/linc_convert/modalities/psoct/multi_slice.py +++ b/linc_convert/modalities/psoct/multi_slice.py @@ -11,7 +11,7 @@ import os from functools import wraps from itertools import product -from typing import Callable, Optional +from typing import Callable, Mapping, Optional from warnings import warn import cyclopts @@ -53,7 +53,7 @@ def wrapper(inp: list[str], out: str = None, **kwargs: dict) -> callable: class _ArrayWrapper: - def _get_key(self, f) -> str: + def _get_key(self, f: Mapping) -> str: key = self.key if key is None: if not len(f.keys()): @@ -75,7 +75,7 @@ def _get_key(self, f) -> str: class _H5ArrayWrapper(_ArrayWrapper): - def __init__(self, file, key) -> None: + def __init__(self, file: h5py.File, key: str | None) -> None: self.file = file self.key = key self.array = file.get(self._get_key(self.file)) @@ -102,13 +102,13 @@ def dtype(self) -> np.dtype: def __len__(self) -> int: return len(self.array) - def __getitem__(self, index) -> np.ndarray: + def __getitem__(self, index: object) -> np.ndarray: return self.array[index] class _MatArrayWrapper(_ArrayWrapper): - def __init__(self, file, key) -> None: + def __init__(self, file: str, key: str | None) -> None: self.file = file self.key = key self.array = None @@ -140,7 +140,7 @@ def __len__(self) -> int: self.load() return len(self.array) - def __getitem__(self, index) -> np.ndarray: + def __getitem__(self, index: object) -> np.ndarray: if self.array is None: self.load() return self.array[index] From a2735f7e75d7552708f3b68ab5f68386a601e703 Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 15:51:56 +0000 Subject: [PATCH 3/8] style fixes by ruff --- linc_convert/modalities/psoct/multi_slice.py | 31 +++++++------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/linc_convert/modalities/psoct/multi_slice.py b/linc_convert/modalities/psoct/multi_slice.py index 03ec4e71..b771426c 100644 --- a/linc_convert/modalities/psoct/multi_slice.py +++ b/linc_convert/modalities/psoct/multi_slice.py @@ -52,14 +52,13 @@ def wrapper(inp: list[str], out: str = None, **kwargs: dict) -> callable: class _ArrayWrapper: - def _get_key(self, f: Mapping) -> str: key = self.key if key is None: if not len(f.keys()): raise Exception(f"{self.file} is empty") for key in f.keys(): - if key[:1] != '_': + if key[:1] != "_": break if len(f.keys()) > 1: warn( @@ -74,19 +73,18 @@ def _get_key(self, f: Mapping) -> str: class _H5ArrayWrapper(_ArrayWrapper): - def __init__(self, file: h5py.File, key: str | None) -> None: self.file = file self.key = key self.array = file.get(self._get_key(self.file)) def __del__(self) -> None: - if hasattr(self.file, 'close'): + if hasattr(self.file, "close"): self.file.close() def load(self) -> np.ndarray: self.array = self.array[...] - if hasattr(self.file, 'close'): + if hasattr(self.file, "close"): self.file.close() self.file = None return self.array @@ -107,14 +105,13 @@ def __getitem__(self, index: object) -> np.ndarray: class _MatArrayWrapper(_ArrayWrapper): - def __init__(self, file: str, key: str | None) -> None: self.file = file self.key = key self.array = None def __del__(self) -> None: - if hasattr(self.file, 'close'): + if hasattr(self.file, "close"): self.file.close() def load(self) -> np.ndarray: @@ -257,9 +254,7 @@ def convert( ni = len(inp) nblevels = min( - [int(math.ceil(math.log2(x))) - for i, x in enumerate(inp_shape) - if i != no_pool] + [int(math.ceil(math.log2(x))) for i, x in enumerate(inp_shape) if i != no_pool] ) nblevels = min(nblevels, int(math.ceil(math.log2(max_load)))) nblevels = min(nblevels, max_levels) @@ -270,12 +265,11 @@ def convert( # iterate across input chunks for i in range(ni): - for j, k in product(range(nj), range(nk)): loaded_chunk = inp[i][ ..., - k * inp_chunk[0]: (k + 1) * inp_chunk[0], - j * inp_chunk[1]: (j + 1) * inp_chunk[1], + k * inp_chunk[0] : (k + 1) * inp_chunk[0], + j * inp_chunk[1] : (j + 1) * inp_chunk[1], ] print( @@ -289,8 +283,8 @@ def convert( # save current chunk omz["0"][ ..., - k * inp_chunk[0]: k * inp_chunk[0] + loaded_chunk.shape[0], - j * inp_chunk[1]: j * inp_chunk[1] + loaded_chunk.shape[1], + k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0], + j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1], i, ] = loaded_chunk @@ -310,9 +304,7 @@ def convert( no_pool=no_pool, space_unit=ome_unit, space_scale=vx, - multiscales_type=( - ("2x2x2" if no_pool is None else "2x2") + "mean window" - ), + multiscales_type=(("2x2x2" if no_pool is None else "2x2") + "mean window"), ) if not nii: @@ -328,6 +320,5 @@ def convert( if center: affine = center_affine(affine, shape[:3]) niftizarr_write_header( - omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), - nifti_version=2 + omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), nifti_version=2 ) From 2e3d12c847c023d3ae44316e6ee23ff4e9493f17 Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 16:06:34 +0000 Subject: [PATCH 4/8] ENH(psoct.single_volume): more robust default key --- .../modalities/psoct/single_volume.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/linc_convert/modalities/psoct/single_volume.py b/linc_convert/modalities/psoct/single_volume.py index 6bf651b3..04ba73ce 100644 --- a/linc_convert/modalities/psoct/single_volume.py +++ b/linc_convert/modalities/psoct/single_volume.py @@ -12,7 +12,7 @@ from contextlib import contextmanager from functools import wraps from itertools import product -from typing import Any, Callable, Optional +from typing import Callable, Optional from warnings import warn import cyclopts @@ -38,10 +38,10 @@ def _automap(func: Callable) -> Callable: - """Decorator to automatically map the array in the mat file.""" # noqa: D401 + """Automatically map the array in the mat file.""" @wraps(func) - def wrapper(inp: str, out: str = None, **kwargs: dict) -> Any: # noqa: ANN401 + def wrapper(inp: str, out: str = None, **kwargs: dict) -> None: if out is None: out = os.path.splitext(inp[0])[0] out += ".nii.zarr" if kwargs.get("nii", False) else ".ome.zarr" @@ -65,9 +65,12 @@ def _mapmat(fname: str, key: str = None) -> None: if key is None: if not len(f.keys()): raise Exception(f"{fname} is empty") - key = list(f.keys())[0] + for key in f.keys(): + if key[:1] != '_': + break if len(f.keys()) > 1: - warn(f'More than one key in .mat file {fname}, arbitrarily loading "{key}"') + warn(f'More than one key in .mat file {fname}, ' + f'arbitrarily loading "{key}"') if key not in f.keys(): raise Exception(f"Key {key} not found in file {fname}") @@ -153,9 +156,9 @@ def convert( omz = zarr.group(store=omz, overwrite=True) if not hasattr(inp, "dtype"): - raise Exception("Input is not a numpy array. This is likely unexpected") + raise Exception("Input is not a numpy array. This is unexpected.") if len(inp.shape) < 3: - raise Exception("Input array is not 3d") + raise Exception("Input array is not 3d:", inp.shape) # Prepare chunking options opt = { "dimension_separator": r"/", @@ -171,7 +174,9 @@ def convert( ni = ceildiv(inp.shape[2], inp_chunk[2]) nblevels = min( - [int(math.ceil(math.log2(x))) for i, x in enumerate(inp.shape) if i != no_pool] + [int(math.ceil(math.log2(x))) + for i, x in enumerate(inp.shape) + if i != no_pool] ) nblevels = min(nblevels, int(math.ceil(math.log2(max_load)))) nblevels = min(nblevels, max_levels) @@ -183,9 +188,9 @@ def convert( # iterate across input chunks for i, j, k in product(range(ni), range(nj), range(nk)): loaded_chunk = inp[ - k * inp_chunk[0] : (k + 1) * inp_chunk[0], - j * inp_chunk[1] : (j + 1) * inp_chunk[1], - i * inp_chunk[2] : (i + 1) * inp_chunk[2], + k * inp_chunk[0]: (k + 1) * inp_chunk[0], + j * inp_chunk[1]: (j + 1) * inp_chunk[1], + i * inp_chunk[2]: (i + 1) * inp_chunk[2], ] print( @@ -198,9 +203,9 @@ def convert( # save current chunk omz["0"][ - k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0], - j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1], - i * inp_chunk[2] : i * inp_chunk[2] + loaded_chunk.shape[2], + k * inp_chunk[0]: k * inp_chunk[0] + loaded_chunk.shape[0], + j * inp_chunk[1]: j * inp_chunk[1] + loaded_chunk.shape[1], + i * inp_chunk[2]: i * inp_chunk[2] + loaded_chunk.shape[2], ] = loaded_chunk generate_pyramid(omz, nblevels - 1, mode="mean") @@ -217,7 +222,9 @@ def convert( no_pool=no_pool, space_unit=ome_unit, space_scale=vx, - multiscales_type=("2x2x2" if no_pool is None else "2x2") + "mean window", + multiscales_type=( + ("2x2x2" if no_pool is None else "2x2") + "mean window" + ), ) if not nii: @@ -233,5 +240,6 @@ def convert( if center: affine = center_affine(affine, shape[:3]) niftizarr_write_header( - omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), nifti_version=2 + omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), + nifti_version=2 ) From 3dc2636b46185f9ed8a78d8f30f59f1abfe2cbeb Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 16:06:53 +0000 Subject: [PATCH 5/8] DOC(psoct) --- linc_convert/modalities/psoct/multi_slice.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/linc_convert/modalities/psoct/multi_slice.py b/linc_convert/modalities/psoct/multi_slice.py index 03ec4e71..6de28033 100644 --- a/linc_convert/modalities/psoct/multi_slice.py +++ b/linc_convert/modalities/psoct/multi_slice.py @@ -183,8 +183,12 @@ def convert( """ Matlab to OME-Zarr. - Convert OCT volumes in raw matlab files - into a pyramidal OME-ZARR (or NIfTI-Zarr) hierarchy. + Convert OCT volumes in raw matlab files into a pyramidal + OME-ZARR (or NIfTI-Zarr) hierarchy. + + This command assumes that each slice in a volume is stored in a + different mat file. All slices must have the same shape, and will + be concatenated into a 3D Zarr. Parameters ---------- @@ -207,7 +211,7 @@ def convert( max_levels Maximum number of pyramid levels no_pool - Index of dimension to not pool when building pyramid + Index of dimension to not pool when building pyramid. nii Convert to nifti-zarr. True if path ends in ".nii.zarr" orientation From 36119115ef768080f2b581169f3e12867cf3096b Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 16:07:30 +0000 Subject: [PATCH 6/8] style fixes by ruff --- .../modalities/psoct/single_volume.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/linc_convert/modalities/psoct/single_volume.py b/linc_convert/modalities/psoct/single_volume.py index 04ba73ce..efdd64c3 100644 --- a/linc_convert/modalities/psoct/single_volume.py +++ b/linc_convert/modalities/psoct/single_volume.py @@ -66,11 +66,13 @@ def _mapmat(fname: str, key: str = None) -> None: if not len(f.keys()): raise Exception(f"{fname} is empty") for key in f.keys(): - if key[:1] != '_': + if key[:1] != "_": break if len(f.keys()) > 1: - warn(f'More than one key in .mat file {fname}, ' - f'arbitrarily loading "{key}"') + warn( + f"More than one key in .mat file {fname}, " + f'arbitrarily loading "{key}"' + ) if key not in f.keys(): raise Exception(f"Key {key} not found in file {fname}") @@ -174,9 +176,7 @@ def convert( ni = ceildiv(inp.shape[2], inp_chunk[2]) nblevels = min( - [int(math.ceil(math.log2(x))) - for i, x in enumerate(inp.shape) - if i != no_pool] + [int(math.ceil(math.log2(x))) for i, x in enumerate(inp.shape) if i != no_pool] ) nblevels = min(nblevels, int(math.ceil(math.log2(max_load)))) nblevels = min(nblevels, max_levels) @@ -188,9 +188,9 @@ def convert( # iterate across input chunks for i, j, k in product(range(ni), range(nj), range(nk)): loaded_chunk = inp[ - k * inp_chunk[0]: (k + 1) * inp_chunk[0], - j * inp_chunk[1]: (j + 1) * inp_chunk[1], - i * inp_chunk[2]: (i + 1) * inp_chunk[2], + k * inp_chunk[0] : (k + 1) * inp_chunk[0], + j * inp_chunk[1] : (j + 1) * inp_chunk[1], + i * inp_chunk[2] : (i + 1) * inp_chunk[2], ] print( @@ -203,9 +203,9 @@ def convert( # save current chunk omz["0"][ - k * inp_chunk[0]: k * inp_chunk[0] + loaded_chunk.shape[0], - j * inp_chunk[1]: j * inp_chunk[1] + loaded_chunk.shape[1], - i * inp_chunk[2]: i * inp_chunk[2] + loaded_chunk.shape[2], + k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0], + j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1], + i * inp_chunk[2] : i * inp_chunk[2] + loaded_chunk.shape[2], ] = loaded_chunk generate_pyramid(omz, nblevels - 1, mode="mean") @@ -222,9 +222,7 @@ def convert( no_pool=no_pool, space_unit=ome_unit, space_scale=vx, - multiscales_type=( - ("2x2x2" if no_pool is None else "2x2") + "mean window" - ), + multiscales_type=(("2x2x2" if no_pool is None else "2x2") + "mean window"), ) if not nii: @@ -240,6 +238,5 @@ def convert( if center: affine = center_affine(affine, shape[:3]) niftizarr_write_header( - omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), - nifti_version=2 + omz, shape, affine, omz["0"].dtype, to_nifti_unit(unit), nifti_version=2 ) From 41cbd1f8cc643186d1bfafb0d2ed6c0eeabfe41f Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 16:39:55 +0000 Subject: [PATCH 7/8] FIX(psoct): propagate no_pool to pyramid generator + FIX(generate_pyramid): do not crash if last chunk in a row only has a single voxel --- linc_convert/modalities/psoct/_utils.py | 12 ++++++++++-- linc_convert/modalities/psoct/multi_slice.py | 2 +- linc_convert/modalities/psoct/single_volume.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/linc_convert/modalities/psoct/_utils.py b/linc_convert/modalities/psoct/_utils.py index 380515e5..0d7026dc 100644 --- a/linc_convert/modalities/psoct/_utils.py +++ b/linc_convert/modalities/psoct/_utils.py @@ -203,10 +203,14 @@ def generate_pyramid( slice(i * max_load, min((i + 1) * max_load, n)) for i, n in zip(chunk_index, prev_shape) ] + fullshape = omz[str(level - 1)].shape dat = omz[str(level - 1)][tuple(slicer)] # Discard the last voxel along odd dimensions - crop = [0 if x == 1 else x % 2 for x in dat.shape[-ndim:]] + crop = [ + 0 if y == 1 else x % 2 + for x, y in zip(dat.shape[-ndim:], fullshape) + ] # Don't crop the axis not down-sampling # cannot do if not no_pyramid_axis since it could be 0 if no_pyramid_axis is not None: @@ -214,6 +218,10 @@ def generate_pyramid( slcr = [slice(-1) if x else slice(None) for x in crop] dat = dat[tuple([Ellipsis, *slcr])] + if any(n == 0 for n in dat.shape): + # last strip had a single voxel, nothing to do + continue + patch_shape = dat.shape[-ndim:] # Reshape into patches of shape 2x2x2 @@ -234,7 +242,7 @@ def generate_pyramid( # -> flatten patches smaller_shape = [max(n // 2, 1) for n in patch_shape] if no_pyramid_axis is not None: - smaller_shape[2 * no_pyramid_axis] = patch_shape[no_pyramid_axis] + smaller_shape[no_pyramid_axis] = patch_shape[no_pyramid_axis] dat = dat.reshape(batch + smaller_shape + [-1]) diff --git a/linc_convert/modalities/psoct/multi_slice.py b/linc_convert/modalities/psoct/multi_slice.py index 47838a9d..a4c7d625 100644 --- a/linc_convert/modalities/psoct/multi_slice.py +++ b/linc_convert/modalities/psoct/multi_slice.py @@ -294,7 +294,7 @@ def convert( inp[i] = None # no ref count -> delete array - generate_pyramid(omz, nblevels - 1, mode="mean") + generate_pyramid(omz, nblevels - 1, mode="mean", no_pyramid_axis=no_pool) print("") diff --git a/linc_convert/modalities/psoct/single_volume.py b/linc_convert/modalities/psoct/single_volume.py index 04ba73ce..669675b2 100644 --- a/linc_convert/modalities/psoct/single_volume.py +++ b/linc_convert/modalities/psoct/single_volume.py @@ -208,7 +208,7 @@ def convert( i * inp_chunk[2]: i * inp_chunk[2] + loaded_chunk.shape[2], ] = loaded_chunk - generate_pyramid(omz, nblevels - 1, mode="mean") + generate_pyramid(omz, nblevels - 1, mode="mean", no_pyramid_axis=no_pool) print("") From 204bcb23fc379daca0eada469b3f57baf7a428d4 Mon Sep 17 00:00:00 2001 From: balbasty Date: Fri, 22 Nov 2024 16:40:30 +0000 Subject: [PATCH 8/8] style fixes by ruff --- linc_convert/modalities/psoct/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/linc_convert/modalities/psoct/_utils.py b/linc_convert/modalities/psoct/_utils.py index 0d7026dc..ee0d7491 100644 --- a/linc_convert/modalities/psoct/_utils.py +++ b/linc_convert/modalities/psoct/_utils.py @@ -208,8 +208,7 @@ def generate_pyramid( # Discard the last voxel along odd dimensions crop = [ - 0 if y == 1 else x % 2 - for x, y in zip(dat.shape[-ndim:], fullshape) + 0 if y == 1 else x % 2 for x, y in zip(dat.shape[-ndim:], fullshape) ] # Don't crop the axis not down-sampling # cannot do if not no_pyramid_axis since it could be 0