From 51dabac4dbc3bede9392d9b2dba1f52d00f4b2da Mon Sep 17 00:00:00 2001 From: lorenzo Date: Sat, 9 Nov 2024 10:21:29 +0100 Subject: [PATCH] improve color picking strategy --- src/ngio/ngff_meta/fractal_image_meta.py | 91 +++++++++++++++++------- src/ngio/ngff_meta/utils.py | 4 +- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/src/ngio/ngff_meta/fractal_image_meta.py b/src/ngio/ngff_meta/fractal_image_meta.py index e3f8be9..ad260cd 100644 --- a/src/ngio/ngff_meta/fractal_image_meta.py +++ b/src/ngio/ngff_meta/fractal_image_meta.py @@ -7,6 +7,7 @@ """ from collections.abc import Collection +from difflib import SequenceMatcher from enum import Enum from typing import Any, TypeVar @@ -49,15 +50,50 @@ class NgioColors(str, Enum): cyan = "00FFFF" gray = "808080" green = "00FF00" - random = "random" - def random_pick(self) -> "NgioColors": - """Pick a random color.""" - available_colors = [color for color in NgioColors if color != "random"] - return available_colors[np.random.randint(0, len(available_colors))] + @staticmethod + def semi_random_pick(channel_name: str | None = None) -> "NgioColors": + """Try to fuzzy match the color to the channel name. - -def valid_hex_color(v: str) -> str: + - If a channel name is given will try to match the channel name to the color. + - If name has the paatern 'channel_x' cyclic rotate over a list of colors + [cyan, magenta, yellow, green] + - If no channel name is given will return a random color. + """ + available_colors = NgioColors._member_names_ + + if channel_name is None: + # Purely random color + color_str = available_colors[np.random.randint(0, len(available_colors))] + return NgioColors.__members__[color_str] + + if channel_name.startswith("channel_"): + # Rotate over a list of colors + defaults_colors = [ + NgioColors.cyan, + NgioColors.magenta, + NgioColors.yellow, + NgioColors.green, + ] + + try: + index = int(channel_name.split("_")[-1]) % len(defaults_colors) + return defaults_colors[index] + except ValueError: + # If the name of the channel is something like + # channel_dapi this will fail an proceed to the + # standard fuzzy match + pass + + similarity = {} + for color in available_colors: + # try to match the color to the channel name + similarity[color] = SequenceMatcher(None, channel_name, color).ratio() + color_str = max(similarity, key=similarity.get) + return NgioColors.__members__[color_str] + + +def valid_hex_color(v: str) -> bool: """Validate a hexadecimal color. Check that `color` is made of exactly six elements which are letters @@ -70,15 +106,12 @@ def valid_hex_color(v: str) -> str: - Tommaso Comparin """ if len(v) != 6: - raise ValueError(f'color must have length 6 (given: "{v}")') + return False allowed_characters = "abcdefABCDEF0123456789" for character in v: if character not in allowed_characters: - raise ValueError( - "color must only include characters from " - f'"{allowed_characters}" (given: "{v}")' - ) - return v + return False + return True class ChannelVisualisation(BaseWithExtraFields): @@ -95,15 +128,15 @@ class ChannelVisualisation(BaseWithExtraFields): active(bool): Whether the channel is active. """ - color: str | NgioColors = NgioColors.random + color: str | NgioColors | None = Field(default=None, validate_default=True) min: int | float = 0 max: int | float = 65535 start: int | float = 0 end: int | float = 65535 active: bool = True - @classmethod @field_validator("color", mode="after") + @classmethod def validate_color(cls, value: str | NgioColors) -> str: """Color validator. @@ -112,18 +145,22 @@ def validate_color(cls, value: str | NgioColors) -> str: - A color name. - A NgioColors element. """ - if value in NgioColors: - value = NgioColors(value) - if value == NgioColors.random: - value = value.random_pick() + if value is None: + return NgioColors.semi_random_pick().value + if isinstance(value, str) and valid_hex_color(value): + return value + elif isinstance(value, str): + value_lower = value.lower() + return NgioColors.semi_random_pick(value_lower).value + elif isinstance(value, NgioColors): return value.value - - return valid_hex_color(value) + else: + raise ValueError("Invalid color value.") @classmethod def lazy_init( cls, - color: str = "random", + color: str | NgioColors | None = None, start: int | float | None = None, end: int | float | None = None, data_type: Any = np.uint16, @@ -138,7 +175,7 @@ def lazy_init( """ start = start if start is not None else np.iinfo(data_type).min end = end if end is not None else np.iinfo(data_type).max - return cls( + return ChannelVisualisation( color=color, min=np.iinfo(data_type).min, max=np.iinfo(data_type).max, @@ -166,7 +203,7 @@ def lazy_init( cls, label: str, wavelength_id: str | None = None, - color: str = "random", + color: str | NgioColors | None = None, start: int | float | None = None, end: int | float | None = None, data_type: Any = np.uint16, @@ -177,10 +214,16 @@ def lazy_init( label(str): The label of the channel. wavelength_id(str | None): The wavelength ID of the channel. color(str): The color of the channel in hexadecimal format or a color name. + If None, the color will be picked based on the label. start(int | float | None): The start value of the channel. end(int | float | None): The end value of the channel. data_type(Any): The data type of the channel. """ + if color is None: + # If no color is provided, try to pick a color based on the label + # See the NgioColors.semi_random_pick method for more details. + color = label + channel_visualization = ChannelVisualisation.lazy_init( color=color, start=start, end=end, data_type=data_type ) diff --git a/src/ngio/ngff_meta/utils.py b/src/ngio/ngff_meta/utils.py index 0786c87..4c52189 100644 --- a/src/ngio/ngff_meta/utils.py +++ b/src/ngio/ngff_meta/utils.py @@ -157,7 +157,9 @@ def create_image_metadata( ) if channel_visualization is None: - channel_visualization = [ChannelVisualisation() for _ in channel_labels] + channel_visualization = [ + ChannelVisualisation(color=label) for label in channel_labels + ] else: if len(channel_visualization) != len(channel_labels): raise ValueError(