diff --git a/src/ngio/core/__init__.py b/src/ngio/core/__init__.py index 544eaab..6ba57ed 100644 --- a/src/ngio/core/__init__.py +++ b/src/ngio/core/__init__.py @@ -1 +1,6 @@ """Core classes for the ngio library.""" + +from ngio.core.image_handler import Image +from ngio.core.label_handler import Label + +__all__ = ["Image", "Label"] diff --git a/src/ngio/core/dimensions.py b/src/ngio/core/dimensions.py new file mode 100644 index 0000000..bfdf3b2 --- /dev/null +++ b/src/ngio/core/dimensions.py @@ -0,0 +1,102 @@ +"""Dimension metadata. + +This is not related to the NGFF metadata, +but it is based on the actual metadata of the image data. +""" + +from zarr import Array + + +class Dimensions: + """Dimension metadata.""" + + def __init__( + self, array: Array, axes_names: list[str], axes_order: list[int] + ) -> None: + """Create a Dimension object from a Zarr array. + + Args: + array (Array): The Zarr array. + axes_names (list[str]): The names of the axes. + axes_order (list[int]): The order of the axes. + """ + # We init with the shape only but in the ZarrV3 + # we will have to validate the axes names too. + self._on_disk_shape = array.shape + + if len(self._on_disk_shape) != len(axes_names): + raise ValueError( + "The number of axes names must match the number of dimensions." + ) + + self._axes_names = axes_names + self._axes_order = axes_order + self._shape = [self._on_disk_shape[i] for i in axes_order] + self._shape_dict = dict(zip(axes_names, self._shape, strict=True)) + + @property + def shape(self) -> tuple[int, ...]: + """Return the shape as a tuple.""" + return tuple(self._shape) + + @property + def on_disk_shape(self) -> tuple[int, ...]: + """Return the shape as a tuple.""" + return tuple(self._on_disk_shape) + + def ad_dict(self) -> dict[str, int]: + """Return the shape as a dictionary.""" + return self._shape_dict + + @property + def t(self) -> int: + """Return the time dimension.""" + return self._shape_dict.get("t", None) + + @property + def c(self) -> int: + """Return the channel dimension.""" + return self._shape_dict.get("c", None) + + @property + def z(self) -> int: + """Return the z dimension.""" + return self._shape_dict.get("z", None) + + @property + def y(self) -> int: + """Return the y dimension.""" + return self._shape_dict.get("y", None) + + @property + def x(self) -> int: + """Return the x dimension.""" + return self._shape_dict.get("x", None) + + @property + def on_disk_ndim(self) -> int: + """Return the number of dimensions on disk.""" + return len(self._on_disk_shape) + + @property + def ndim(self) -> int: + """Return the number of dimensions.""" + return len(self._shape) + + def is_3D(self) -> bool: + """Return whether the data is 3D.""" + if (self.z is None) or (self.z == 1): + return False + return True + + def is_time_series(self) -> bool: + """Return whether the data is a time series.""" + if (self.t is None) or (self.t == 1): + return False + return True + + def has_multiple_channels(self) -> bool: + """Return whether the data has multiple channels.""" + if (self.c is None) or (self.c == 1): + return False + return True diff --git a/src/ngio/core/image_like_handler.py b/src/ngio/core/image_like_handler.py index f0c3fbb..7df3f29 100644 --- a/src/ngio/core/image_like_handler.py +++ b/src/ngio/core/image_like_handler.py @@ -2,9 +2,11 @@ from typing import Literal +import numpy as np import zarr -from ngio.io import StoreOrGroup, open_group +from ngio.core.dimensions import Dimensions +from ngio.io import StoreOrGroup, open_group_wrapper from ngio.ngff_meta import ( Dataset, ImageLabelMeta, @@ -47,7 +49,9 @@ def __init__( cache (bool): Whether to cache the metadata. """ if not isinstance(store, zarr.Group): - store = open_group(store=store, mode="r") + store = open_group_wrapper(store=store, mode="r+") + + self._group = store self._metadata_handler = get_ngff_image_meta_handler( store=store, meta_mode=meta_mode, cache=cache @@ -55,15 +59,47 @@ def __init__( # Find the level / resolution index metadata = self._metadata_handler.load_meta() - self._dataset = metadata.get_dataset( + dataset = metadata.get_dataset( path=path, idx=idx, pixel_size=pixel_size, highest_resolution=highest_resolution, strict=strict, ) - self._group = store + self._init_dataset(dataset) + + def _init_dataset(self, dataset: Dataset): + """Set the dataset of the image. + + This method is for internal use only. + """ + self._dataset = dataset + + if self._dataset.path not in self._group.array_keys(): + raise ValueError(f"Dataset {self._dataset.path} not found in the group.") + + self._array = self._group[self.dataset.path] + self._diminesions = Dimensions( + array=self._array, + axes_names=self._dataset.axes_names, + axes_order=self._dataset.axes_order, + ) + + def _debug_set_new_dataset( + self, + new_dataset: Dataset, + ): + """Debug method to change the the dataset metadata. + + This methods allow to change dataset after initialization. + This allow to skip the OME-NGFF metadata validation. + This method is for testing/debug purposes only. + DO NOT USE THIS METHOD IN PRODUCTION CODE. + """ + self._init_dataset(new_dataset) + + # Method to get the metadata of the image @property def group(self) -> zarr.Group: """Return the Zarr group containing the image data.""" @@ -108,3 +144,28 @@ def space_axes_unit(self) -> SpaceUnits: def pixel_size(self) -> PixelSize: """Return the pixel resolution of the image.""" return self.dataset.pixel_size + + # Method to get the data of the image + @property + def array(self) -> zarr.Array: + """Return the image data as a Zarr array.""" + return self._array + + @property + def dimensions(self) -> Dimensions: + """Return the dimensions of the image.""" + return self._diminesions + + @property + def shape(self) -> tuple[int, ...]: + """Return the shape of the image.""" + return self.dimensions.shape + + @property + def on_disk_shape(self) -> tuple[int, ...]: + """Return the shape of the image.""" + return self.dimensions.on_disk_shape + + def get_data(self) -> np.ndarray: + """Return the image data as a Zarr array.""" + return self.array[...] diff --git a/src/ngio/core/ngff_image.py b/src/ngio/core/ngff_image.py index 96e60df..c4e8ac2 100644 --- a/src/ngio/core/ngff_image.py +++ b/src/ngio/core/ngff_image.py @@ -5,7 +5,7 @@ from zarr.store.common import StoreLike from ngio.core.image_handler import Image -from ngio.io import open_group +from ngio.io import open_group_wrapper from ngio.ngff_meta import FractalImageLabelMeta, get_ngff_image_meta_handler T = TypeVar("T") @@ -67,7 +67,7 @@ class NgffImage: def __init__(self, store: StoreLike) -> None: """Initialize the NGFFImage in read mode.""" self.store = store - self.group = open_group(store=store, mode="r+") + self.group = open_group_wrapper(store=store, mode="r+") self._image_meta = get_ngff_image_meta_handler( self.group, meta_mode="image", cache=False ) diff --git a/src/ngio/core/roi.py b/src/ngio/core/roi.py index be24ae1..f91e452 100644 --- a/src/ngio/core/roi.py +++ b/src/ngio/core/roi.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from ngio.ngff_meta.fractal_image_meta import SpaceUnits, PixelSize +import numpy as np class Point(BaseModel): @@ -15,13 +16,31 @@ class WorldCooROI(BaseModel): """Region of interest (ROI) metadata.""" field_index: str - p1: Point - p2: Point + x: float + y: float + z: float + x_length: float + y_length: float + z_length: float unit: SpaceUnits - def to_raster_coo(self, pixel_size: float) -> "RasterCooROI": + def _to_raster(self, value: float, pixel_size: PixelSize, max_shape: int) -> int: """Convert to raster coordinates.""" - raise NotImplementedError + round_value = int(np.round(value / pixel_size)) + return min(round_value, max_shape) + + def to_raster_coo(self, pixel_size: PixelSize, max_shape) -> "RasterCooROI": + """Convert to raster coordinates.""" + RasterCooROI( + field_index=self.field_index, + x=self._to_raster(value=self.x, pixel_size=pixel_size.x, max_shape=2**32), + y=int(self.y / pixel_size.y), + z=int(self.z / pixel_size.z), + x_length=int(self.x_length / pixel_size.x), + y_length=int(self.y_length / pixel_size.y), + z_length=int(self.z_length / pixel_size.z), + original_roi=self, + ) class RasterCooROI(BaseModel): @@ -34,7 +53,20 @@ class RasterCooROI(BaseModel): x_length: int y_length: int z_length: int + original_roi: WorldCooROI def to_world_coo(self, pixel_size: float) -> "WorldCooROI": """Convert to world coordinates.""" raise NotImplementedError + + def x_slice(self) -> slice: + """Return the slice for the x-axis.""" + return slice(self.x, self.x + self.x_length) + + def y_slice(self) -> slice: + """Return the slice for the y-axis.""" + return slice(self.y, self.y + self.y_length) + + def z_slice(self) -> slice: + """Return the slice for the z-axis.""" + return slice(self.z, self.z + self.z_length) diff --git a/src/ngio/ngff_meta/utils.py b/src/ngio/ngff_meta/utils.py index 8af7199..68aa52a 100644 --- a/src/ngio/ngff_meta/utils.py +++ b/src/ngio/ngff_meta/utils.py @@ -18,7 +18,7 @@ ) -def _create_image_metadata( +def _create_multiscale_meta( on_disk_axis: list[str] = ("t", "c", "z", "y", "x"), pixel_sizes: PixelSize | None = None, xy_scaling_factor: float = 2.0, @@ -110,7 +110,7 @@ def create_image_metadata( version: The version of NGFF metadata. """ - datasets = _create_image_metadata( + datasets = _create_multiscale_meta( on_disk_axis=on_disk_axis, pixel_sizes=pixel_sizes, xy_scaling_factor=xy_scaling_factor, @@ -192,7 +192,7 @@ def create_label_metadata( name: The name of the metadata. version: The version of NGFF metadata. """ - datasets, _ = _create_image_metadata( + datasets = _create_multiscale_meta( on_disk_axis=on_disk_axis, pixel_sizes=pixel_sizes, xy_scaling_factor=xy_scaling_factor, diff --git a/src/ngio/pipes/__init__.py b/src/ngio/pipes/__init__.py new file mode 100644 index 0000000..e12b955 --- /dev/null +++ b/src/ngio/pipes/__init__.py @@ -0,0 +1,7 @@ +"""A module to handle data transforms for image data.""" + +from ngio.pipes.common import ArrayLike +from ngio.pipes._slicer_transforms import NaiveSlicer +from ngio.pipes.data_transform_pipe import DataTransformPipe + +__all__ = ["ArrayLike", "DataTransformPipe", "NaiveSlicer"] diff --git a/src/ngio/pipes/_slicer_transforms.py b/src/ngio/pipes/_slicer_transforms.py new file mode 100644 index 0000000..8aa18d9 --- /dev/null +++ b/src/ngio/pipes/_slicer_transforms.py @@ -0,0 +1,105 @@ +from typing import Protocol + +import numpy as np +from dask import array as da + +from ngio.core.roi import RasterCooROI +from ngio.pipes import ArrayLike + + +class SlicerTransform(Protocol): + """A special class of transform that load a specific slice of the data.""" + + def get(self, data: ArrayLike) -> ArrayLike: + """Select a slice of the data and return the result.""" + ... + + def push( + self, + data: ArrayLike, + patch: ArrayLike, + ) -> ArrayLike: + """Replace the slice of the data with the patch and return the result.""" + ... + + +class NaiveSlicer: + """A simple slicer that requires all axes to be specified.""" + + def __init__( + self, + on_disk_axes_name: list[str], + axes_order: list[int], + t: int | slice | None = None, + c: int | slice | None = None, + z: int | slice | None = None, + y: int | slice | None = None, + x: int | slice | None = None, + ): + """Initialize the NaiveSlicer object.""" + self.on_disk_axes_name = on_disk_axes_name + + # Check if axes_order is trivial + if axes_order != list(range(len(axes_order))): + self.axes_order = axes_order + else: + self.axes_order = None + + self.slices = { + "t": t if t is not None else slice(None), + "c": c if c is not None else slice(None), + "z": z if z is not None else slice(None), + "y": y if y is not None else slice(None), + "x": x if x is not None else slice(None), + } + + def get(self, data: ArrayLike) -> ArrayLike: + """Select a slice of the data and return the result.""" + slice_on_disk_order = [self.slices[axis] for axis in self.on_disk_axes_name] + patch = data[tuple(slice_on_disk_order)] + + # If sel.axis_order is trivial, skip the transpose + if self.axes_order is None: + return patch + + if isinstance(patch, np.ndarray): + patch = np.transpose(patch, self.axes_order) + elif isinstance(patch, da.core.Array): + patch = da.transpose(patch, self.axes_order) + return patch + + def push(self, data: ArrayLike, patch: ArrayLike) -> ArrayLike: + """Replace the slice of the data with the patch and return the result.""" + slice_on_disk_order = [self.slices[axis] for axis in self.on_disk_axes_name] + # If sel.axis_order is trivial, skip the transpose + if self.axes_order is not None: + if isinstance(patch, np.ndarray): + patch = np.transpose(patch, self.axes_order) + elif isinstance(patch, da.core.Array): + patch = da.transpose(patch, self.axes_order) + + data[tuple(slice_on_disk_order)] = patch + return data + + +class RoiSlicer(NaiveSlicer): + """A slicer that requires all axes to be specified.""" + + def __init__( + self, + on_disk_axes_name: list[str], + axes_order: list[int], + roi: RasterCooROI, + t: int | slice | None = None, + c: int | slice | None = None, + ): + """Initialize the RoiSlicer object.""" + super().__init__( + on_disk_axes_name=on_disk_axes_name, + axes_order=axes_order, + t=t, + c=c, + z=roi.z_slice(), + y=roi.y_slice(), + x=roi.x_slice(), + ) diff --git a/src/ngio/pipes/_transforms.py b/src/ngio/pipes/_transforms.py new file mode 100644 index 0000000..d7964c5 --- /dev/null +++ b/src/ngio/pipes/_transforms.py @@ -0,0 +1,14 @@ +from ngio.pipes import ArrayLike +from typing import Protocol + + +class Transform(Protocol): + """A protocol for data transforms to be performed on image data.""" + + def get(self, data: ArrayLike) -> ArrayLike: + """Apply the transform to the data and return the result.""" + ... + + def push(self, data: ArrayLike) -> ArrayLike: + """Apply the reverse transform to the data and return the result.""" + ... diff --git a/src/ngio/pipes/common.py b/src/ngio/pipes/common.py new file mode 100644 index 0000000..87b6c58 --- /dev/null +++ b/src/ngio/pipes/common.py @@ -0,0 +1,5 @@ +import numpy as np +import zarr +from dask import array as da + +ArrayLike = np.ndarray | da.core.Array | zarr.Array diff --git a/src/ngio/pipes/data_transform_pipe.py b/src/ngio/pipes/data_transform_pipe.py new file mode 100644 index 0000000..73b920e --- /dev/null +++ b/src/ngio/pipes/data_transform_pipe.py @@ -0,0 +1,47 @@ +"""A module to handle data transforms for image data.""" + +import numpy as np +import zarr +from dask import array as da + +from ngio.pipes._slicer_transforms import SlicerTransform +from ngio.pipes._transforms import Transform +from ngio.pipes import ArrayLike + + +class DataTransformPipe: + """A class to handle a pipeline of data transforms. + + For example, a pipeline of transforms can be: + - Selecte a subset of the data + - Shuffle the axes of the data + - Normalize the data + + All these in reverse order will be applied to the data when pushing a patch. + + """ + + def __init__(self, slicer: SlicerTransform, *data_transforms: Transform): + """Initialize the DataLoadPipe object. + + Args: + slicer (SlicerTransform): The first transform to be applied to the + data MUST be a slicer. + *data_transforms (Transform): A list of transforms to be + applied to the data in order. + """ + self.slicer = slicer + self.list_of_transforms = data_transforms + + def get(self, data: ArrayLike) -> ArrayLike: + """Apply all the transforms to the data and return the result.""" + data = self.slicer.get(data) + for transform in self.list_of_transforms: + data = transform.get(data) + return data + + def push(self, data: ArrayLike, patch: ArrayLike) -> ArrayLike: + """Apply all the reverse transforms to the data and return the result.""" + for transform in reversed(self.list_of_transforms): + patch = transform.push(patch) + return self.slicer.push(data, patch) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 0233e9c..2de68dc 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -17,4 +17,10 @@ def ome_zarr_image_v04_path(tmpdir): base_ome_zarr_meta = base_ome_zarr_meta group.attrs.update(base_ome_zarr_meta) + + # shape = (3, 10, 256, 256) + for i, path in enumerate(["0", "1", "2", "3"]): + shape = (3, 10, 256 // (2**i), 256 // (2**i)) + group.create_array(name=path, fill_value=0, shape=shape) + return zarr_path diff --git a/tests/core/test_image_like_handler.py b/tests/core/test_image_like_handler.py index 084b5db..584b9cc 100644 --- a/tests/core/test_image_like_handler.py +++ b/tests/core/test_image_like_handler.py @@ -7,6 +7,13 @@ def test_ngff_image(self, ome_zarr_image_v04_path): assert image_handler.path == "0" assert image_handler.pixel_size.zyx == (1.0, 0.1625, 0.1625) assert image_handler.axes_names == ["c", "z", "y", "x"] + assert image_handler.space_axes_names == ["z", "y", "x"] + assert image_handler.dimensions.shape == (3, 10, 256, 256) + assert image_handler.shape == (3, 10, 256, 256) + assert image_handler.dimensions.z == 10 + assert image_handler.dimensions.is_3D() + assert not image_handler.dimensions.is_time_series() + assert image_handler.dimensions.has_multiple_channels() def test_ngff_image_from_pixel_size(self, ome_zarr_image_v04_path): from ngio.core.image_like_handler import ImageLike diff --git a/tests/io/test_zarr_group_utils.py b/tests/io/test_zarr_group_utils.py index ce09e86..433a564 100644 --- a/tests/io/test_zarr_group_utils.py +++ b/tests/io/test_zarr_group_utils.py @@ -7,81 +7,19 @@ class TestGroupUtils: def test_attrs(self) -> dict: return {"a": 1, "b": 2, "c": 3} - def test_update_group_attrs(self, store_fixture): - from ngio.io._zarr_group_utils import ( - read_group_attrs, - update_group_attrs, - ) + def test_open_group_wrapper(self, store_fixture): + from ngio.io import open_group_wrapper store, zarr_format = store_fixture - - update_group_attrs(store=store, attrs=self.test_attrs, zarr_format=zarr_format) - attrs = read_group_attrs(store=store, zarr_format=zarr_format) - assert attrs == self.test_attrs, "Attributes were not written correctly." - - update_group_attrs(store=store, attrs={"new": 1}, zarr_format=zarr_format) - attrs = read_group_attrs(store=store, zarr_format=zarr_format) - expected = {**self.test_attrs, "new": 1} - assert attrs == expected, "Attributes were not written correctly." - - def test_overwrite_group_attrs(self, store_fixture): - from ngio.io._zarr_group_utils import ( - overwrite_group_attrs, - read_group_attrs, - ) - - store, zarr_format = store_fixture - - overwrite_group_attrs( - store=store, attrs=self.test_attrs, zarr_format=zarr_format - ) - attrs = read_group_attrs(store=store, zarr_format=zarr_format) - assert attrs == self.test_attrs, "Attributes were not written correctly." - - def test_list_group_arrays(self, store_fixture): - from ngio.io._zarr_group_utils import list_group_arrays - - store, zarr_format = store_fixture - - arrays = list_group_arrays(store=store, zarr_format=zarr_format) - assert len(arrays) == 3, "Arrays were not listed correctly." - - def test_list_group_groups(self, store_fixture): - from ngio.io._zarr_group_utils import list_group_groups - - store, zarr_format = store_fixture - - groups = list_group_groups(store=store, zarr_format=zarr_format) - assert len(groups) == 3, "Groups were not listed correctly." - - def test_raise_file_not_found_error(self): - from ngio.io._zarr_group_utils import open_group - - with pytest.raises(FileNotFoundError): - open_group(store="nonexistent.zarr", mode="r", zarr_format=2) - - with pytest.raises(FileNotFoundError): - open_group( - store=zarr.store.LocalStore("nonexistent.zarr"), mode="r", zarr_format=2 - ) - - def test_raise_permission_error(self, local_zarr_path_v2): - from ngio.io._zarr_group_utils import open_group - - local_zarr_path, _ = local_zarr_path_v2 - - with pytest.raises(PermissionError): - open_group( - store=zarr.store.LocalStore(local_zarr_path, mode="r"), - mode="r+", - zarr_format=2, - ) + group = open_group_wrapper(store=store, mode="r+", zarr_format=zarr_format) + group.attrs.update(self.test_attrs) + assert dict(group.attrs) == self.test_attrs def test_raise_not_implemented_error(self): - from ngio.io._zarr_group_utils import open_group + from ngio.io._zarr_group_utils import open_group_wrapper with pytest.raises(NotImplementedError): - open_group( + open_group_wrapper( store=zarr.store.RemoteStore(url="https://test.com/test.zarr"), mode="r", zarr_format=3, diff --git a/tests/ngff_meta/test_v04.py b/tests/ngff_meta/test_v04.py index 44181b2..5bd57d2 100644 --- a/tests/ngff_meta/test_v04.py +++ b/tests/ngff_meta/test_v04.py @@ -5,8 +5,8 @@ class TestOMEZarrHandlerV04: def test_basic_workflow(self, ome_zarr_image_v04_path): - from ngio.io import read_group_attrs from ngio.ngff_meta import get_ngff_image_meta_handler + from ngio.ngff_meta.v04.zarr_utils import NgffImageMeta04 handler = get_ngff_image_meta_handler( store=ome_zarr_image_v04_path, meta_mode="image" @@ -18,12 +18,14 @@ def test_basic_workflow(self, ome_zarr_image_v04_path): with open("tests/data/meta_v04/base_ome_zarr_image_meta.json") as f: base_ome_zarr_meta = json.load(f) - saved_meta = read_group_attrs(store=ome_zarr_image_v04_path, zarr_format=2) + saved_meta = NgffImageMeta04(**handler.group.attrs).model_dump( + exclude_none=True + ) assert saved_meta == base_ome_zarr_meta def test_basic_workflow_with_cache(self, ome_zarr_image_v04_path): - from ngio.io import read_group_attrs from ngio.ngff_meta import get_ngff_image_meta_handler + from ngio.ngff_meta.v04.zarr_utils import NgffImageMeta04 handler = get_ngff_image_meta_handler( store=ome_zarr_image_v04_path, meta_mode="image", cache=True @@ -35,7 +37,9 @@ def test_basic_workflow_with_cache(self, ome_zarr_image_v04_path): with open("tests/data/meta_v04/base_ome_zarr_image_meta.json") as f: base_ome_zarr_meta = json.load(f) - saved_meta = read_group_attrs(store=ome_zarr_image_v04_path, zarr_format=2) + saved_meta = NgffImageMeta04(**handler.group.attrs).model_dump( + exclude_none=True + ) assert saved_meta == base_ome_zarr_meta def test_wrong_axis_order(self):