From 5d0256389fb139238261ac7fb616d84679fbf8bb Mon Sep 17 00:00:00 2001 From: Calvin Chai Date: Mon, 25 Nov 2024 15:31:29 +0000 Subject: [PATCH] WIP --- linc_convert/modalities/lsm/mosaic.py | 145 ++++++++-------- linc_convert/models/zarr_config.py | 234 ++++++++++++++++++++++++++ tests/helper.py | 2 + tests/test_lsm.py | 4 +- 4 files changed, 315 insertions(+), 70 deletions(-) create mode 100644 linc_convert/models/zarr_config.py diff --git a/linc_convert/modalities/lsm/mosaic.py b/linc_convert/modalities/lsm/mosaic.py index e02c0db..a56453c 100644 --- a/linc_convert/modalities/lsm/mosaic.py +++ b/linc_convert/modalities/lsm/mosaic.py @@ -23,7 +23,7 @@ from linc_convert.utils.math import ceildiv from linc_convert.utils.orientation import center_affine, orientation_to_affine from linc_convert.utils.zarr import make_compressor - +from linc_convert.models.zarr_config import ZarrConfig mosaic = cyclopts.App(name="mosaic", help_format="markdown") lsm.command(mosaic) @@ -31,13 +31,9 @@ @mosaic.default def convert( inp: str, - out: str = None, *, - chunk: int = 128, - compressor: str = "blosc", - compressor_opt: str = "{}", + zarr_config: ZarrConfig, max_load: int = 512, - nii: bool = False, orientation: str = "coronal", center: bool = True, thickness: float | None = None, @@ -91,6 +87,12 @@ def convert( voxel_size Voxel size along the X, Y and Z dimension, in micron. """ + out: str = zarr_config.out + chunk: int = zarr_config.chunk + compressor: str = zarr_config.compressor + compressor_opt: str = zarr_config.compressor_opt + print(compressor_opt, type(compressor_opt)) + nii: bool = zarr_config.nii if isinstance(compressor_opt, str): compressor_opt = ast.literal_eval(compressor_opt) @@ -242,69 +244,74 @@ def convert( # build pyramid using median windows level = 0 - while any(x > 1 for x in omz[str(level)].shape[-3:]): - prev_array = omz[str(level)] - prev_shape = prev_array.shape[-3:] - level += 1 - - new_shape = list(map(lambda x: max(1, x // 2), prev_shape)) - if all(x < chunk for x in new_shape): - break - print("Compute level", level, "with shape", new_shape) - omz.create_dataset(str(level), shape=[nchannels, *new_shape], **opt) - new_array = omz[str(level)] - - nz, ny, nx = prev_array.shape[-3:] - ncz = ceildiv(nz, max_load) - ncy = ceildiv(ny, max_load) - ncx = ceildiv(nx, max_load) - - for cz in range(ncz): - for cy in range(ncy): - for cx in range(ncx): - print(f"chunk ({cz}, {cy}, {cx}) / ({ncz}, {ncy}, {ncx})", end="\r") - - dat = prev_array[ - ..., - cz * max_load : (cz + 1) * max_load, - cy * max_load : (cy + 1) * max_load, - cx * max_load : (cx + 1) * max_load, - ] - crop = [0 if x == 1 else x % 2 for x in dat.shape[-3:]] - slicer = [slice(-1) if x else slice(None) for x in crop] - dat = dat[(Ellipsis, *slicer)] - pz, py, px = dat.shape[-3:] - - dat = dat.reshape( - [ - nchannels, - max(pz // 2, 1), - min(pz, 2), - max(py // 2, 1), - min(py, 2), - max(px // 2, 1), - min(px, 2), - ] - ) - dat = dat.transpose([0, 1, 3, 5, 2, 4, 6]) - dat = dat.reshape( - [ - nchannels, - max(pz // 2, 1), - max(py // 2, 1), - max(px // 2, 1), - -1, - ] - ) - dat = np.median(dat, -1) - - new_array[ - ..., - cz * max_load // 2 : (cz + 1) * max_load // 2, - cy * max_load // 2 : (cy + 1) * max_load // 2, - cx * max_load // 2 : (cx + 1) * max_load // 2, - ] = dat - + # while any(x > 1 for x in omz[str(level)].shape[-3:]): + # prev_array = omz[str(level)] + # prev_shape = prev_array.shape[-3:] + # level += 1 + + # new_shape = list(map(lambda x: max(1, x // 2), prev_shape)) + # if all(x < chunk for x in new_shape): + # break + # print("Compute level", level, "with shape", new_shape) + # omz.create_dataset(str(level), shape=[nchannels, *new_shape], **opt) + # new_array = omz[str(level)] + + # nz, ny, nx = prev_array.shape[-3:] + # ncz = ceildiv(nz, max_load) + # ncy = ceildiv(ny, max_load) + # ncx = ceildiv(nx, max_load) + + # for cz in range(ncz): + # for cy in range(ncy): + # for cx in range(ncx): + # print(f"chunk ({cz}, {cy}, {cx}) / ({ncz}, {ncy}, {ncx})", end="\r") + + # dat = prev_array[ + # ..., + # cz * max_load : (cz + 1) * max_load, + # cy * max_load : (cy + 1) * max_load, + # cx * max_load : (cx + 1) * max_load, + # ] + # crop = [0 if x == 1 else x % 2 for x in dat.shape[-3:]] + # slicer = [slice(-1) if x else slice(None) for x in crop] + # dat = dat[(Ellipsis, *slicer)] + # pz, py, px = dat.shape[-3:] + + # dat = dat.reshape( + # [ + # nchannels, + # max(pz // 2, 1), + # min(pz, 2), + # max(py // 2, 1), + # min(py, 2), + # max(px // 2, 1), + # min(px, 2), + # ] + # ) + # dat = dat.transpose([0, 1, 3, 5, 2, 4, 6]) + # dat = dat.reshape( + # [ + # nchannels, + # max(pz // 2, 1), + # max(py // 2, 1), + # max(px // 2, 1), + # -1, + # ] + # ) + # dat = np.median(dat, -1) + + # new_array[ + # ..., + # cz * max_load // 2 : (cz + 1) * max_load // 2, + # cy * max_load // 2 : (cy + 1) * max_load // 2, + # cx * max_load // 2 : (cx + 1) * max_load // 2, + # ] = dat + # + nxyz = np.array(fullshape) + default_layers = int(np.ceil(np.log2(np.max(nxyz / chunk)))) + 1 + nblevel = max(default_layers, 1) + from linc_convert.modalities.psoct._utils import generate_pyramid + generate_pyramid(omz, levels=nblevel-1,ndim = len(fullshape)+1) print("") nblevel = level diff --git a/linc_convert/models/zarr_config.py b/linc_convert/models/zarr_config.py new file mode 100644 index 0000000..a60fae8 --- /dev/null +++ b/linc_convert/models/zarr_config.py @@ -0,0 +1,234 @@ +from dataclasses import dataclass +import os +from typing import Literal, Annotated +import abc +from cyclopts import Parameter +import numpy as np +import zarr + +@dataclass +class _ZarrConfig: + """ + Parameters + ---------- + out + Path to the output Zarr directory [\.ome.zarr] + chunk + Test + """ + out: str = "" + chunk: int = 128 + compressor: str = "blosc" + compressor_opt: str = "{}" + shard: list[int | str] | None = None + version: Literal[2, 3] = 3 + driver: Literal["zarr-python","tensorstore", "zarrita"] = "zarr-python" + nii: bool = False + + def __post_init__(self): + print(self) + +ZarrConfig = Annotated[_ZarrConfig, Parameter(name="*")] + +class AbstractZarrIO(abc.ABC): + + def __init__(self, config: _ZarrConfig): + self.config = config + + def __getitem__(self, index): + pass + def __setitem__(self, index): + pass + def create_dataset(self): + pass + +class ZarrPythonIO(AbstractZarrIO): + def __init__(self, config: _ZarrConfig, overwrite=True): + super().__init__(config) + omz = zarr.storage.DirectoryStore(config.out) + self.zgroup = zarr.group(store=omz, overwrite=overwrite) + def create_dataset(self, + chunk, + dtype, + dimension_separator = r"/", + order = "F", + fill_value = 0, + compressor = None): + + make_compressor(compressor, **compressor_opt) + + pass + def __getitem__(self, index): + return self.zgroup[index] + +class TensorStoreIO(AbstractZarrIO): + + pass + + + +def default_write_config( + path: os.PathLike | str, + shape: list[int], + dtype: np.dtype | str, + chunk: list[int] = [32], + shard: list[int] | Literal["auto"] | None = None, + compressor: str = "blosc", + compressor_opt: dict | None = None, + version: int = 3, +) -> dict: + """ + Generate a default TensorStore configuration. + Parameters + ---------- + chunk : list[int] + Chunk size. + shard : list[int], optional + Shard size. No sharding if `None`. + compressor : str + Compressor name + version : int + Zarr version + Returns + ------- + config : dict + Configuration + """ + path = UPath(path) + if not path.protocol: + path = "file://" / path + + # Format compressor + if version == 3 and compressor == "zlib": + compressor = "gzip" + if version == 2 and compressor == "gzip": + compressor = "zlib" + compressor_opt = compressor_opt or {} + + # Prepare chunk size + if isinstance(chunk, int): + chunk = [chunk] + chunk = chunk[:1] * max(0, len(shape) - len(chunk)) + chunk + + # Prepare shard size + if shard: + if shard == "auto": + shard = auto_shard_size(shape, dtype) + if isinstance(shard, int): + shard = [shard] + shard = shard[:1] * max(0, len(shape) - len(shard)) + shard + + # Fix incompatibilities + shard, chunk = fix_shard_chunk(shard, chunk, shape) + + # ------------------------------------------------------------------ + # Zarr 3 + # ------------------------------------------------------------------ + if version == 3: + if compressor and compressor != "raw": + compressor = [make_compressor_v3(compressor, **compressor_opt)] + else: + compressor = [] + + codec_little_endian = {"name": "bytes", "configuration": {"endian": "little"}} + + if shard: + chunk_grid = { + "name": "regular", + "configuration": {"chunk_shape": shard}, + } + + sharding_codec = { + "name": "sharding_indexed", + "configuration": { + "chunk_shape": chunk, + "codecs": [ + codec_little_endian, + *compressor, + ], + "index_codecs": [ + codec_little_endian, + {"name": "crc32c"}, + ], + "index_location": "end", + }, + } + codecs = [sharding_codec] + + else: + chunk_grid = {"name": "regular", "configuration": {"chunk_shape": chunk}} + codecs = [ + codec_little_endian, + *compressor, + ] + + metadata = { + "chunk_grid": chunk_grid, + "codecs": codecs, + "data_type": np.dtype(dtype).name, + "fill_value": 0, + "chunk_key_encoding": { + "name": "default", + "configuration": {"separator": r"/"}, + }, + } + config = { + "driver": "zarr3", + "metadata": metadata, + } + + # ------------------------------------------------------------------ + # Zarr 2 + # ------------------------------------------------------------------ + else: + if compressor and compressor != "raw": + compressor = make_compressor_v2(compressor, **compressor_opt) + else: + compressor = None + + metadata = { + "chunks": chunk, + "order": "F", + "dtype": np.dtype(dtype).str, + "fill_value": 0, + "compressor": compressor, + } + config = { + "driver": "zarr", + "metadata": metadata, + "key_encoding": r"/", + } + + # Prepare store + config["metadata"]["shape"] = shape + config["kvstore"] = make_kvstore(path) + + return config + + +def default_read_config(path: os.PathLike | str) -> dict: + """ + Generate a TensorStore configuration to read an existing Zarr. + Parameters + ---------- + path : PathLike | str + Path to zarr array. + """ + path = UPath(path) + if not path.protocol: + path = "file://" / path + if (path / "zarr.json").exists(): + zarr_version = 3 + elif (path / ".zarray").exists(): + zarr_version = 2 + else: + raise ValueError("Cannot find zarr.json or .zarray file") + return { + "kvstore": make_kvstore(path), + "driver": "zarr3" if zarr_version == 3 else "zarr", + "open": True, + "create": False, + "delete_existing": False, + } + + diff --git a/tests/helper.py b/tests/helper.py index eb865d7..92b796e 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -27,6 +27,8 @@ def _cmp_zarr_archives(path1: str, path2: str) -> bool: return False if zarr1.attrs != zarr2.attrs: print("attrs mismatch") + print(dict(zarr1.attrs)) + print(dict(zarr2.attrs)) return False # Compare each array in both archives diff --git a/tests/test_lsm.py b/tests/test_lsm.py index 601da02..1318724 100644 --- a/tests/test_lsm.py +++ b/tests/test_lsm.py @@ -5,6 +5,7 @@ from helper import _cmp_zarr_archives from linc_convert.modalities.lsm import mosaic +from linc_convert.models.zarr_config import _ZarrConfig def _write_test_data(directory: str) -> None: @@ -24,5 +25,6 @@ def _write_test_data(directory: str) -> None: def test_lsm(tmp_path): _write_test_data(tmp_path) output_zarr = tmp_path / "output.zarr" - mosaic.convert(str(tmp_path), str(output_zarr)) + config=_ZarrConfig( str(output_zarr)) + mosaic.convert(str(tmp_path), zarr_config=config) assert _cmp_zarr_archives(str(output_zarr), "data/lsm.zarr.zip")