-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
46ab28a
commit dbd6efe
Showing
7 changed files
with
330 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,14 +8,16 @@ | |
|
||
from collections.abc import Collection | ||
from enum import Enum | ||
from typing import Any | ||
from typing import Any, TypeVar | ||
|
||
import numpy as np | ||
from pydantic import BaseModel, Field | ||
from pydantic import BaseModel, Field, field_validator | ||
from typing_extensions import Self | ||
|
||
from ngio.utils._pydantic_utils import BaseWithExtraFields | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class NgffVersion(str, Enum): | ||
"""Allowed NGFF versions.""" | ||
|
@@ -32,36 +34,120 @@ class NgffVersion(str, Enum): | |
################################################################################################# | ||
|
||
|
||
class Window(BaseModel): | ||
"""Window model to be used by the Viewer.""" | ||
|
||
min: int | float | ||
max: int | float | ||
start: int | float | ||
end: int | float | ||
|
||
@classmethod | ||
def from_type(cls, data_type: str) -> "Window": | ||
"""Create a Window object from a window type.""" | ||
type_info = np.iinfo(data_type) | ||
return cls( | ||
min=type_info.min, max=type_info.max, start=type_info.min, end=type_info.max | ||
) | ||
class NgioColors(str, Enum): | ||
"""Default colors for the channels.""" | ||
|
||
dapi = "0000FF" | ||
hoechst = "0000FF" | ||
gfp = "00FF00" | ||
cy3 = "FFFF00" | ||
cy5 = "FF0000" | ||
brightfield = "808080" | ||
red = "FF0000" | ||
yellow = "FFFF00" | ||
magenta = "FF00FF" | ||
cyan = "00FFFF" | ||
gray = "808080" | ||
green = "00FF00" | ||
random = "random" | ||
|
||
def random_pick(self) -> "NgioColors": | ||
"""Pick a random color.""" | ||
available_colors = [color for color in NgioColors if color != "random"] | ||
return available_colors[np.random.randint(0, len(available_colors))] | ||
|
||
|
||
def valid_hex_color(v: str) -> str: | ||
"""Validate a hexadecimal color. | ||
Check that `color` is made of exactly six elements which are letters | ||
(a-f or A-F) or digits (0-9). | ||
If fail, raise a ValueError. | ||
Implementation source: | ||
https://github.com/fractal-analytics-platform/fractal-tasks-core/fractal_tasks_core/channels.py#L87 | ||
Original authors: | ||
- Tommaso Comparin <[email protected]> | ||
""" | ||
if len(v) != 6: | ||
raise ValueError(f'color must have length 6 (given: "{v}")') | ||
allowed_characters = "abcdefABCDEF0123456789" | ||
for character in v: | ||
if character not in allowed_characters: | ||
raise ValueError( | ||
"color must only include characters from " | ||
f'"{allowed_characters}" (given: "{v}")' | ||
) | ||
return v | ||
|
||
|
||
class ChannelVisualisation(BaseWithExtraFields): | ||
"""Channel visualisation model. | ||
Contains the information about the visualisation of a channel. | ||
Attributes: | ||
color(str): The color of the channel in hexadecimal format or a color name. | ||
min(int | float): The minimum value of the channel. | ||
max(int | float): The maximum value of the channel. | ||
start(int | float): The start value of the channel. | ||
end(int | float): The end value of the channel. | ||
active(bool): Whether the channel is active. | ||
""" | ||
|
||
color: str | ||
window: Window | ||
color: str | NgioColors = NgioColors.random | ||
min: int | float = 0 | ||
max: int | float = 65535 | ||
start: int | float = 0 | ||
end: int | float = 65535 | ||
active: bool = True | ||
inverted: bool = False | ||
|
||
@classmethod | ||
@field_validator("color", mode="after") | ||
def validate_color(cls, value: str | NgioColors) -> str: | ||
"""Color validator. | ||
There are three possible values to set a color: | ||
- A hexadecimal string. | ||
- A color name. | ||
- A NgioColors element. | ||
""" | ||
if value in NgioColors: | ||
value = NgioColors(value) | ||
if value == NgioColors.random: | ||
value = value.random_pick() | ||
return value.value | ||
|
||
return valid_hex_color(value) | ||
|
||
class Channel(BaseWithExtraFields): | ||
@classmethod | ||
def lazy_init( | ||
cls, | ||
color: str = "random", | ||
start: int | float | None = None, | ||
end: int | float | None = None, | ||
data_type: Any = np.uint16, | ||
) -> "ChannelVisualisation": | ||
"""Create a ChannelVisualisation object with the default unit. | ||
Args: | ||
color(str): The color of the channel in hexadecimal format or a color name. | ||
start(int | float | None): The start value of the channel. | ||
end(int | float | None): The end value of the channel. | ||
data_type(Any): The data type of the channel. | ||
""" | ||
start = start if start is not None else np.iinfo(data_type).min | ||
end = end if end is not None else np.iinfo(data_type).max | ||
return cls( | ||
color=color, | ||
min=np.iinfo(data_type).min, | ||
max=np.iinfo(data_type).max, | ||
start=start, | ||
end=end, | ||
) | ||
|
||
|
||
class Channel(BaseModel): | ||
"""Information about a channel in the image. | ||
Attributes: | ||
|
@@ -73,27 +159,57 @@ class Channel(BaseWithExtraFields): | |
|
||
label: str | ||
wavelength_id: str | None = None | ||
channel_visualisation: ChannelVisualisation | ||
|
||
@classmethod | ||
def lazy_init( | ||
cls, | ||
label: str, | ||
wavelength_id: str | None = None, | ||
color: str = "00FFFF", | ||
color: str = "random", | ||
start: int | float | None = None, | ||
end: int | float | None = None, | ||
data_type: Any = np.uint16, | ||
) -> "Channel": | ||
"""Create a Channel object with the default unit.""" | ||
channel_visualization = ChannelVisualisation( | ||
color=color, window=Window.from_type(data_type) | ||
) | ||
"""Create a Channel object with the default unit. | ||
Args: | ||
label(str): The label of the channel. | ||
wavelength_id(str | None): The wavelength ID of the channel. | ||
color(str): The color of the channel in hexadecimal format or a color name. | ||
start(int | float | None): The start value of the channel. | ||
end(int | float | None): The end value of the channel. | ||
data_type(Any): The data type of the channel. | ||
""" | ||
channel_visualization = ChannelVisualisation.lazy_init( | ||
color=color, start=start, end=end, data_type=data_type | ||
) | ||
return cls( | ||
label=label, | ||
wavelength_id=wavelength_id, | ||
**channel_visualization.model_dump(), | ||
channel_visualisation=channel_visualization, | ||
) | ||
|
||
|
||
def _check_elements(elements: Collection[T], expected_type: Any) -> Collection[T]: | ||
"""Check that the elements are of the same type.""" | ||
if len(elements) == 0: | ||
raise ValueError("At least one element must be provided.") | ||
|
||
for element in elements: | ||
if not isinstance(element, expected_type): | ||
raise ValueError(f"All elements must be of the same type {expected_type}.") | ||
|
||
return elements | ||
|
||
|
||
def _check_unique(elements: Collection[T]) -> Collection[T]: | ||
"""Check that the elements are unique.""" | ||
if len(set(elements)) != len(elements): | ||
raise ValueError("All elements must be unique.") | ||
return elements | ||
|
||
|
||
class Omero(BaseWithExtraFields): | ||
"""Information about the OMERO metadata. | ||
|
@@ -105,6 +221,75 @@ class Omero(BaseWithExtraFields): | |
|
||
channels: list[Channel] = Field(default_factory=list) | ||
|
||
@classmethod | ||
def lazy_init( | ||
cls, | ||
channels: Collection[str] | int, | ||
wavelength_id: Collection[str] | None = None, | ||
colors: Collection[str | NgioColors] | None = None, | ||
start: Collection[int | float] | int | float | None = None, | ||
end: Collection[int | float] | int | float | None = None, | ||
data_type: Any = np.uint16, | ||
**omero_kwargs: dict, | ||
) -> "Omero": | ||
"""Create an Omero object with the default unit. | ||
Args: | ||
channels(Collection[str] | int): The list of channels in the image. | ||
If an integer is provided, the channels will be named "channel_i". | ||
wavelength_id(Collection[str] | None): The wavelength ID of the channel. | ||
If None, the wavelength ID will be the same as the channel name. | ||
colors(Collection[str, NgioColors] | None): The list of colors for the | ||
channels. If None, the colors will be random. | ||
start(Collection[int | float] | int | float | None): The start value of the | ||
channel. If None, the start value will be the minimum value of the | ||
data type. | ||
end(Collection[int | float] | int | float | None): The end value of the | ||
channel. If None, the end value will be the maximum value of the | ||
data type. | ||
data_type(Any): The data type of the channel. Will be used to set the | ||
min and max values of the channel. | ||
omero_kwargs(dict): Extra fields to store in the omero attributes. | ||
""" | ||
if isinstance(channels, int): | ||
channels = [f"channel_{i}" for i in range(channels)] | ||
|
||
channels = _check_elements(channels, str) | ||
channels = _check_unique(channels) | ||
|
||
_wavelength_id: Collection[str | None] = [None] * len(channels) | ||
if isinstance(wavelength_id, Collection): | ||
_wavelength_id = _check_elements(wavelength_id, str) | ||
_wavelength_id = _check_unique(wavelength_id) | ||
|
||
_colors: Collection[str | NgioColors] = ["random"] * len(channels) | ||
if isinstance(colors, Collection): | ||
_colors = _check_elements(colors, str | NgioColors) | ||
|
||
_start: Collection[int | float | None] = [None] * len(channels) | ||
if isinstance(start, Collection): | ||
_start = _check_elements(start, (int, float)) | ||
|
||
_end: Collection[int | float | None] = [None] * len(channels) | ||
if isinstance(end, Collection): | ||
_end = _check_elements(end, (int, float)) | ||
|
||
omero_channels = [] | ||
for ch_name, w_id, color, s, e in zip( | ||
channels, _wavelength_id, _colors, _start, _end, strict=True | ||
): | ||
omero_channels.append( | ||
Channel.lazy_init( | ||
label=ch_name, | ||
wavelength_id=w_id, | ||
color=color, | ||
start=s, | ||
end=e, | ||
data_type=data_type, | ||
) | ||
) | ||
return cls(channels=omero_channels, **omero_kwargs) | ||
|
||
|
||
################################################################################################ | ||
# | ||
|
@@ -217,6 +402,10 @@ def from_list( | |
) -> "PixelSize": | ||
"""Build a PixelSize object from a list of sizes. | ||
Order of the sizes: | ||
- for 2d: [y, x] | ||
- for 3d: [z, y, x] | ||
Note: The order of the sizes must be z, y, x. | ||
Args: | ||
|
@@ -235,12 +424,12 @@ def as_dict(self) -> dict: | |
return {"z": self.z, "y": self.y, "x": self.x} | ||
|
||
@property | ||
def zyx(self) -> tuple: | ||
def zyx(self) -> tuple[float, float, float]: | ||
"""Return the voxel size in z, y, x order.""" | ||
return self.z, self.y, self.x | ||
|
||
@property | ||
def yx(self) -> tuple: | ||
def yx(self) -> tuple[float, float]: | ||
"""Return the xy plane pixel size in y, x order.""" | ||
return self.y, self.x | ||
|
||
|
Oops, something went wrong.