Skip to content

Commit

Permalink
refactor channel viewer metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Nov 8, 2024
1 parent 46ab28a commit dbd6efe
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/ngio/core/ngff_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def derive_new_image(
"name": name,
"channel_labels": image_0.channel_labels,
"channel_wavelengths": [ch.wavelength_id for ch in channels],
"channel_kwargs": [ch.extra_fields for ch in channels],
"channel_visualization": [ch.channel_visualisation for ch in channels],
"omero_kwargs": omero_kwargs,
"overwrite": overwrite,
"version": self.image_meta.version,
Expand Down
9 changes: 5 additions & 4 deletions src/ngio/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_ngff_image_meta_handler,
)
from ngio.ngff_meta.fractal_image_meta import (
ChannelVisualisation,
PixelSize,
TimeUnits,
)
Expand Down Expand Up @@ -99,11 +100,10 @@ def create_empty_ome_zarr_image(
time_spacing: float = 1.0,
time_units: TimeUnits | str = TimeUnits.s,
levels: int | list[str] = 5,
path_names: list[str] | None = None,
name: str | None = None,
channel_labels: list[str] | None = None,
channel_wavelengths: list[str] | None = None,
channel_kwargs: list[dict[str, Any]] | None = None,
channel_visualization: list[ChannelVisualisation] | None = None,
omero_kwargs: dict[str, Any] | None = None,
overwrite: bool = True,
version: str = "0.4",
Expand All @@ -126,7 +126,8 @@ def create_empty_ome_zarr_image(
name (str | None): The name of the image.
channel_labels (list[str] | None): The labels of the channels.
channel_wavelengths (list[str] | None): The wavelengths of the channels.
channel_kwargs (list[dict[str, Any]] | None): The extra fields for the channels.
channel_visualization (list[ChannelVisualisation] | None): A list of
channel visualisation objects.
omero_kwargs (dict[str, Any] | None): The extra fields for the image.
overwrite (bool): Whether to overwrite the image if it exists.
version (str): The version of the OME-Zarr format.
Expand Down Expand Up @@ -163,7 +164,7 @@ def create_empty_ome_zarr_image(
name=name,
channel_labels=channel_labels,
channel_wavelengths=channel_wavelengths,
channel_kwargs=channel_kwargs,
channel_visualization=channel_visualization,
omero_kwargs=omero_kwargs,
version=version,
)
Expand Down
247 changes: 218 additions & 29 deletions src/ngio/ngff_meta/fractal_image_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@

from collections.abc import Collection
from enum import Enum
from typing import Any
from typing import Any, TypeVar

import numpy as np
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Self

from ngio.utils._pydantic_utils import BaseWithExtraFields

T = TypeVar("T")


class NgffVersion(str, Enum):
"""Allowed NGFF versions."""
Expand All @@ -32,36 +34,120 @@ class NgffVersion(str, Enum):
#################################################################################################


class Window(BaseModel):
"""Window model to be used by the Viewer."""

min: int | float
max: int | float
start: int | float
end: int | float

@classmethod
def from_type(cls, data_type: str) -> "Window":
"""Create a Window object from a window type."""
type_info = np.iinfo(data_type)
return cls(
min=type_info.min, max=type_info.max, start=type_info.min, end=type_info.max
)
class NgioColors(str, Enum):
"""Default colors for the channels."""

dapi = "0000FF"
hoechst = "0000FF"
gfp = "00FF00"
cy3 = "FFFF00"
cy5 = "FF0000"
brightfield = "808080"
red = "FF0000"
yellow = "FFFF00"
magenta = "FF00FF"
cyan = "00FFFF"
gray = "808080"
green = "00FF00"
random = "random"

def random_pick(self) -> "NgioColors":
"""Pick a random color."""
available_colors = [color for color in NgioColors if color != "random"]
return available_colors[np.random.randint(0, len(available_colors))]


def valid_hex_color(v: str) -> str:
"""Validate a hexadecimal color.
Check that `color` is made of exactly six elements which are letters
(a-f or A-F) or digits (0-9).
If fail, raise a ValueError.
Implementation source:
https://github.com/fractal-analytics-platform/fractal-tasks-core/fractal_tasks_core/channels.py#L87
Original authors:
- Tommaso Comparin <[email protected]>
"""
if len(v) != 6:
raise ValueError(f'color must have length 6 (given: "{v}")')
allowed_characters = "abcdefABCDEF0123456789"
for character in v:
if character not in allowed_characters:
raise ValueError(
"color must only include characters from "
f'"{allowed_characters}" (given: "{v}")'
)
return v


class ChannelVisualisation(BaseWithExtraFields):
"""Channel visualisation model.
Contains the information about the visualisation of a channel.
Attributes:
color(str): The color of the channel in hexadecimal format or a color name.
min(int | float): The minimum value of the channel.
max(int | float): The maximum value of the channel.
start(int | float): The start value of the channel.
end(int | float): The end value of the channel.
active(bool): Whether the channel is active.
"""

color: str
window: Window
color: str | NgioColors = NgioColors.random
min: int | float = 0
max: int | float = 65535
start: int | float = 0
end: int | float = 65535
active: bool = True
inverted: bool = False

@classmethod
@field_validator("color", mode="after")
def validate_color(cls, value: str | NgioColors) -> str:
"""Color validator.
There are three possible values to set a color:
- A hexadecimal string.
- A color name.
- A NgioColors element.
"""
if value in NgioColors:
value = NgioColors(value)
if value == NgioColors.random:
value = value.random_pick()
return value.value

return valid_hex_color(value)

class Channel(BaseWithExtraFields):
@classmethod
def lazy_init(
cls,
color: str = "random",
start: int | float | None = None,
end: int | float | None = None,
data_type: Any = np.uint16,
) -> "ChannelVisualisation":
"""Create a ChannelVisualisation object with the default unit.
Args:
color(str): The color of the channel in hexadecimal format or a color name.
start(int | float | None): The start value of the channel.
end(int | float | None): The end value of the channel.
data_type(Any): The data type of the channel.
"""
start = start if start is not None else np.iinfo(data_type).min
end = end if end is not None else np.iinfo(data_type).max
return cls(
color=color,
min=np.iinfo(data_type).min,
max=np.iinfo(data_type).max,
start=start,
end=end,
)


class Channel(BaseModel):
"""Information about a channel in the image.
Attributes:
Expand All @@ -73,27 +159,57 @@ class Channel(BaseWithExtraFields):

label: str
wavelength_id: str | None = None
channel_visualisation: ChannelVisualisation

@classmethod
def lazy_init(
cls,
label: str,
wavelength_id: str | None = None,
color: str = "00FFFF",
color: str = "random",
start: int | float | None = None,
end: int | float | None = None,
data_type: Any = np.uint16,
) -> "Channel":
"""Create a Channel object with the default unit."""
channel_visualization = ChannelVisualisation(
color=color, window=Window.from_type(data_type)
)
"""Create a Channel object with the default unit.
Args:
label(str): The label of the channel.
wavelength_id(str | None): The wavelength ID of the channel.
color(str): The color of the channel in hexadecimal format or a color name.
start(int | float | None): The start value of the channel.
end(int | float | None): The end value of the channel.
data_type(Any): The data type of the channel.
"""
channel_visualization = ChannelVisualisation.lazy_init(
color=color, start=start, end=end, data_type=data_type
)
return cls(
label=label,
wavelength_id=wavelength_id,
**channel_visualization.model_dump(),
channel_visualisation=channel_visualization,
)


def _check_elements(elements: Collection[T], expected_type: Any) -> Collection[T]:
"""Check that the elements are of the same type."""
if len(elements) == 0:
raise ValueError("At least one element must be provided.")

for element in elements:
if not isinstance(element, expected_type):
raise ValueError(f"All elements must be of the same type {expected_type}.")

return elements


def _check_unique(elements: Collection[T]) -> Collection[T]:
"""Check that the elements are unique."""
if len(set(elements)) != len(elements):
raise ValueError("All elements must be unique.")
return elements


class Omero(BaseWithExtraFields):
"""Information about the OMERO metadata.
Expand All @@ -105,6 +221,75 @@ class Omero(BaseWithExtraFields):

channels: list[Channel] = Field(default_factory=list)

@classmethod
def lazy_init(
cls,
channels: Collection[str] | int,
wavelength_id: Collection[str] | None = None,
colors: Collection[str | NgioColors] | None = None,
start: Collection[int | float] | int | float | None = None,
end: Collection[int | float] | int | float | None = None,
data_type: Any = np.uint16,
**omero_kwargs: dict,
) -> "Omero":
"""Create an Omero object with the default unit.
Args:
channels(Collection[str] | int): The list of channels in the image.
If an integer is provided, the channels will be named "channel_i".
wavelength_id(Collection[str] | None): The wavelength ID of the channel.
If None, the wavelength ID will be the same as the channel name.
colors(Collection[str, NgioColors] | None): The list of colors for the
channels. If None, the colors will be random.
start(Collection[int | float] | int | float | None): The start value of the
channel. If None, the start value will be the minimum value of the
data type.
end(Collection[int | float] | int | float | None): The end value of the
channel. If None, the end value will be the maximum value of the
data type.
data_type(Any): The data type of the channel. Will be used to set the
min and max values of the channel.
omero_kwargs(dict): Extra fields to store in the omero attributes.
"""
if isinstance(channels, int):
channels = [f"channel_{i}" for i in range(channels)]

channels = _check_elements(channels, str)
channels = _check_unique(channels)

_wavelength_id: Collection[str | None] = [None] * len(channels)
if isinstance(wavelength_id, Collection):
_wavelength_id = _check_elements(wavelength_id, str)
_wavelength_id = _check_unique(wavelength_id)

_colors: Collection[str | NgioColors] = ["random"] * len(channels)
if isinstance(colors, Collection):
_colors = _check_elements(colors, str | NgioColors)

_start: Collection[int | float | None] = [None] * len(channels)
if isinstance(start, Collection):
_start = _check_elements(start, (int, float))

_end: Collection[int | float | None] = [None] * len(channels)
if isinstance(end, Collection):
_end = _check_elements(end, (int, float))

omero_channels = []
for ch_name, w_id, color, s, e in zip(
channels, _wavelength_id, _colors, _start, _end, strict=True
):
omero_channels.append(
Channel.lazy_init(
label=ch_name,
wavelength_id=w_id,
color=color,
start=s,
end=e,
data_type=data_type,
)
)
return cls(channels=omero_channels, **omero_kwargs)


################################################################################################
#
Expand Down Expand Up @@ -217,6 +402,10 @@ def from_list(
) -> "PixelSize":
"""Build a PixelSize object from a list of sizes.
Order of the sizes:
- for 2d: [y, x]
- for 3d: [z, y, x]
Note: The order of the sizes must be z, y, x.
Args:
Expand All @@ -235,12 +424,12 @@ def as_dict(self) -> dict:
return {"z": self.z, "y": self.y, "x": self.x}

@property
def zyx(self) -> tuple:
def zyx(self) -> tuple[float, float, float]:
"""Return the voxel size in z, y, x order."""
return self.z, self.y, self.x

@property
def yx(self) -> tuple:
def yx(self) -> tuple[float, float]:
"""Return the xy plane pixel size in y, x order."""
return self.y, self.x

Expand Down
Loading

0 comments on commit dbd6efe

Please sign in to comment.