From e2b48f90822be9b445a033fd1bb3fbb5e8985151 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Wed, 6 Nov 2024 15:29:45 +0100 Subject: [PATCH 1/5] minor table refactor --- docs/notebooks/basic_usage.ipynb | 2 +- docs/notebooks/image.ipynb | 4 +-- docs/notebooks/processing.ipynb | 4 +-- pyproject.toml | 5 ++-- src/ngio/core/image_like_handler.py | 4 +-- src/ngio/tables/_utils.py | 37 ++++++++++++------------ src/ngio/tables/tables_group.py | 5 +--- src/ngio/tables/v1/_generic_table.py | 2 +- src/ngio/tables/v1/feature_tables.py | 8 ++--- src/ngio/tables/v1/masking_roi_tables.py | 12 ++++---- src/ngio/tables/v1/roi_tables.py | 12 ++++---- 11 files changed, 47 insertions(+), 48 deletions(-) diff --git a/docs/notebooks/basic_usage.ipynb b/docs/notebooks/basic_usage.ipynb index 04dcb84..8734acc 100644 --- a/docs/notebooks/basic_usage.ipynb +++ b/docs/notebooks/basic_usage.ipynb @@ -185,7 +185,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dev2", + "display_name": "ngio", "language": "python", "name": "python3" }, diff --git a/docs/notebooks/image.ipynb b/docs/notebooks/image.ipynb index 6f544bc..0d5b5b5 100644 --- a/docs/notebooks/image.ipynb +++ b/docs/notebooks/image.ipynb @@ -301,7 +301,7 @@ "\n", "print(f\"New list of feature table: {ngff_image.table.list(table_type='feature_table')}\")\n", "feat_table.set_table(feat_df)\n", - "feat_table.write()\n", + "feat_table.consolidate()\n", "\n", "feat_table.table" ] @@ -309,7 +309,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dev2", + "display_name": "ngio", "language": "python", "name": "python3" }, diff --git a/docs/notebooks/processing.ipynb b/docs/notebooks/processing.ipynb index 0e4c67e..36d08b2 100644 --- a/docs/notebooks/processing.ipynb +++ b/docs/notebooks/processing.ipynb @@ -167,7 +167,7 @@ " roi_list.append(roi)\n", "\n", "mip_roi_table.set_rois(roi_list, overwrite=True)\n", - "mip_roi_table.write()\n", + "mip_roi_table.consolidate()\n", "\n", "mip_roi_table.table" ] @@ -303,7 +303,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 495a3da..945c578 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,11 +50,12 @@ core = ["zarr<3", "dask[distributed]", "dask-image"] test = ["zarr<3", "pytest", "pytest-cov"] -# add anything else you like to have in your dev environment here dev2 = [ "zarr<3", "dask[distributed]", "dask-image", + "napari", + "pyqt5", "scikit-image", "matplotlib", "ipython", @@ -66,7 +67,7 @@ dev2 = [ "pre-commit", "rich", # https://github.com/Textualize/rich "ruff", -] +] # add anything else you like to have in your dev environment here dev3 = [ "zarr==v3.0.0-alpha.4", diff --git a/src/ngio/core/image_like_handler.py b/src/ngio/core/image_like_handler.py index ca560b2..b65df5e 100644 --- a/src/ngio/core/image_like_handler.py +++ b/src/ngio/core/image_like_handler.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Any, Literal -from warnings import warn import dask.array as da import numpy as np @@ -21,6 +20,7 @@ get_ngff_image_meta_handler, ) from ngio.pipes import DataTransformPipe, NaiveSlicer, RoiSlicer, on_disk_zoom +from ngio.utils import ngio_logger from ngio.utils._common_types import ArrayLike @@ -62,7 +62,7 @@ def __init__( _label_group (LabelGroup): The group containing the label data (internal use). """ if not strict: - warn("Strict mode is not fully supported yet.", UserWarning, stacklevel=2) + ngio_logger.warning("Strict mode is not fully supported yet.") self._mode = mode if not isinstance(store, zarr.Group): diff --git a/src/ngio/tables/_utils.py b/src/ngio/tables/_utils.py index 7ac9a28..90df988 100644 --- a/src/ngio/tables/_utils.py +++ b/src/ngio/tables/_utils.py @@ -6,12 +6,7 @@ import pandas as pd import pandas.api.types as ptypes - -class TableValidationError(Exception): - """Error raised when a table is not formatted correctly.""" - - pass - +from ngio.utils import NgioTableValidationError Validator = Callable[[pd.DataFrame], pd.DataFrame] @@ -19,7 +14,7 @@ class TableValidationError(Exception): def _check_for_mixed_types(series: pd.Series) -> None: """Check if the column has mixed types.""" if series.apply(type).nunique() > 1: - raise TableValidationError( + raise NgioTableValidationError( f"Column {series.name} has mixed types: " f"{series.apply(type).unique()}. " "Type of all elements must be the same." @@ -34,7 +29,7 @@ def _check_for_supported_types(series: pd.Series) -> Literal["str", "int", "nume return "int" if ptypes.is_numeric_dtype(series): return "numeric" - raise TableValidationError( + raise NgioTableValidationError( f"Column {series.name} has unsupported type: {series.dtype}." " Supported types are string and numerics." ) @@ -63,7 +58,9 @@ def _check_index_key( table_df = table_df.set_index(index_key) if table_df.index.name != index_key: - raise TableValidationError(f"index_key: {index_key} not found in data frame") + raise NgioTableValidationError( + f"index_key: {index_key} not found in data frame" + ) if index_type == "str": if ptypes.is_integer_dtype(table_df.index): @@ -71,7 +68,9 @@ def _check_index_key( table_df.index = table_df.index.astype(str) if not ptypes.is_string_dtype(table_df.index): - raise TableValidationError(f"index_key {index_key} must be of string type") + raise NgioTableValidationError( + f"index_key {index_key} must be of string type" + ) elif index_type == "int": if ptypes.is_string_dtype(table_df.index): @@ -80,7 +79,7 @@ def _check_index_key( table_df.index = table_df.index.astype(int) except ValueError as e: if "invalid literal for int() with base 10" in str(e): - raise TableValidationError( + raise NgioTableValidationError( f"index_key {index_key} must be of " "integer type, but found string. We " "tried implicit conversion failed." @@ -89,10 +88,12 @@ def _check_index_key( raise e from e if not ptypes.is_integer_dtype(table_df.index): - raise TableValidationError(f"index_key {index_key} must be of integer type") + raise NgioTableValidationError( + f"index_key {index_key} must be of integer type" + ) else: - raise TableValidationError(f"index_type {index_type} not recognized") + raise NgioTableValidationError(f"index_type {index_type} not recognized") return table_df @@ -219,7 +220,7 @@ def table_ad_to_df( elif table_ad.obs.index.name is not None: if validate_index_name: if table_ad.obs.index.name != index_key: - raise TableValidationError( + raise NgioTableValidationError( f"Index key {index_key} not found in AnnData object." ) table_df.index = table_ad.obs.index @@ -227,7 +228,7 @@ def table_ad_to_df( table_df.index = table_ad.obs.index table_df.index.name = index_key else: - raise TableValidationError( + raise NgioTableValidationError( f"Index key {index_key} not found in AnnData object." ) @@ -270,7 +271,7 @@ def validate_columns( table_header = table_df.columns for column in required_columns: if column not in table_header: - raise TableValidationError(f"Column {column} is required in ROI table") + raise NgioTableValidationError(f"Column {column} is required in ROI table") if optional_columns is None: return table_df @@ -278,7 +279,7 @@ def validate_columns( possible_columns = [*required_columns, *optional_columns] for column in table_header: if column not in possible_columns: - raise TableValidationError( + raise NgioTableValidationError( f"Column {column} is not recognized in ROI table" ) @@ -292,6 +293,6 @@ def validate_unique_index(table_df: pd.DataFrame) -> pd.DataFrame: # Find the duplicates duplicates = table_df.index[table_df.index.duplicated()].tolist() - raise TableValidationError( + raise NgioTableValidationError( f"Index of the table contains duplicates values. Duplicate: {duplicates}" ) diff --git a/src/ngio/tables/tables_group.py b/src/ngio/tables/tables_group.py index 6468fe2..e1616e3 100644 --- a/src/ngio/tables/tables_group.py +++ b/src/ngio/tables/tables_group.py @@ -1,7 +1,4 @@ -"""Module for handling the /tables group in an OME-NGFF file. - -The /tables group contains t -""" +"""Module for handling the /tables group in an OME-NGFF file.""" from typing import Literal from warnings import warn diff --git a/src/ngio/tables/v1/_generic_table.py b/src/ngio/tables/v1/_generic_table.py index 828168a..1b5eabb 100644 --- a/src/ngio/tables/v1/_generic_table.py +++ b/src/ngio/tables/v1/_generic_table.py @@ -164,7 +164,7 @@ def add_validator(self, validator: Validator) -> None: self._validators = [] self._validators.append(validator) - def write(self, meta: BaseModel) -> None: + def consolidate(self, meta: BaseModel) -> None: """Write the current state of the table to the Zarr file.""" table = self.table table = validate_table( diff --git a/src/ngio/tables/v1/feature_tables.py b/src/ngio/tables/v1/feature_tables.py index 87cd0bb..6af8a5a 100644 --- a/src/ngio/tables/v1/feature_tables.py +++ b/src/ngio/tables/v1/feature_tables.py @@ -116,13 +116,13 @@ def table(self) -> pd.DataFrame: return self._table_handler.table @table.setter - def table(self, table: pd.DataFrame): + def table(self, table: pd.DataFrame) -> None: """Set the feature table.""" raise NotImplementedError( "Setting the table is not implemented. Please use the 'set_table' method." ) - def set_table(self, table: pd.DataFrame): + def set_table(self, table: pd.DataFrame) -> None: """Set the feature table.""" self._table_handler.set_table(table) @@ -141,6 +141,6 @@ def label_image_name(self, get_full_path: bool = False) -> str: return path.split("/")[-1] - def write(self): + def consolidate(self) -> None: """Write the table to the group.""" - self._table_handler.write(meta=self.meta) + self._table_handler.consolidate(meta=self.meta) diff --git a/src/ngio/tables/v1/masking_roi_tables.py b/src/ngio/tables/v1/masking_roi_tables.py index 20f2b79..475b282 100644 --- a/src/ngio/tables/v1/masking_roi_tables.py +++ b/src/ngio/tables/v1/masking_roi_tables.py @@ -84,7 +84,7 @@ def _new( label_image: str, instance_key: str = "label", overwrite: bool = False, - ): + ) -> "MaskingROITableV1": """Create a new Masking ROI table. Note this method is not meant to be called directly. @@ -133,18 +133,18 @@ def table(self) -> pd.DataFrame: return self._table_handler.table @table.setter - def table(self, table: pd.DataFrame): + def table(self, table: pd.DataFrame) -> None: """Set the feature table.""" raise NotImplementedError( "Setting the table is not implemented. Please use the 'set_table' method." ) - def set_table(self, table: pd.DataFrame): + def set_table(self, table: pd.DataFrame) -> None: """Set the feature table.""" self._table_handler.set_table(table) @property - def list_labels(self) -> list[str]: + def list_labels(self) -> list[int]: """Return a list of all field indexes in the table.""" return self.table.index.tolist() @@ -192,6 +192,6 @@ def rois(self) -> list[WorldCooROI]: """List all ROIs in the table.""" return [self.get_roi(label) for label in self.list_labels] - def write(self) -> None: + def consolidate(self) -> None: """Write the crrent state of the table to the Zarr file.""" - self._table_handler.write(self.meta) + self._table_handler.consolidate(self.meta) diff --git a/src/ngio/tables/v1/roi_tables.py b/src/ngio/tables/v1/roi_tables.py index 996d120..09bca7d 100644 --- a/src/ngio/tables/v1/roi_tables.py +++ b/src/ngio/tables/v1/roi_tables.py @@ -104,7 +104,7 @@ def _new( include_origin: bool = False, include_translation: bool = False, overwrite: bool = False, - ): + ) -> "ROITableV1": """Create a new ROI table. Note this method is not meant to be called directly. @@ -152,13 +152,13 @@ def table(self) -> pd.DataFrame: return self._table_handler.table @table.setter - def table(self, table: pd.DataFrame): + def table(self, table: pd.DataFrame) -> None: """Set the feature table.""" raise NotImplementedError( "Setting the table is not implemented. Please use the 'set_table' method." ) - def set_table(self, table: pd.DataFrame): + def set_table(self, table: pd.DataFrame) -> None: """Set the feature table.""" self._table_handler.set_table(table) @@ -222,7 +222,7 @@ def _gater_optional_columns(self, series: pd.Series) -> dict: optional_dict[column] = series[column] return optional_dict - def get_roi(self, field_index) -> WorldCooROI: + def get_roi(self, field_index: str) -> WorldCooROI: """Get an ROI from the table.""" if field_index not in self.field_indexes: raise ValueError(f"Field index {field_index} is not in the table") @@ -245,6 +245,6 @@ def rois(self) -> list[WorldCooROI]: """List all ROIs in the table.""" return [self.get_roi(field_index) for field_index in self.field_indexes] - def write(self) -> None: + def consolidate(self) -> None: """Write the crrent state of the table to the Zarr file.""" - self._table_handler.write(self.meta) + self._table_handler.consolidate(self.meta) From ff60b9a1d6dbc422e50a5258b9a93fe67a02d294 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Wed, 6 Nov 2024 16:57:00 +0100 Subject: [PATCH 2/5] add zoom testing --- src/ngio/pipes/__init__.py | 2 +- .../pipes/{_zomm_utils.py => _zoom_utils.py} | 35 +++++---- tests/pipes/conftest.py | 46 +++++++++++ tests/pipes/test_zoom.py | 77 +++++++++++++++++++ 4 files changed, 145 insertions(+), 15 deletions(-) rename src/ngio/pipes/{_zomm_utils.py => _zoom_utils.py} (88%) create mode 100644 tests/pipes/conftest.py create mode 100644 tests/pipes/test_zoom.py diff --git a/src/ngio/pipes/__init__.py b/src/ngio/pipes/__init__.py index bd62231..1a69ba0 100644 --- a/src/ngio/pipes/__init__.py +++ b/src/ngio/pipes/__init__.py @@ -1,7 +1,7 @@ """A module to handle data transforms for image data.""" from ngio.pipes._slicer_transforms import NaiveSlicer, RoiSlicer -from ngio.pipes._zomm_utils import on_disk_zoom +from ngio.pipes._zoom_utils import on_disk_zoom from ngio.pipes.data_pipe import DataTransformPipe __all__ = ["DataTransformPipe", "NaiveSlicer", "RoiSlicer", "on_disk_zoom"] diff --git a/src/ngio/pipes/_zomm_utils.py b/src/ngio/pipes/_zoom_utils.py similarity index 88% rename from src/ngio/pipes/_zomm_utils.py rename to src/ngio/pipes/_zoom_utils.py index 59fb763..027b66c 100644 --- a/src/ngio/pipes/_zomm_utils.py +++ b/src/ngio/pipes/_zoom_utils.py @@ -20,11 +20,12 @@ def _zoom_inputs_check( if scale is None: assert target_shape is not None, "Target shape must be provided" - assert len(target_shape) == source_array.ndim, ( - "Target shape must have the " - "same number of dimensions as " - "the source array" - ) + if len(target_shape) != source_array.ndim: + raise ValueError( + "Target shape must have the " + "same number of dimensions as " + "the source array" + ) _scale = np.array(target_shape) / np.array(source_array.shape) _target_shape = target_shape else: @@ -142,6 +143,7 @@ def on_disk_zoom( if mode == "numpy": target[...] = _numpy_zoom(source[...], target_shape=target.shape, order=order) + return None source_array = da.from_zarr(source) target_array = _dask_zoom(source_array, target_shape=target.shape, order=order) @@ -154,7 +156,6 @@ def on_disk_coarsen( source: zarr.Array, target: zarr.Array, aggregation_function: np.ufunc, - coarsening_setup: dict[int, int], ) -> None: """Apply a coarsening operation from a source zarr array to a target zarr array. @@ -162,19 +163,25 @@ def on_disk_coarsen( source (zarr.Array): The source array to coarsen. target (zarr.Array): The target array to save the coarsened result to. aggregation_function (np.ufunc): The aggregation function to use. - coarsening_setup (dict[int, int]): The coarsening setup to use. """ source_array = da.from_zarr(source) - for ax, factor in coarsening_setup.items(): - if ax >= source_array.ndim: + _scale, _target_shape = _zoom_inputs_check( + source_array=source_array, scale=None, target_shape=target.shape + ) + + assert ( + _target_shape == target.shape + ), "Target shape must match the target array shape" + coarsening_setup = {} + for i, s in enumerate(_scale): + factor = 1 / s + if factor.is_integer(): + coarsening_setup[i] = int(factor) + else: raise ValueError( - "Coarsening axis must be less than the number of dimensions" + "Coarsening factor must be an integer, got " f"{factor} on axis {i}" ) - if factor <= 0: - raise ValueError("Coarsening factor must be greater than 0") - - assert isinstance(factor, int), "Coarsening factor must be an integer" out_target = da.coarsen( aggregation_function, source_array, coarsening_setup, trim_excess=True diff --git a/tests/pipes/conftest.py b/tests/pipes/conftest.py new file mode 100644 index 0000000..d2001d2 --- /dev/null +++ b/tests/pipes/conftest.py @@ -0,0 +1,46 @@ +# create a zarr 3D array fixture +from pathlib import Path + +import numpy as np +import pytest +import zarr + + +@pytest.fixture +def zarr_zoom_3d_array(tmp_path: Path) -> tuple[zarr.Array, zarr.Array]: + source = zarr.zeros((3, 64, 64), store=tmp_path / "test_3d_s.zarr") + source[...] = np.random.rand(3, 64, 64) + target = zarr.zeros((3, 32, 32), store=tmp_path / "test_3d_t.zarr") + return source, target + + +@pytest.fixture +def zarr_zoom_2d_array(tmp_path: Path) -> tuple[zarr.Array, zarr.Array]: + source = zarr.zeros((64, 64), store=tmp_path / "test_2d_s.zarr") + source[...] = np.random.rand(64, 64) + target = zarr.zeros((32, 32), store=str(tmp_path / "test_2d_t.zarr")) + return source, target + + +@pytest.fixture +def zarr_zoom_4d_array(tmp_path: Path) -> tuple[zarr.Array, zarr.Array]: + source = zarr.zeros((3, 3, 64, 64), store=tmp_path / "test_4d_s.zarr") + source[...] = np.random.rand(3, 3, 64, 64) + target = zarr.zeros((3, 3, 32, 32), store=tmp_path / "test_4d_t.zarr") + return source, target + + +@pytest.fixture +def zarr_zoom_2d_array_not_int(tmp_path: Path) -> tuple[zarr.Array, zarr.Array]: + source = zarr.zeros((64, 64), store=tmp_path / "test_2d_s.zarr") + source[...] = np.random.rand(64, 64) + target = zarr.zeros((30, 30), store=str(tmp_path / "test_2d_t.zarr")) + return source, target + + +@pytest.fixture +def zarr_zoom_3d_array_shape_mismatch(tmp_path: Path) -> tuple[zarr.Array, zarr.Array]: + source = zarr.zeros((3, 3, 64, 64), store=tmp_path / "test_3d_s.zarr") + source[...] = np.random.rand(3, 3, 64, 64) + target = zarr.zeros((3, 32, 32), store=tmp_path / "test_3d_t.zarr") + return source, target diff --git a/tests/pipes/test_zoom.py b/tests/pipes/test_zoom.py new file mode 100644 index 0000000..4d37b07 --- /dev/null +++ b/tests/pipes/test_zoom.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest +import zarr + + +class TestZoom: + def _test_zoom( + self, source: zarr.Array, target: zarr.Array, order: int = 1, mode: str = "dask" + ) -> None: + from ngio.pipes import on_disk_zoom + + on_disk_zoom(source, target, order=order, mode=mode) + + def test_zoom_3d(self, zarr_zoom_3d_array: tuple[zarr.Array, zarr.Array]) -> None: + source, target = zarr_zoom_3d_array + + for mode in ["dask", "numpy"]: + for order in [0, 1, 2]: + self._test_zoom(source, target, order=order, mode=mode) + + def test_zoom_2d(self, zarr_zoom_2d_array: tuple[zarr.Array, zarr.Array]) -> None: + source, target = zarr_zoom_2d_array + self._test_zoom(source, target) + + def test_zoom_4d(self, zarr_zoom_4d_array: tuple[zarr.Array, zarr.Array]) -> None: + source, target = zarr_zoom_4d_array + self._test_zoom(source, target) + + def test_zoom_3d_fail( + self, zarr_zoom_3d_array_shape_mismatch: tuple[zarr.Array, zarr.Array] + ) -> None: + source, target = zarr_zoom_3d_array_shape_mismatch + with pytest.raises(ValueError): + self._test_zoom(source, target) + + with pytest.raises(ValueError): + self._test_zoom(source, target[...]) + + with pytest.raises(ValueError): + self._test_zoom(source[...], target) + + with pytest.raises(ValueError): + _target2 = target.astype("float32") + self._test_zoom(source, _target2) + + with pytest.raises(AssertionError): + self._test_zoom(source, target, mode="not_a_mode") + + def _test_coarsen(self, source: zarr.Array, target: zarr.Array) -> None: + from ngio.pipes._zoom_utils import on_disk_coarsen + + on_disk_coarsen(source, target, aggregation_function=np.mean) + + def test_coarsen_3d( + self, zarr_zoom_3d_array: tuple[zarr.Array, zarr.Array] + ) -> None: + source, target = zarr_zoom_3d_array + self._test_coarsen(source, target) + + def test_coarsen_2d( + self, zarr_zoom_2d_array: tuple[zarr.Array, zarr.Array] + ) -> None: + source, target = zarr_zoom_2d_array + self._test_coarsen(source, target) + + def test_coarsen_4d( + self, zarr_zoom_4d_array: tuple[zarr.Array, zarr.Array] + ) -> None: + source, target = zarr_zoom_4d_array + self._test_coarsen(source, target) + + def test_coarsen_2d_fail( + self, zarr_zoom_2d_array_not_int: tuple[zarr.Array, zarr.Array] + ) -> None: + source, target = zarr_zoom_2d_array_not_int + with pytest.raises(ValueError): + self._test_coarsen(source, target) From 9e90980ca9f55a68289091bfa497c4f294557d3d Mon Sep 17 00:00:00 2001 From: lorenzo Date: Thu, 7 Nov 2024 09:56:52 +0100 Subject: [PATCH 3/5] add pixel_size testing --- src/ngio/ngff_meta/fractal_image_meta.py | 10 ++++++--- tests/ngff_meta/test_pixel_size.py | 27 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 tests/ngff_meta/test_pixel_size.py diff --git a/src/ngio/ngff_meta/fractal_image_meta.py b/src/ngio/ngff_meta/fractal_image_meta.py index fe43273..d090076 100644 --- a/src/ngio/ngff_meta/fractal_image_meta.py +++ b/src/ngio/ngff_meta/fractal_image_meta.py @@ -105,7 +105,9 @@ def __str__(self) -> str: return f"PixelSize(x={self.x}, y={self.y}, z={self.z}, unit={self.unit.value})" @classmethod - def from_list(cls, sizes: list[float], unit: SpaceUnits) -> "PixelSize": + def from_list( + cls, sizes: list[float], unit: SpaceUnits = SpaceUnits.micrometer + ) -> "PixelSize": """Build a PixelSize object from a list of sizes. Note: The order of the sizes must be z, y, x. @@ -115,7 +117,7 @@ def from_list(cls, sizes: list[float], unit: SpaceUnits) -> "PixelSize": unit(SpaceUnits): The unit of the sizes. """ if len(sizes) == 2: - return cls(y=sizes[0], x=sizes[1], unit=unit) + return cls(y=sizes[0], x=sizes[1], z=1, unit=unit) elif len(sizes) == 3: return cls(z=sizes[0], y=sizes[1], x=sizes[2], unit=unit) else: @@ -135,10 +137,12 @@ def yx(self) -> tuple: """Return the xy plane pixel size in y, x order.""" return self.y, self.x + @property def voxel_volume(self) -> float: """Return the volume of a voxel.""" - return self.y * self.x * (self.z or 1) + return self.y * self.x * self.z + @property def xy_plane_area(self) -> float: """Return the area of the xy plane.""" return self.y * self.x diff --git a/tests/ngff_meta/test_pixel_size.py b/tests/ngff_meta/test_pixel_size.py new file mode 100644 index 0000000..80db6b6 --- /dev/null +++ b/tests/ngff_meta/test_pixel_size.py @@ -0,0 +1,27 @@ +import pytest + + +class TestPixelSize: + def test_pixel_size_from_list(self) -> None: + from ngio.ngff_meta import PixelSize + + pix_size_2d = PixelSize.from_list([0.1625, 0.1625]) + assert pix_size_2d.zyx == (1.0, 0.1625, 0.1625) + + pix_size_3d = PixelSize.from_list([0.1625, 0.1625, 0.1625]) + assert pix_size_3d.zyx == (0.1625, 0.1625, 0.1625) + + with pytest.raises(ValueError): + PixelSize.from_list([0.1625, 0.1625, 0.1625, 0.1625]) + + def test_pixel_size(self) -> None: + from ngio.ngff_meta import PixelSize + + pixel_size = PixelSize(x=0.1625, y=0.1625, z=0.25) + assert pixel_size.zyx == (0.25, 0.1625, 0.1625) + assert pixel_size.yx == (0.1625, 0.1625) + assert pixel_size.voxel_volume == 0.1625 * 0.1625 * 0.25 + assert pixel_size.xy_plane_area == 0.1625 * 0.1625 + + plixel_size2 = PixelSize(x=0.1625, y=0.1625, z=0.5) + assert pixel_size.distance(plixel_size2) == 0.25 From 3422c174dc2d2f1c859306827479750cd69fc042 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Thu, 7 Nov 2024 14:54:14 +0100 Subject: [PATCH 4/5] derive api cleanup --- src/ngio/core/dimensions.py | 51 +-- src/ngio/core/image_like_handler.py | 4 + src/ngio/core/label_handler.py | 73 ++-- src/ngio/core/ngff_image.py | 4 +- src/ngio/core/utils.py | 23 +- src/ngio/ngff_meta/__init__.py | 8 - src/ngio/ngff_meta/fractal_image_meta.py | 417 ++++++++++++--------- src/ngio/ngff_meta/utils.py | 91 ++--- tests/core/test_image_like_handler.py | 4 +- tests/ngff_meta/test_fractal_image_meta.py | 34 -- tests/ngff_meta/test_utils.py | 10 +- 11 files changed, 362 insertions(+), 357 deletions(-) diff --git a/src/ngio/core/dimensions.py b/src/ngio/core/dimensions.py index f5ea317..0d9d0f9 100644 --- a/src/ngio/core/dimensions.py +++ b/src/ngio/core/dimensions.py @@ -4,6 +4,8 @@ but it is based on the actual metadata of the image data. """ +from collections import OrderedDict + class Dimensions: """Dimension metadata.""" @@ -37,7 +39,7 @@ def __init__( self._axes_order = axes_order self._shape = [self._on_disk_shape[i] for i in axes_order] - self._shape_dict = dict(zip(axes_names, self._shape, strict=True)) + self._shape_dict = OrderedDict(zip(axes_names, self._shape, strict=True)) def __str__(self) -> str: """Return the string representation of the object.""" @@ -64,33 +66,6 @@ def ad_dict(self) -> dict[str, int]: """Return the shape as a dictionary.""" return self._shape_dict - @property - def t(self) -> int | None: - """Return the time dimension.""" - return self._shape_dict.get("t", None) - - @property - def c(self) -> int | None: - """Return the channel dimension.""" - return self._shape_dict.get("c", None) - - @property - def z(self) -> int | None: - """Return the z dimension.""" - return self._shape_dict.get("z", None) - - @property - def y(self) -> int: - """Return the y dimension.""" - assert "y" in self._shape_dict - return self._shape_dict["y"] - - @property - def x(self) -> int: - """Return the x dimension.""" - assert "x" in self._shape_dict - return self._shape_dict["x"] - def get(self, ax_name: str, default: int = 1) -> int: """Return the dimension of the given axis name.""" return self._shape_dict.get(ax_name, default) @@ -103,14 +78,16 @@ def on_disk_ndim(self) -> int: @property def is_time_series(self) -> bool: """Return whether the data is a time series.""" - if (self.t is None) or (self.t == 1): + t = self._shape_dict.get("t", 1) + if t == 1: return False return True @property def is_2d(self) -> bool: """Return whether the data is 2D.""" - if (self.z is not None) and (self.z > 1): + z = self._shape_dict.get("z", 1) + if z != 1: return False return True @@ -122,9 +99,7 @@ def is_2d_time_series(self) -> bool: @property def is_3d(self) -> bool: """Return whether the data is 3D.""" - if (self.z is None) or (self.z == 1): - return False - return True + return not self.is_2d @property def is_3d_time_series(self) -> bool: @@ -134,6 +109,14 @@ def is_3d_time_series(self) -> bool: @property def is_multi_channels(self) -> bool: """Return whether the data has multiple channels.""" - if (self.c is None) or (self.c == 1): + c = self._shape_dict.get("c", 1) + if c == 1: return False return True + + def find_axis(self, ax_name: str) -> int | None: + """Return the index of the axis name.""" + for i, ax in enumerate(self._axes_names): + if ax == ax_name: + return i + return None diff --git a/src/ngio/core/image_like_handler.py b/src/ngio/core/image_like_handler.py index b65df5e..aeeac44 100644 --- a/src/ngio/core/image_like_handler.py +++ b/src/ngio/core/image_like_handler.py @@ -227,6 +227,10 @@ def on_disk_shape(self) -> tuple[int, ...]: """Return the shape of the image.""" return self.dimensions.on_disk_shape + def find_axis(self, axis_name: str) -> int | None: + """Return the index of the given axis name.""" + return self.dimensions.find_axis(axis_name) + # Methods to get the image data in the canonical order def init_lock(self, lock_id: str | None = None) -> None: """Set the lock for the Dask array.""" diff --git a/src/ngio/core/label_handler.py b/src/ngio/core/label_handler.py index b694ee4..07cf3fd 100644 --- a/src/ngio/core/label_handler.py +++ b/src/ngio/core/label_handler.py @@ -1,5 +1,6 @@ """A module to handle OME-NGFF images stored in Zarr format.""" +import builtins from typing import Any, Literal import zarr @@ -285,6 +286,8 @@ def get_label( def derive( self, name: str, + reference: ImageLike | None = None, + levels: int | builtins.list[str] = 5, overwrite: bool = False, **kwargs: dict, ) -> Label: @@ -292,6 +295,9 @@ def derive( Args: name (str): The name of the new label. + reference (ImageLike | None): The reference image to use for the new label. + levels (int | list[str]): The number of levels to create or + a list of paths names. overwrite (bool): If True, the label will be overwritten if it exists. Default is False. **kwargs: Additional keyword arguments to pass to the new label. @@ -311,44 +317,46 @@ def derive( # create the new label new_label_group = self._label_group.create_group(name, overwrite=overwrite) - if self._image_ref is None: - label_0 = self.get_label(list_of_labels[0]) - metadata = label_0.metadata - on_disk_shape = label_0.on_disk_shape - chunks = label_0.on_disk_array.chunks - dataset = label_0.dataset - else: - label_0 = self._image_ref - metadata = label_0.metadata - channel_index = metadata.index_mapping.get("c", None) - if channel_index is not None: - on_disk_shape = ( - label_0.on_disk_shape[:channel_index] - + label_0.on_disk_shape[channel_index + 1 :] - ) - chunks = ( - label_0.on_disk_array.chunks[:channel_index] - + label_0.on_disk_array.chunks[channel_index + 1 :] - ) - else: - on_disk_shape = label_0.on_disk_shape - chunks = label_0.on_disk_array.chunks + ref_0 = self._image_ref if reference is None else reference + assert isinstance(ref_0, ImageLike) + + if isinstance(levels, int): + paths = [str(i) for i in range(levels)] + elif isinstance(levels, list): + if not all(isinstance(level, str) for level in levels): + raise ValueError(f"All levels must be strings. Got: {levels}") + paths = levels + + on_disk_ch_index = ref_0.find_axis("c") + metadata = ref_0.metadata + if on_disk_ch_index is None: + on_disk_shape = ref_0.on_disk_shape + chunks = ref_0.on_disk_array.chunks + else: metadata = metadata.remove_axis("c") - dataset = metadata.get_highest_resolution_dataset() + on_disk_shape = ( + ref_0.on_disk_shape[:on_disk_ch_index] + + ref_0.on_disk_shape[on_disk_ch_index + 1 :] + ) + chunks = ( + ref_0.on_disk_array.chunks[:on_disk_ch_index] + + ref_0.on_disk_array.chunks[on_disk_ch_index + 1 :] + ) + dataset = metadata.get_dataset(path=paths[0]) default_kwargs = { "store": new_label_group, "shape": on_disk_shape, "chunks": chunks, - "dtype": label_0.on_disk_array.dtype, + "dtype": ref_0.on_disk_array.dtype, "on_disk_axis": dataset.on_disk_axes_names, "pixel_sizes": dataset.pixel_size, "xy_scaling_factor": metadata.xy_scaling_factor, "z_scaling_factor": metadata.z_scaling_factor, "time_spacing": dataset.time_spacing, "time_units": dataset.time_axis_unit, - "num_levels": metadata.num_levels, + "levels": paths, "name": name, "overwrite": overwrite, "version": metadata.version, @@ -357,7 +365,20 @@ def derive( default_kwargs.update(kwargs) create_empty_ome_zarr_label( - **default_kwargs, + store=new_label_group, + shape=on_disk_shape, + chunks=chunks, + dtype=ref_0.on_disk_array.dtype, + on_disk_axis=dataset.on_disk_axes_names, + pixel_sizes=dataset.pixel_size, + xy_scaling_factor=metadata.xy_scaling_factor, + z_scaling_factor=metadata.z_scaling_factor, + time_spacing=dataset.time_spacing, + time_units=dataset.time_axis_unit, + levels=paths, + name=name, + overwrite=overwrite, + version=metadata.version, ) if name not in self.list(): diff --git a/src/ngio/core/ngff_image.py b/src/ngio/core/ngff_image.py index 38bbf78..951e358 100644 --- a/src/ngio/core/ngff_image.py +++ b/src/ngio/core/ngff_image.py @@ -178,7 +178,7 @@ def derive_new_image( default_kwargs = { "store": store, - "shape": image_0.on_disk_shape, + "on_disk_shape": image_0.on_disk_shape, "chunks": image_0.on_disk_array.chunks, "dtype": image_0.on_disk_array.dtype, "on_disk_axis": image_0.dataset.on_disk_axes_names, @@ -187,7 +187,7 @@ def derive_new_image( "z_scaling_factor": self.image_meta.z_scaling_factor, "time_spacing": image_0.dataset.time_spacing, "time_units": image_0.dataset.time_axis_unit, - "num_levels": self.num_levels, + "levels": self.num_levels, "name": name, "channel_labels": image_0.channel_labels, "channel_wavelengths": [ch.wavelength_id for ch in channels], diff --git a/src/ngio/core/utils.py b/src/ngio/core/utils.py index aecd84f..aeb0ff6 100644 --- a/src/ngio/core/utils.py +++ b/src/ngio/core/utils.py @@ -74,16 +74,17 @@ def _build_empty_pyramid( def create_empty_ome_zarr_image( store: StoreLike, - shape: Collection[int], + on_disk_shape: Collection[int], + on_disk_axis: Collection[str] = ("t", "c", "z", "y", "x"), chunks: Collection[int] | None = None, dtype: str = "uint16", - on_disk_axis: Collection[str] = ("t", "c", "z", "y", "x"), pixel_sizes: PixelSize | None = None, xy_scaling_factor: float = 2.0, z_scaling_factor: float = 1.0, time_spacing: float = 1.0, time_units: TimeUnits | str = TimeUnits.s, - num_levels: int = 5, + levels: int | list[str] = 5, + path_names: list[str] | None = None, name: str | None = None, channel_labels: list[str] | None = None, channel_wavelengths: list[str] | None = None, @@ -93,16 +94,16 @@ def create_empty_ome_zarr_image( version: str = "0.4", ) -> None: """Create an empty OME-Zarr image with the given shape and metadata.""" - if len(shape) != len(on_disk_axis): + if len(on_disk_shape) != len(on_disk_axis): raise ValueError( "The number of dimensions in the shape must match the number of " "axes in the on-disk axis." ) if "c" in on_disk_axis: - shape = tuple(shape) + on_disk_shape = tuple(on_disk_shape) on_disk_axis = tuple(on_disk_axis) - num_channels = shape[on_disk_axis.index("c")] + num_channels = on_disk_shape[on_disk_axis.index("c")] if channel_labels is None: channel_labels = [f"C{i:02d}" for i in range(num_channels)] else: @@ -120,7 +121,7 @@ def create_empty_ome_zarr_image( z_scaling_factor=z_scaling_factor, time_spacing=time_spacing, time_units=time_units, - num_levels=num_levels, + levels=levels, name=name, channel_labels=channel_labels, channel_wavelengths=channel_wavelengths, @@ -141,7 +142,7 @@ def create_empty_ome_zarr_image( _build_empty_pyramid( group=group, image_meta=image_meta, - shape=shape, + shape=on_disk_shape, chunks=chunks, dtype=dtype, on_disk_axis=on_disk_axis, @@ -160,8 +161,8 @@ def create_empty_ome_zarr_label( xy_scaling_factor: float = 2.0, z_scaling_factor: float = 1.0, time_spacing: float = 1.0, - time_units: TimeUnits | str = TimeUnits.s, - num_levels: int = 5, + time_units: TimeUnits | str | None = None, + levels: int | list[str] = 5, name: str | None = None, overwrite: bool = True, version: str = "0.4", @@ -180,7 +181,7 @@ def create_empty_ome_zarr_label( z_scaling_factor=z_scaling_factor, time_spacing=time_spacing, time_units=time_units, - num_levels=num_levels, + levels=levels, name=name, version=version, ) diff --git a/src/ngio/ngff_meta/__init__.py b/src/ngio/ngff_meta/__init__.py index 0a9b67a..82bfb31 100644 --- a/src/ngio/ngff_meta/__init__.py +++ b/src/ngio/ngff_meta/__init__.py @@ -10,12 +10,8 @@ ) from ngio.ngff_meta.meta_handler import get_ngff_image_meta_handler from ngio.ngff_meta.utils import ( - add_axis_to_metadata, create_image_metadata, create_label_metadata, - derive_image_metadata, - derive_label_metadata, - remove_axis_from_metadata, ) __all__ = [ @@ -26,10 +22,6 @@ "PixelSize", "SpaceUnits", "get_ngff_image_meta_handler", - "add_axis_to_metadata", "create_image_metadata", "create_label_metadata", - "derive_image_metadata", - "derive_label_metadata", - "remove_axis_from_metadata", ] diff --git a/src/ngio/ngff_meta/fractal_image_meta.py b/src/ngio/ngff_meta/fractal_image_meta.py index d090076..2f6fc0a 100644 --- a/src/ngio/ngff_meta/fractal_image_meta.py +++ b/src/ngio/ngff_meta/fractal_image_meta.py @@ -6,7 +6,9 @@ can be converted to the OME standard. """ +from collections.abc import Collection from enum import Enum +from typing import Any import numpy as np from pydantic import BaseModel, Field @@ -15,6 +17,50 @@ from ngio.utils._pydantic_utils import BaseWithExtraFields +class NgffVersion(str, Enum): + """Allowed NGFF versions.""" + + v04 = "0.4" + + +################################################################################################ +# +# Omero Section of the Metadata is used to store channel information and visualisation +# settings. +# This section is transitory and will be likely changed in the future. +# +################################################################################################# + + +class Window(BaseModel): + """Window model to be used by the Viewer.""" + + min: int | float + max: int | float + start: int | float + end: int | float + + @classmethod + def from_type(cls, data_type: str) -> "Window": + """Create a Window object from a window type.""" + type_info = np.iinfo(data_type) + return cls( + min=type_info.min, max=type_info.max, start=type_info.min, end=type_info.max + ) + + +class ChannelVisualisation(BaseWithExtraFields): + """Channel visualisation model. + + Contains the information about the visualisation of a channel. + """ + + color: str + window: Window + active: bool = True + inverted: bool = False + + class Channel(BaseWithExtraFields): """Information about a channel in the image. @@ -28,6 +74,25 @@ class Channel(BaseWithExtraFields): label: str wavelength_id: str | None = None + @classmethod + def lazy_init( + cls, + label: str, + wavelength_id: str | None = None, + color: str = "00FFFF", + data_type: Any = np.uint16, + ) -> "Channel": + """Create a Channel object with the default unit.""" + channel_visualization = ChannelVisualisation( + color=color, window=Window.from_type(data_type) + ) + + return cls( + label=label, + wavelength_id=wavelength_id, + **channel_visualization.model_dump(), + ) + class Omero(BaseWithExtraFields): """Information about the OMERO metadata. @@ -41,6 +106,16 @@ class Omero(BaseWithExtraFields): channels: list[Channel] = Field(default_factory=list) +################################################################################################ +# +# Axis Types and Units +# We define a small set of axis types and units that can be used in the metadata. +# This axis types are more restrictive than the OME standard. +# We do that to simplify the data processing. +# +################################################################################################# + + class AxisType(str, Enum): """Allowed axis types.""" @@ -91,6 +166,38 @@ def allowed_names(self) -> list[str]: return list(ChannelNames.__members__.keys()) +class TimeUnits(str, Enum): + """Allowed time units.""" + + seconds = "seconds" + s = "s" + + @classmethod + def allowed_names(self) -> list[str]: + """Get the allowed time axis names.""" + return list(TimeUnits.__members__.keys()) + + +class TimeNames(str, Enum): + """Allowed time axis names.""" + + t = "t" + + @classmethod + def allowed_names(self) -> list[str]: + """Get the allowed time axis names.""" + return list(TimeNames.__members__.keys()) + + +################################################################################################ +# +# PixelSize model +# The PixelSize model is used to store the pixel size in 3D space. +# The model does not store scaling factors and units for other axes. +# +################################################################################################# + + class PixelSize(BaseModel): """PixelSize class to store the pixel size in 3D space.""" @@ -149,30 +256,24 @@ def xy_plane_area(self) -> float: def distance(self, other: "PixelSize") -> float: """Return the distance between two pixel sizes.""" - return np.linalg.norm(np.array(self.zyx) - np.array(other.zyx)) - - -class TimeUnits(str, Enum): - """Allowed time units.""" - - seconds = "seconds" - s = "s" - - @classmethod - def allowed_names(self) -> list[str]: - """Get the allowed time axis names.""" - return list(TimeUnits.__members__.keys()) - - -class TimeNames(str, Enum): - """Allowed time axis names.""" - - t = "t" - - @classmethod - def allowed_names(self) -> list[str]: - """Get the allowed time axis names.""" - return list(TimeNames.__members__.keys()) + return float(np.linalg.norm(np.array(self.zyx) - np.array(other.zyx))) + + +################################################################################################ +# +# Axis and Dataset models are the two core components of the OME-NFF +# multiscale metadata. +# The Axis model is used to store the information about an axis (name, unit, type). +# The Dataset model is used to store the information about a +# dataset (path, axes, scale). +# +# The Dataset and Axis have two representations: +# - on_disk: The representation of the metadata as stored on disk. This representation +# preserves the order of the axes and the scale transformation. +# - canonical: The representation of the metadata in the canonical order. +# This representation is used to simplify the data processing. +# +################################################################################################# class Axis: @@ -196,6 +297,7 @@ def __init__( name = name.value self._name = name + self._unit = unit if name in TimeNames.allowed_names(): self._type = AxisType.time @@ -243,7 +345,7 @@ def lazy_create( @classmethod def batch_create( cls, - axes_names: list[str | SpaceNames | TimeNames], + axes_names: Collection[str | SpaceNames | TimeNames], time_unit: TimeUnits | None = None, space_unit: SpaceUnits | None = None, ) -> list["Axis"]: @@ -277,7 +379,9 @@ def type(self) -> AxisType: def model_dump(self) -> dict: """Return the axis information in a dictionary.""" - return {"name": self.name, "unit": self.unit, "type": self.type} + _dict = {"name": self.name, "unit": self.unit, "type": self.type} + # Remove None values + return {k: v for k, v in _dict.items() if v is not None} class Dataset: @@ -356,9 +460,13 @@ def __init__( # Compute the index mapping between the canonical order and the actual order _map = {ax.name: i for i, ax in enumerate(on_disk_axes)} - self._index_mapping = { - name: _map.get(name, None) for name in self._canonical_order - } + + self._index_mapping = {} + for name in self._canonical_order: + _index = _map.get(name, None) + if _index is not None: + self._index_mapping[name] = _index + self._ordered_axes = [ on_disk_axes[i] for i in self._index_mapping.values() if i is not None ] @@ -397,12 +505,13 @@ def axes_order(self) -> list[int]: return [on_disk_axes.index(ax) for ax in canonical_axes] @property - def reverse_axes_order(self) -> list[str]: + def reverse_axes_order(self) -> list[int]: """Get the mapping between the on-disk order and the canonical order. It is the inverse of the axes_order. """ - return np.argsort(self.axes_order) + sorted_order = np.argsort(self.axes_order).tolist() + return sorted_order # type: ignore @property def scale(self) -> list[float]: @@ -412,10 +521,11 @@ def scale(self) -> list[float]: @property def time_spacing(self) -> float: """Get the time spacing of the dataset.""" - if "t" not in self.axes_names: + t = self.index_mapping.get("t") + if t is None: return 1.0 - scale_t = self.scale[self.index_mapping.get("t")] + scale_t = self.scale[t] return scale_t @property @@ -446,7 +556,14 @@ def space_axes_unit(self) -> SpaceUnits: types = [ax.unit for ax in self.axes if ax.type == AxisType.space] if len(set(types)) > 1: raise ValueError("Inconsistent spatial axes units.") - return types[0] + return_type = types[0] + if return_type is None: + raise ValueError("Spatial axes must have a unit.") + if return_type not in SpaceUnits.allowed_names(): + raise ValueError(f"Invalid space unit {return_type}.") + if isinstance(return_type, str): + return_type = SpaceUnits(return_type) + return return_type @property def pixel_size(self) -> PixelSize: @@ -457,15 +574,24 @@ def pixel_size(self) -> PixelSize: if ax.type == AxisType.space: pixel_sizes[ax.name] = scale - return PixelSize(**pixel_sizes, unit=self.space_axes_unit) + return PixelSize( + x=pixel_sizes["x"], + y=pixel_sizes["y"], + z=pixel_sizes.get("z", 1.0), + unit=self.space_axes_unit, + ) @property def time_axis_unit(self) -> TimeUnits | None: """Get the unit of the time axis.""" types = [ax.unit for ax in self.axes if ax.type == AxisType.time] - if len(set(types)) > 1: - raise ValueError("Inconsistent time axis units.") - return types[0] if types else None + if len(types) == 0: + return None + elif len(types) == 1: + assert isinstance(types[0], TimeUnits) + return types[0] + else: + raise ValueError("Multiple time axes found. Only one time axis is allowed.") def remove_axis(self, axis_name: str) -> "Dataset": """Remove an axis from the dataset. @@ -501,82 +627,20 @@ def remove_axis(self, axis_name: str) -> "Dataset": canonical_order=self._canonical_order, ) - def add_axis( - self, axis_name: str, scale: float = 1.0, translation: float | None = None - ) -> "Dataset": - """Add an axis to the dataset. - - Args: - axis_name(str): The name of the axis to add. - scale(float): The scale of the axis. - translation(float | None): The translation of the axis. - """ - if axis_name in self.axes_names: - raise ValueError(f"Axis {axis_name} already exists in the dataset.") - - axis = Axis.lazy_create( - name=axis_name, - space_unit=self.space_axes_unit, - time_unit=self.time_axis_unit, - ) - - new_on_disk_axes = self._on_disk_axes.copy() - new_on_disk_axes.append(axis) - - new_scale = self._scale.copy() - new_scale.append(scale) - - if self._translation is not None: - new_translation = self._translation.copy() - new_translation.append(translation) - else: - new_translation = None - - return Dataset( - path=self.path, - on_disk_axes=new_on_disk_axes, - on_disk_scale=new_scale, - on_disk_translation=new_translation, - canonical_order=self._canonical_order, - ) - - def to_canonical_order(self) -> "Dataset": - """Return a new Dataset where the axes are in the canonical order.""" - new_axes = self._ordered_axes - new_scale = self.scale - new_translation = self.translation - return Dataset( - path=self.path, - on_disk_axes=new_axes, - on_disk_scale=new_scale, - on_disk_translation=new_translation, - ) - - def on_disk_model_dump(self) -> dict: - """Return the dataset information in the on_disk order.""" - return { - "path": self.path, - "axes": [ax.model_dump(exclude_none=True) for ax in self._on_disk_axes], - "scale": self._scale, - "translation": self._translation, - } - - def ordered_model_dump(self) -> dict: - """Return the dataset information in the canonical order.""" - return { - "path": self.path, - "axes": [ax.model_dump(exclude_none=True) for ax in self.axes], - "scale": self.scale, - "translation": self.translation, - } - +################################################################################################ +# +# BaseMeta, ImageMeta and LabelMeta are the core models to represent the multiscale the +# OME-NGFF spec on memory. The are the only interfaces to interact with +# the metadata on-disk and the metadata in memory. +# +################################################################################################# class BaseMeta: """Base class for ImageMeta and LabelMeta.""" - def __init__(self, version: str, name: str, datasets: list[Dataset]) -> None: + def __init__(self, version: str, name: str | None, datasets: list[Dataset]) -> None: """Initialize the ImageMeta object.""" - self._version = version + self._version = NgffVersion(version) self._name = name if len(datasets) == 0: @@ -585,12 +649,12 @@ def __init__(self, version: str, name: str, datasets: list[Dataset]) -> None: self._datasets = datasets @property - def version(self) -> str: + def version(self) -> NgffVersion: """Version of the OME-NFF metadata used to build the object.""" return self._version @property - def name(self) -> str: + def name(self) -> str | None: """Name of the image.""" return self._name @@ -751,8 +815,12 @@ def _scaling_factors(self) -> list[float]: def xy_scaling_factor(self) -> float: """Get the xy scaling factor of the dataset.""" scaling_factors = self._scaling_factors() - x_scaling_f = scaling_factors[self.index_mapping.get("x")] - y_scaling_f = scaling_factors[self.index_mapping.get("y")] + x, y = self.index_mapping.get("x"), self.index_mapping.get("y") + if x is None or y is None: + raise ValueError("Mandatory axes x and y not found.") + + x_scaling_f = scaling_factors[x] + y_scaling_f = scaling_factors[y] if not np.allclose(x_scaling_f, y_scaling_f): raise ValueError("Inconsistent xy scaling factor.") @@ -762,9 +830,11 @@ def xy_scaling_factor(self) -> float: def z_scaling_factor(self) -> float: """Get the z scaling factor of the dataset.""" scaling_factors = self._scaling_factors() - if "z" not in self.axes_names: + z = self.index_mapping.get("z") + if z is None: return 1.0 - z_scaling_f = scaling_factors[self.index_mapping.get("z")] + + z_scaling_f = scaling_factors[z] return z_scaling_f def translation( @@ -798,28 +868,11 @@ def remove_axis(self, axis_name: str) -> Self: version=self.version, name=self.name, datasets=new_datasets ) - def add_axis( - self, axis_name: str, scale: float = 1.0, translation: float | None = None - ) -> Self: - """Add an axis to the metadata. - - Args: - axis_name(str): The name of the axis to add. - scale(float): The scale of the axis. - translation(float | None): The translation of the axis. - """ - new_datasets = [ - dataset.add_axis(axis_name, scale, translation) for dataset in self.datasets - ] - return self.__class__( - version=self.version, name=self.name, datasets=new_datasets - ) - class LabelMeta(BaseMeta): """Label metadata model.""" - def __init__(self, version: str, name: str, datasets: list[Dataset]) -> None: + def __init__(self, version: str, name: str | None, datasets: list[Dataset]) -> None: """Initialize the ImageMeta object.""" super().__init__(version, name, datasets) @@ -828,24 +881,6 @@ def __init__(self, version: str, name: str, datasets: list[Dataset]) -> None: if ax.type == AxisType.channel: raise ValueError("Channel axes are not allowed in ImageMeta.") - def add_axis( - self, axis_name: str, scale: float = 1, translation: float | None = None - ) -> "LabelMeta": - """Add an axis to the metadata.""" - # Check if the axis is a channel - axis = Axis.lazy_create( - name=axis_name, - space_unit=self.space_axes_unit, - time_unit=self.time_axis_unit, - ) - if axis.type == AxisType.channel: - raise ValueError("Channel axes are not allowed in LabelMeta.") - - meta = super().add_axis( - axis_name=axis_name, scale=scale, translation=translation - ) - return meta - class ImageMeta(BaseMeta): """Image metadata model.""" @@ -853,7 +888,7 @@ class ImageMeta(BaseMeta): def __init__( self, version: str, - name: str, + name: str | None, datasets: list[Dataset], omero: Omero | None = None, ) -> None: @@ -861,6 +896,50 @@ def __init__( super().__init__(version=version, name=name, datasets=datasets) self._omero = omero + def build_omero( + self, + channels_names: list[str], + channels_wavelengths: list[str] | None = None, + channels_extra_fields: list[dict[str, Any]] | None = None, + omero_kwargs: dict[str, Any] | None = None, + ) -> None: + """Build a default OMERO metadata. + + Args: + channels_names(list[str]): The names of the channels. + channels_wavelengths(list[str] | None): The wavelength IDs of the channels. + channels_extra_fields(list[dict[str, Any]] | None): The extra fields of + the channels. + omero_kwargs(dict[str, Any] | None): Additional OMERO metadata. + """ + omero_kwargs = {} if omero_kwargs is None else omero_kwargs + + if channels_wavelengths is None: + channels_wavelengths = channels_names + else: + if len(channels_wavelengths) != len(channels_names): + raise ValueError( + "Channels names and wavelengths " "must have the same length." + ) + + if channels_extra_fields is None: + channels_extra_fields = [{} for _ in channels_names] + else: + if len(channels_extra_fields) != len(channels_names): + raise ValueError( + "Channels names and extra fields " "must have the same length." + ) + channels = [] + for ch_name, ch_wavelength, ch_extra in zip( + channels_names, channels_wavelengths, channels_extra_fields, strict=True + ): + ch = Channel( + label=ch_name, wavelength_id=ch_wavelength, extra_fields=ch_extra + ) + channels.append(ch) + omero = Omero(channels=channels, extra_fields=omero_kwargs) + self._omero = omero + @property def omero(self) -> Omero | None: """Get the OMERO metadata.""" @@ -887,7 +966,11 @@ def channel_labels(self) -> list[str]: @property def channel_wavelength_ids(self) -> list[str]: """Get the wavelength IDs of the channels in the image.""" - return [channel.wavelength_id for channel in self.channels] + return [ + channel.wavelength_id + for channel in self.channels + if channel.wavelength_id is not None + ] def _get_channel_idx_by_label(self, label: str) -> int | None: """Get the index of a channel by its label.""" @@ -911,7 +994,7 @@ def _get_channel_idx_by_wavelength_id(self, wavelength_id: str) -> int | None: def get_channel_idx( self, label: str | None = None, wavelength_id: str | None = None - ) -> int: + ) -> int | None: """Get the index of a channel by its label or wavelength ID.""" # Only one of the arguments must be provided if sum([label is not None, wavelength_id is not None]) != 1: @@ -926,25 +1009,13 @@ def get_channel_idx( "get_channel_idx must receive either label or wavelength_id." ) - def remove_axis(self, axis_name: str) -> "ImageMeta": - """Remove an axis from the metadata. - - Args: - axis_name(str): The name of the axis to remove. - """ - new_image = super().remove_axis(axis_name=axis_name) - - # If the removed axis is a channel, remove the channel from the omero metadata - if axis_name in ChannelNames.allowed_names(): - new_omero = Omero(channels=[], **self.omero.extra_fields) - return ImageMeta( - version=new_image.version, - name=new_image.name, - datasets=new_image.datasets, - omero=new_omero, - ) - - return new_image + def to_label(self, name: str | None = None) -> LabelMeta: + """Convert the ImageMeta to a LabelMeta.""" + image_meta = self.remove_axis("c") + name = self.name if name is None else name + return LabelMeta( + version=self.version, name=self.name, datasets=image_meta.datasets + ) ImageLabelMeta = ImageMeta | LabelMeta diff --git a/src/ngio/ngff_meta/utils.py b/src/ngio/ngff_meta/utils.py index 66866a9..f8da643 100644 --- a/src/ngio/ngff_meta/utils.py +++ b/src/ngio/ngff_meta/utils.py @@ -26,9 +26,9 @@ def _create_multiscale_meta( z_scaling_factor: float = 1.0, pixel_units: SpaceUnits | str = SpaceUnits.micrometer, time_spacing: float = 1.0, - time_units: TimeUnits | str = TimeUnits.s, - num_levels: int = 5, -) -> tuple[list[Dataset], Omero]: + time_units: TimeUnits | str | None = None, + levels: int | list[str] = 5, +) -> list[Dataset]: """Create a image metadata object from scratch.""" allowed_axes_names = ( SpaceNames.allowed_names() @@ -41,6 +41,9 @@ def _create_multiscale_meta( f"Invalid axis name: {ax}, allowed names: {allowed_axes_names}" ) + if isinstance(pixel_units, str): + pixel_units = SpaceUnits(pixel_units) + if pixel_sizes is None: pixel_sizes = PixelSize(z=1.0, y=1.0, x=1.0, unit=pixel_units) @@ -53,9 +56,23 @@ def _create_multiscale_meta( "x": xy_scaling_factor, } + if time_units is None: + time_units = TimeUnits.s + + if isinstance(time_units, str): + time_units = TimeUnits(time_units) + axes = Axis.batch_create(on_disk_axis, time_unit=time_units, space_unit=pixel_units) datasets = [] - for level in range(num_levels): + + if isinstance(levels, int): + paths = [str(i) for i in range(levels)] + elif isinstance(levels, list): + if not all(isinstance(level, str) for level in levels): + raise ValueError(f"All levels must be strings. Got: {levels}") + paths = levels + + for level, path in enumerate(paths): scale = [ pixel_sizes_dict.get(ax, 1.0) * scaling_factor_dict.get(ax, 1.0) ** level for ax in on_disk_axis @@ -63,7 +80,7 @@ def _create_multiscale_meta( datasets.append( Dataset( - path=str(level), + path=path, on_disk_axes=axes, on_disk_scale=scale, on_disk_translation=None, @@ -79,7 +96,7 @@ def create_image_metadata( z_scaling_factor: float = 1.0, time_spacing: float = 1.0, time_units: TimeUnits | str = TimeUnits.s, - num_levels: int = 5, + levels: int | list[str] = 5, name: str | None = None, channel_labels: list[str] | None = None, channel_wavelengths: list[str] | None = None, @@ -102,7 +119,7 @@ def create_image_metadata( different than 1.0 for the z axis. time_spacing: The time spacing (If the time axis is present). time_units: The units of the time spacing (If the time axis is present). - num_levels: The number of levels in the pyramid. + levels: The number of levels in the pyramid or the list of paths. name: The name of the metadata. channel_labels: The names of the channels. channel_wavelengths: The wavelengths of the channels. @@ -118,7 +135,7 @@ def create_image_metadata( z_scaling_factor=z_scaling_factor, time_spacing=time_spacing, time_units=time_units, - num_levels=num_levels, + levels=levels, ) if channel_labels is None: @@ -169,8 +186,8 @@ def create_label_metadata( xy_scaling_factor: float = 2.0, z_scaling_factor: float = 1.0, time_spacing: float = 1.0, - time_units: TimeUnits | str = TimeUnits.s, - num_levels: int = 5, + time_units: TimeUnits | str | None = None, + levels: int | list[str] = 5, name: str | None = None, version: str = "0.4", ) -> LabelMeta: @@ -189,7 +206,7 @@ def create_label_metadata( different than 1.0 for the z axis. time_spacing: The time spacing (If the time axis is present). time_units: The units of the time spacing (If the time axis is present). - num_levels: The number of levels in the pyramid. + levels: The number of levels in the pyramid or the list of paths. name: The name of the metadata. version: The version of NGFF metadata. """ @@ -200,60 +217,10 @@ def create_label_metadata( z_scaling_factor=z_scaling_factor, time_spacing=time_spacing, time_units=time_units, - num_levels=num_levels, + levels=levels, ) return LabelMeta( version=version, name=name, datasets=datasets, ) - - -def remove_axis_from_metadata( - metadata: ImageMeta, - *, - axis_name: str | None = None, -) -> ImageMeta: - """Remove an axis from the metadata. - - Args: - metadata: A ImageMeta object. - axis_name: The name of the axis to remove. - """ - return metadata.remove_axis(axis_name=axis_name) - - -def add_axis_to_metadata( - metadata: ImageMeta | LabelMeta, - axis_name: str, - scale: float = 1.0, -) -> ImageMeta | LabelMeta: - """Add an axis to the ImageMeta or LabelMeta object. - - Args: - metadata: A ImageMeta or LabelMeta object. - axis_name: The name of the axis to add. - scale: The scale of the axis - """ - return metadata.add_axis( - axis_name=axis_name, - scale=scale, - ) - - -def derive_image_metadata( - image: ImageMeta, - name: str, - start_level: int = 0, -) -> ImageMeta: - """Derive a new image metadata from an existing one.""" - pass - - -def derive_label_metadata( - image: ImageMeta, - name: str, - start_level: int = 0, -) -> LabelMeta: - """Derive a new label metadata from an existing one.""" - pass diff --git a/tests/core/test_image_like_handler.py b/tests/core/test_image_like_handler.py index 5b44b61..c478360 100644 --- a/tests/core/test_image_like_handler.py +++ b/tests/core/test_image_like_handler.py @@ -19,7 +19,7 @@ def test_ngff_image(self, ome_zarr_image_v04_path: Path) -> None: assert image_handler.dimensions.shape == (3, 10, 256, 256) shape = image_handler.dimensions.shape assert image_handler.shape == shape - assert image_handler.dimensions.z == 10 + assert image_handler.dimensions.get("z") == 10 assert image_handler.is_3d assert not image_handler.is_time_series assert image_handler.is_multi_channels @@ -54,7 +54,7 @@ def test_ngff_image_fs(self, ome_zarr_image_v04_fs: Path) -> None: assert image_handler.dimensions.shape == (2, 2, 4320, 2560) shape = image_handler.dimensions.shape assert image_handler.shape == shape - assert image_handler.dimensions.z == 2 + assert image_handler.dimensions.get("z") == 2 assert image_handler.is_3d assert not image_handler.is_time_series assert image_handler.is_multi_channels diff --git a/tests/ngff_meta/test_fractal_image_meta.py b/tests/ngff_meta/test_fractal_image_meta.py index 94b51fe..126b7d8 100644 --- a/tests/ngff_meta/test_fractal_image_meta.py +++ b/tests/ngff_meta/test_fractal_image_meta.py @@ -26,25 +26,6 @@ def test_basic_workflow(self, ome_zarr_image_v04_path): assert fractal_meta.space_axes_names == ["z", "y", "x"] assert fractal_meta.get_highest_resolution_dataset().path == "0" - def test_modify_axis_from_metadata(self, ome_zarr_image_v04_path): - from ngio.ngff_meta import get_ngff_image_meta_handler - from ngio.ngff_meta.utils import add_axis_to_metadata, remove_axis_from_metadata - - handler = get_ngff_image_meta_handler( - store=ome_zarr_image_v04_path, meta_mode="image" - ) - - fractal_meta = handler.load_meta() - meta_no_channel = remove_axis_from_metadata( - metadata=fractal_meta, axis_name="c" - ) - assert meta_no_channel.axes_names == ["z", "y", "x"] - - meta_add_channel = add_axis_to_metadata( - metadata=meta_no_channel, axis_name="c", scale=1.0 - ) - assert meta_add_channel.axes_names == fractal_meta.axes_names - def test_pixel_size(self, ome_zarr_image_v04_path): from ngio.ngff_meta import get_ngff_image_meta_handler @@ -54,18 +35,3 @@ def test_pixel_size(self, ome_zarr_image_v04_path): pixel_size = handler.load_meta().pixel_size(idx=0) assert pixel_size.zyx == (1.0, 0.1625, 0.1625) - - def test_modify_axis_from_label_metadata(self, ome_zarr_label_v04_path): - from ngio.ngff_meta import get_ngff_image_meta_handler - - handler = get_ngff_image_meta_handler( - store=ome_zarr_label_v04_path, meta_mode="label" - ) - - fractal_meta = handler.load_meta() - - meta_no_channel = fractal_meta.remove_axis(axis_name="z") - assert meta_no_channel.axes_names == ["y", "x"] - - meta_add_channel = meta_no_channel.add_axis(axis_name="z", scale=1.0) - assert meta_add_channel.axes_names == fractal_meta.axes_names diff --git a/tests/ngff_meta/test_utils.py b/tests/ngff_meta/test_utils.py index 8b628ed..4eb796c 100644 --- a/tests/ngff_meta/test_utils.py +++ b/tests/ngff_meta/test_utils.py @@ -12,7 +12,7 @@ def test_create_fractal_meta_with_t(self): z_scaling_factor=1.0, time_spacing=1.0, time_units="s", - num_levels=5, + levels=5, name="test", channel_labels=["DAPI", "nanog", "Lamin B1"], channel_wavelengths=["A01_C01", "A02_C02", "A03_C03"], @@ -39,7 +39,7 @@ def test_create_fractal_meta(self): z_scaling_factor=1.0, time_spacing=1.0, time_units="s", - num_levels=5, + levels=5, name="test", channel_labels=["DAPI", "nanog", "Lamin B1"], channel_wavelengths=["A01_C01", "A02_C02", "A03_C03"], @@ -66,7 +66,7 @@ def test_create_fractal_meta_with_non_canonical_order(self): z_scaling_factor=1.0, time_spacing=1.0, time_units="s", - num_levels=5, + levels=5, name="test", channel_labels=["DAPI", "nanog", "Lamin B1"], channel_wavelengths=["A01_C01", "A02_C02", "A03_C03"], @@ -85,7 +85,7 @@ def test_create_fractal_meta_with_non_canonical_order(self): z_scaling_factor=1.0, time_spacing=1.0, time_units="s", - num_levels=5, + levels=5, name="test", channel_labels=["DAPI", "nanog", "Lamin B1"], channel_wavelengths=["A01_C01", "A02_C02", "A03_C03"], @@ -107,7 +107,7 @@ def test_create_fractal_label_meta(self): z_scaling_factor=1.0, time_spacing=1.0, time_units="s", - num_levels=5, + levels=5, name="test", version="0.4", ) From 6be8f07484f45b7f70892e344b1f5f5fb388a162 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Thu, 7 Nov 2024 15:55:25 +0100 Subject: [PATCH 5/5] expand table testing --- src/ngio/core/label_handler.py | 2 +- src/ngio/core/utils.py | 78 +++++++++++++++++++++---- src/ngio/tables/_utils.py | 7 ++- tests/tables/test_table_conversion.py | 82 +++++++++++++++++++++++++++ tests/tables/test_validation.py | 53 +++++++++++++++++ 5 files changed, 208 insertions(+), 14 deletions(-) create mode 100644 tests/tables/test_table_conversion.py create mode 100644 tests/tables/test_validation.py diff --git a/src/ngio/core/label_handler.py b/src/ngio/core/label_handler.py index 07cf3fd..1a3845c 100644 --- a/src/ngio/core/label_handler.py +++ b/src/ngio/core/label_handler.py @@ -366,7 +366,7 @@ def derive( create_empty_ome_zarr_label( store=new_label_group, - shape=on_disk_shape, + on_disk_shape=on_disk_shape, chunks=chunks, dtype=ref_0.on_disk_array.dtype, on_disk_axis=dataset.on_disk_axes_names, diff --git a/src/ngio/core/utils.py b/src/ngio/core/utils.py index aeb0ff6..b8648bf 100644 --- a/src/ngio/core/utils.py +++ b/src/ngio/core/utils.py @@ -25,7 +25,7 @@ def _build_empty_pyramid( group: Group, image_meta: ImageLabelMeta, - shape: Collection[int], + on_disk_shape: Collection[int], chunks: Collection[int] | None = None, dtype: str = "uint16", on_disk_axis: Collection[str] = ("t", "c", "z", "y"), @@ -42,6 +42,21 @@ def _build_empty_pyramid( else: scaling_factor.append(1.0) + if chunks is not None and len(on_disk_shape) != len(chunks): + raise ValueError( + "The shape and chunks must have the same number " "of dimensions." + ) + + if len(on_disk_shape) != len(scaling_factor): + raise ValueError( + "The shape and scaling factor must have the same number " "of dimensions." + ) + + if len(on_disk_shape) != len(on_disk_axis): + raise ValueError( + "The shape and on-disk axis must have the same number " "of dimensions." + ) + for dataset in image_meta.datasets: path = dataset.path @@ -52,7 +67,7 @@ def _build_empty_pyramid( group.zeros( name=path, - shape=shape, + shape=on_disk_shape, dtype=dtype, chunks=chunks, dimension_separator="/", @@ -60,15 +75,15 @@ def _build_empty_pyramid( # Todo redo this with when a proper build of pyramid is implemented _shape = [] - for s, sc in zip(shape, scaling_factor, strict=True): + for s, sc in zip(on_disk_shape, scaling_factor, strict=True): if math.floor(s / sc) % 2 == 0: _shape.append(math.floor(s / sc)) else: _shape.append(math.ceil(s / sc)) - shape = list(_shape) + on_disk_shape = list(_shape) if chunks is not None: - chunks = [min(c, s) for c, s in zip(chunks, shape, strict=True)] + chunks = [min(c, s) for c, s in zip(chunks, on_disk_shape, strict=True)] return None @@ -93,7 +108,30 @@ def create_empty_ome_zarr_image( overwrite: bool = True, version: str = "0.4", ) -> None: - """Create an empty OME-Zarr image with the given shape and metadata.""" + """Create an empty OME-Zarr image with the given shape and metadata. + + Args: + store (StoreLike): The store to create the image in. + on_disk_shape (Collection[int]): The shape of the image on disk. + on_disk_axis (Collection[str]): The order of the axes on disk. + chunks (Collection[int] | None): The chunk shape for the image. + dtype (str): The data type of the image. + pixel_sizes (PixelSize | None): The pixel size of the image. + xy_scaling_factor (float): The scaling factor in the x and y dimensions. + z_scaling_factor (float): The scaling factor in the z dimension. + time_spacing (float): The spacing between time points. + time_units (TimeUnits | str): The units of the time axis. + levels (int | list[str]): The number of levels in the pyramid. + path_names (list[str] | None): The names of the paths in the image. + name (str | None): The name of the image. + channel_labels (list[str] | None): The labels of the channels. + channel_wavelengths (list[str] | None): The wavelengths of the channels. + channel_kwargs (list[dict[str, Any]] | None): The extra fields for the channels. + omero_kwargs (dict[str, Any] | None): The extra fields for the image. + overwrite (bool): Whether to overwrite the image if it exists. + version (str): The version of the OME-Zarr format. + + """ if len(on_disk_shape) != len(on_disk_axis): raise ValueError( "The number of dimensions in the shape must match the number of " @@ -142,7 +180,7 @@ def create_empty_ome_zarr_image( _build_empty_pyramid( group=group, image_meta=image_meta, - shape=on_disk_shape, + on_disk_shape=on_disk_shape, chunks=chunks, dtype=dtype, on_disk_axis=on_disk_axis, @@ -153,7 +191,7 @@ def create_empty_ome_zarr_image( def create_empty_ome_zarr_label( store: StoreLike, - shape: Collection[int], + on_disk_shape: Collection[int], chunks: Collection[int] | None = None, dtype: str = "uint16", on_disk_axis: Collection[str] = ("t", "z", "y", "x"), @@ -167,8 +205,26 @@ def create_empty_ome_zarr_label( overwrite: bool = True, version: str = "0.4", ) -> None: - """Create an empty OME-Zarr image with the given shape and metadata.""" - if len(shape) != len(on_disk_axis): + """Create an empty OME-Zarr image with the given shape and metadata. + + Args: + store (StoreLike): The store to create the image in. + on_disk_shape (Collection[int]): The shape of the image on disk. + chunks (Collection[int] | None): The chunk shape for the image. + dtype (str): The data type of the image. + on_disk_axis (Collection[str]): The order of the axes on disk. + pixel_sizes (PixelSize | None): The pixel size of the image. + xy_scaling_factor (float): The scaling factor in the x and y dimensions. + z_scaling_factor (float): The scaling factor in the z dimension. + time_spacing (float): The spacing between time points. + time_units (TimeUnits | str | None): The units of the time axis. + levels (int | list[str]): The number of levels in the pyramid. + name (str | None): The name of the image. + overwrite (bool): Whether to overwrite the image if it exists. + version (str): The version of the OME-Zarr format + + """ + if len(on_disk_shape) != len(on_disk_axis): raise ValueError( "The number of dimensions in the shape must match the number of " "axes in the on-disk axis." @@ -199,7 +255,7 @@ def create_empty_ome_zarr_label( _build_empty_pyramid( group=group, image_meta=image_meta, - shape=shape, + on_disk_shape=on_disk_shape, chunks=chunks, dtype=dtype, on_disk_axis=on_disk_axis, diff --git a/src/ngio/tables/_utils.py b/src/ngio/tables/_utils.py index 90df988..adfd7c0 100644 --- a/src/ngio/tables/_utils.py +++ b/src/ngio/tables/_utils.py @@ -271,7 +271,9 @@ def validate_columns( table_header = table_df.columns for column in required_columns: if column not in table_header: - raise NgioTableValidationError(f"Column {column} is required in ROI table") + raise NgioTableValidationError( + f"Could not find required column: {column} in the table" + ) if optional_columns is None: return table_df @@ -280,7 +282,8 @@ def validate_columns( for column in table_header: if column not in possible_columns: raise NgioTableValidationError( - f"Column {column} is not recognized in ROI table" + f"Could not find column: {column} in the list of possible columns. ", + f"Possible columns are: {possible_columns}", ) return table_df diff --git a/tests/tables/test_table_conversion.py b/tests/tables/test_table_conversion.py new file mode 100644 index 0000000..f95c7b4 --- /dev/null +++ b/tests/tables/test_table_conversion.py @@ -0,0 +1,82 @@ +import pandas as pd +import pytest + + +class TestTableConversion: + def test_table_conversion1(self) -> None: + from ngio.tables._utils import ( + NgioTableValidationError, + table_ad_to_df, + table_df_to_ad, + ) + + df = pd.DataFrame.from_records( + data=[ + {"label": 1, "feat1": 0.1}, + {"label": 2, "feat1": 0.3}, + {"label": 3, "feat1": 0.5}, + ] + ) + + with pytest.raises(NgioTableValidationError): + table_df_to_ad(df, index_key="label2", index_type="str") + + # Index as column + ad_table = table_df_to_ad(df, index_key="label", index_type="int") + + df_out = table_ad_to_df(ad_table, index_key="label", index_type="int") + + df_out["feat1"].equals(df["feat1"]) + + # Set index explicitly + df.set_index("label", inplace=True) + ad_table = table_df_to_ad(df, index_key="label", index_type="int") + + df_out = table_ad_to_df(ad_table, index_key="label", index_type="int") + + df_out["feat1"].equals(df["feat1"]) + + def test_table_conversion2(self) -> None: + from ngio.tables._utils import ( + NgioTableValidationError, + table_ad_to_df, + table_df_to_ad, + ) + + df = pd.DataFrame.from_records( + data=[ + {"label": "1a", "feat1": 0.1}, + {"label": "2b", "feat1": 0.3}, + {"label": "3c", "feat1": 0.5}, + ] + ) + + with pytest.raises(NgioTableValidationError): + table_df_to_ad(df, index_key="label", index_type="int") + ad_table = table_df_to_ad(df, index_key="label", index_type="str") + + df_out = table_ad_to_df(table_ad=ad_table, index_key="label", index_type="str") + + df_out["feat1"].equals(df["feat1"]) + + with pytest.raises(NgioTableValidationError): + df_out = table_ad_to_df( + table_ad=ad_table, index_key="label", index_type="int" + ) + + def test_table_conversion3(self) -> None: + from ngio.tables._utils import ( + NgioTableValidationError, + table_df_to_ad, + ) + + df = pd.DataFrame.from_records( + data=[ + {"label": 1.3, "feat1": 0.1}, + {"label": 2.1, "feat1": 0.3}, + {"label": 3.4, "feat1": 0.5}, + ] + ) + + with pytest.raises(NgioTableValidationError): + table_df_to_ad(df, index_key="label", index_type="int") diff --git a/tests/tables/test_validation.py b/tests/tables/test_validation.py new file mode 100644 index 0000000..e269ec0 --- /dev/null +++ b/tests/tables/test_validation.py @@ -0,0 +1,53 @@ +import pandas as pd +import pytest + + +class TestValidation: + def test_validate_unique(self) -> None: + from ngio.tables._utils import NgioTableValidationError, validate_unique_index + + df = pd.DataFrame.from_records( + data=[ + {"id": 1, "x": 0.1}, + {"id": 2, "x": 0.3}, + {"id": 3, "x": 0.5}, + ] + ) + df.set_index("id", inplace=True) + out_df = validate_unique_index(df) + assert out_df.equals(df) + + df = pd.DataFrame.from_records( + data=[ + {"id": 1, "x": 0.1}, + {"id": 1, "x": 0.3}, + {"id": 3, "x": 0.5}, + ] + ) + df.set_index("id", inplace=True) + with pytest.raises(NgioTableValidationError): + validate_unique_index(df) + + def test_validate_column(self) -> None: + from ngio.tables._utils import NgioTableValidationError, validate_columns + + df = pd.DataFrame.from_records( + data=[ + {"id": 1, "x": 0.1}, + {"id": 2, "x": 0.3}, + {"id": 3, "x": 0.5}, + ] + ) + out_df = validate_columns( + df, required_columns=["id", "x"], optional_columns=["y"] + ) + assert out_df.equals(df) + + out_df = validate_columns(df, required_columns=["id", "x"]) + assert out_df.equals(df) + + with pytest.raises(NgioTableValidationError): + validate_columns(df, required_columns=["y"]) + + with pytest.raises(NgioTableValidationError): + validate_columns(df, required_columns=["id"], optional_columns=["y"])