From d8621a980ebab39023f5db089f94a3702cb7c144 Mon Sep 17 00:00:00 2001 From: Yuri Chiucconi Date: Mon, 19 Feb 2024 17:59:13 +0100 Subject: [PATCH] define SingleImage and ScalarDict classes --- src/filters.py | 58 --------------------------- src/images.py | 77 ++++++++++++++++++++++++++++++++---- src/models.py | 17 ++++---- src/runner.py | 6 +-- src/task_output.py | 9 ++--- src/tasks.py | 20 +++++----- tests/test_apply_workflow.py | 4 +- tests/test_filters.py | 2 +- tests/test_models.py | 6 +-- tests/test_runner.py | 4 +- 10 files changed, 105 insertions(+), 98 deletions(-) delete mode 100644 src/filters.py diff --git a/src/filters.py b/src/filters.py deleted file mode 100644 index 324c317..0000000 --- a/src/filters.py +++ /dev/null @@ -1,58 +0,0 @@ -from copy import copy -from typing import Optional - -from images import ImageAttribute -from images import SingleImage -from termcolor import cprint -from utils import ipjson -from utils import pjson - -FilterSet = dict[str, ImageAttribute] - - -def _filter_image_list( - images: list[SingleImage], - filters: Optional[FilterSet] = None, -) -> list[SingleImage]: - - if filters is None: - # When no filter is provided, return all images - return images - - filtered_images = [] - for this_image in images: - include_this_image = True - for key, value in filters.items(): - # If the FilterSet input includes the key-value pair - # "attribute": None, then we ignore "attribute" - if value is None: - continue - if this_image.get(key, None) != value: - include_this_image = False - break - if include_this_image: - filtered_images.append(copy(this_image)) - return filtered_images - - -def filter_images( - *, - dataset_images: list[SingleImage], - dataset_filters: Optional[FilterSet] = None, - wftask_filters: Optional[FilterSet] = None, -) -> list[SingleImage]: - def print(x): - return cprint(x, "red") - - current_filters = copy(dataset_filters) - current_filters.update(wftask_filters) - print(f"[filter_images] Dataset filters:\n{ipjson(dataset_filters)}") - print(f"[filter_images] WorkflowTask filters:\n{ipjson(wftask_filters)}") - print(f"[filter_images] Dataset images:\n{ipjson(dataset_images)}") - print(f"[filter_images] Current selection filters:\n{ipjson(current_filters)}") - filtered_images = _filter_image_list( - dataset_images, - filters=current_filters, - ) - print(f"[filter_images] Filtered image list: {pjson(filtered_images)}") - return filtered_images diff --git a/src/images.py b/src/images.py index 162ac10..f0d4529 100644 --- a/src/images.py +++ b/src/images.py @@ -1,15 +1,43 @@ -# Example image -# image = {"path": "/tmp/asasd", "dimensions": 3} -# Example filters -# filters = {"dimensions": 2, "illumination_corrected": False} from copy import copy from typing import Union - +from typing import Optional +from pydantic import BaseModel +from pydantic import Field +from utils import ipjson, pjson ImageAttribute = Union[str, bool, int, None] # a scalar JSON object -SingleImage = dict[str, ImageAttribute] +def check_key_value(key, value): + if not isinstance(key, str): + raise TypeError("Key must be a string") + if not isinstance(value, (int, float, str, bool, type(None))): + raise ValueError( + "Value must be a scalar (int, float, str, bool, or None)" + ) + +class ScalarDict(dict): + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + check_key_value(key=key, value=value) + super().__init__(**kwargs) + + def __setitem__(self, key, value): + check_key_value(key=key, value=value) + super().__setitem__(key, value) + +class SingleImage(BaseModel): + path: str + attributes: ScalarDict = Field(default_factory=ScalarDict) + def apply_filter(self, filters: ScalarDict): + for key, value in filters.items(): + if value is None: + continue + if self.attributes.get(key) != value: + return False + return True + def find_image_by_path( *, images: list[SingleImage], @@ -26,7 +54,7 @@ def find_image_by_path( The first image from `images` which has path equal to `path`. """ try: - image = next(image for image in images if image["path"] == path) + image = next(image for image in images if image.path == path) return copy(image) except StopIteration: raise ValueError(f"No image with {path=} found in image list.") @@ -41,3 +69,38 @@ def _deduplicate_list_of_dicts(list_of_dicts: list[dict]) -> list[dict]: if my_dict not in new_list_of_dicts: new_list_of_dicts.append(my_dict) return new_list_of_dicts + + +def _filter_image_list( + images: list[SingleImage], + filters: Optional[ScalarDict] = None, +) -> list[SingleImage]: + + if filters is None: + return images + filtered_images = [] + for this_image in images: + if this_image.apply_filter(filters): + filtered_images.append(copy(this_image)) + return filtered_images + + +def filter_images( + *, + dataset_images: list[SingleImage], + dataset_filters: Optional[ScalarDict] = None, + wftask_filters: Optional[ScalarDict] = None, +) -> list[SingleImage]: + + current_filters = copy(dataset_filters) + current_filters.update(wftask_filters) + print(f"[filter_images] Dataset filters:\n{ipjson(dataset_filters)}") + print(f"[filter_images] WorkflowTask filters:\n{ipjson(wftask_filters)}") + print(f"[filter_images] Dataset images:\n{ipjson(dataset_images)}") + print(f"[filter_images] Current selection filters:\n{ipjson(current_filters)}") + filtered_images = _filter_image_list( + dataset_images, + filters=current_filters, + ) + print(f"[filter_images] Filtered image list: {pjson(filtered_images)}") + return filtered_images diff --git a/src/models.py b/src/models.py index 0be2d08..2a8226f 100644 --- a/src/models.py +++ b/src/models.py @@ -3,12 +3,12 @@ from typing import Literal from typing import Optional -from filters import FilterSet +from images import ScalarDict from images import SingleImage from pydantic import BaseModel from pydantic import Field from pydantic import validator - +from task_output import TaskOutput KwargsType = dict[str, Any] @@ -19,7 +19,7 @@ class Dataset(BaseModel): # New in v2 root_dir: str images: list[SingleImage] = Field(default_factory=list) - filters: FilterSet = Field(default_factory=dict) + filters: ScalarDict = Field(default_factory=dict) # Temporary state buffer: Optional[dict[str, Any]] = None parallelization_list: Optional[list[dict[str, Any]]] = None @@ -32,17 +32,20 @@ def image_paths(self) -> list[str]: class Task(BaseModel): - function: Callable # mock of task.command + _function: Callable # mock of task.command meta: dict[str, Any] = Field(default_factory=dict) - new_filters: dict[str, Any] = Field(default_factory=dict) # FIXME: this is not using FilterSet any more! + new_filters: dict[str, Any] = Field(default_factory=dict) # FIXME: this is not using ScalarDict any more! task_type: Literal["non_parallel", "parallel"] = "non_parallel" + + def function(self, **kwargs): + return TaskOutput(self._function(**kwargs)) @validator("new_filters") def scalar_filters(cls, v): """ Check that values of new_filters are all JSON-scalar. - Replacement for `new_filters: FilterSet` attribute type, which + Replacement for `new_filters: ScalarDict` attribute type, which does not work in Pydantic. """ for value in v.values(): @@ -60,7 +63,7 @@ class WorkflowTask(BaseModel): args: dict[str, Any] = Field(default_factory=dict) meta: dict[str, Any] = Field(default_factory=dict) task: Optional[Task] = None - filters: FilterSet = Field(default_factory=dict) + filters: ScalarDict = Field(default_factory=dict) class Workflow(BaseModel): diff --git a/src/runner.py b/src/runner.py index caac951..7e746e1 100644 --- a/src/runner.py +++ b/src/runner.py @@ -1,8 +1,8 @@ from copy import copy from copy import deepcopy -from filters import filter_images -from filters import FilterSet +from images import filter_images +from images import ScalarDict from images import _deduplicate_list_of_dicts from images import find_image_by_path from images import SingleImage @@ -22,7 +22,7 @@ def _apply_attributes_to_image( *, image: SingleImage, - filters: FilterSet, + filters: ScalarDict, ) -> SingleImage: updated_image = copy(image) for key, value in filters.items(): diff --git a/src/task_output.py b/src/task_output.py index 6eb0086..e1af0ae 100644 --- a/src/task_output.py +++ b/src/task_output.py @@ -1,10 +1,9 @@ from typing import Any from typing import Optional -from filters import FilterSet from images import find_image_by_path +from images import ScalarDict from images import SingleImage -from models import KwargsType from pydantic import BaseModel from utils import pjson @@ -16,7 +15,7 @@ class TaskOutput(BaseModel): edited_images: Optional[list[SingleImage]] = None """List of images edited by a given task instance.""" - new_filters: Optional[FilterSet] = None # FIXME: this does not actually work in Pydantic + new_filters: Optional[ScalarDict] = None # FIXME: this does not actually work in Pydantic """ *Global* filters (common to all images) added by this task. @@ -31,7 +30,7 @@ class TaskOutput(BaseModel): companion task. """ - parallelization_list: Optional[list[KwargsType]] = None + parallelization_list: Optional[list[ScalarDict]] = None """ Used in the output of an init task, to expose customizable parallelization of the companion task. @@ -47,7 +46,7 @@ class Config: new_images: Optional[list[SingleImage]] = None edited_images: Optional[list[SingleImage]] = None - new_filters: Optional[FilterSet] = None # FIXME + new_filters: Optional[ScalarDict] = None # FIXME def merge_outputs( diff --git a/src/tasks.py b/src/tasks.py index 6dccff1..27327b7 100644 --- a/src/tasks.py +++ b/src/tasks.py @@ -394,22 +394,22 @@ def init_registration( TASK_LIST = { - "create_ome_zarr": Task(function=create_ome_zarr, task_type="non_parallel"), - "yokogawa_to_zarr": Task(function=yokogawa_to_zarr, task_type="parallel"), - "create_ome_zarr_multiplex": Task(function=create_ome_zarr_multiplex, task_type="non_parallel"), - "cellpose_segmentation": Task(function=cellpose_segmentation, task_type="parallel"), - "new_ome_zarr": Task(function=new_ome_zarr, task_type="non_parallel"), - "copy_data": Task(function=copy_data, task_type="parallel"), + "create_ome_zarr": Task(_function=create_ome_zarr, task_type="non_parallel"), + "yokogawa_to_zarr": Task(_function=yokogawa_to_zarr, task_type="parallel"), + "create_ome_zarr_multiplex": Task(_function=create_ome_zarr_multiplex, task_type="non_parallel"), + "cellpose_segmentation": Task(_function=cellpose_segmentation, task_type="parallel"), + "new_ome_zarr": Task(_function=new_ome_zarr, task_type="non_parallel"), + "copy_data": Task(_function=copy_data, task_type="parallel"), "illumination_correction": Task( - function=illumination_correction, + _function=illumination_correction, task_type="parallel", new_filters=dict(illumination_correction=True), ), "maximum_intensity_projection": Task( - function=maximum_intensity_projection, + _function=maximum_intensity_projection, task_type="parallel", new_filters=dict(data_dimensionality="2"), ), - "init_channel_parallelization": Task(function=init_channel_parallelization, task_type="non_parallel"), - "init_registration": Task(function=init_registration, task_type="non_parallel"), + "init_channel_parallelization": Task(_function=init_channel_parallelization, task_type="non_parallel"), + "init_registration": Task(_function=init_registration, task_type="non_parallel"), } diff --git a/tests/test_apply_workflow.py b/tests/test_apply_workflow.py index 473acc8..95a7a7c 100644 --- a/tests/test_apply_workflow.py +++ b/tests/test_apply_workflow.py @@ -15,7 +15,7 @@ def test_single_non_parallel_task(): WorkflowTask( task=Task( task_type="non_parallel", - function=create_images_from_scratch, + _function=create_images_from_scratch, ), args=dict(new_paths=NEW_PATHS), ) @@ -33,7 +33,7 @@ def test_single_parallel_task_no_parallization_list(): WorkflowTask( task=Task( task_type="parallel", - function=print_path, + _function=print_path, ) ) ] diff --git a/tests/test_filters.py b/tests/test_filters.py index 8c27338..484f22a 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,4 +1,4 @@ -from filters import _filter_image_list +from images import _filter_image_list images = [ dict( diff --git a/tests/test_models.py b/tests/test_models.py index 6229363..3423aa1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,15 +9,15 @@ def _dummy_function(): def test_model_task(): NEW_FILTERS = dict(x=True, y=1, z="asd", w=None) - task = Task(function=_dummy_function, new_filters=NEW_FILTERS) + task = Task(_function=_dummy_function, new_filters=NEW_FILTERS) debug(task.new_filters) debug(NEW_FILTERS) assert task.new_filters == NEW_FILTERS with pytest.raises(ValueError) as e: - Task(function=_dummy_function, new_filters=dict(key=[])) + Task(_function=_dummy_function, new_filters=dict(key=[])) debug(str(e.value)) with pytest.raises(ValueError) as e: - Task(function=_dummy_function, new_filters=dict(key={})) + Task(_function=_dummy_function, new_filters=dict(key={})) debug(str(e.value)) diff --git a/tests/test_runner.py b/tests/test_runner.py index 31716c3..29faab1 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -31,7 +31,7 @@ def test_max_parallelization_list_size(N: int): WorkflowTask( task=Task( task_type="parallel", - function=dummy_task, + _function=dummy_task, ), ) ] @@ -85,7 +85,7 @@ def test_image_attribute_propagation( WorkflowTask( task=Task( task_type="parallel", - function=_copy_and_edit_image, + _function=_copy_and_edit_image, ), ) ]