Skip to content
This repository has been archived by the owner on Apr 5, 2024. It is now read-only.

Commit

Permalink
define SingleImage and ScalarDict classes
Browse files Browse the repository at this point in the history
  • Loading branch information
ychiucco committed Feb 19, 2024
1 parent f816251 commit d8621a9
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 98 deletions.
58 changes: 0 additions & 58 deletions src/filters.py

This file was deleted.

77 changes: 70 additions & 7 deletions src/images.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,43 @@
# Example image
# image = {"path": "/tmp/asasd", "dimensions": 3}
# Example filters
# filters = {"dimensions": 2, "illumination_corrected": False}
from copy import copy
from typing import Union

from typing import Optional
from pydantic import BaseModel
from pydantic import Field
from utils import ipjson, pjson

ImageAttribute = Union[str, bool, int, None] # a scalar JSON object
SingleImage = dict[str, ImageAttribute]

def check_key_value(key, value):
if not isinstance(key, str):
raise TypeError("Key must be a string")
if not isinstance(value, (int, float, str, bool, type(None))):
raise ValueError(
"Value must be a scalar (int, float, str, bool, or None)"
)

class ScalarDict(dict):

def __init__(self, **kwargs):
for key, value in kwargs.items():
check_key_value(key=key, value=value)
super().__init__(**kwargs)

def __setitem__(self, key, value):
check_key_value(key=key, value=value)
super().__setitem__(key, value)

class SingleImage(BaseModel):
path: str
attributes: ScalarDict = Field(default_factory=ScalarDict)

def apply_filter(self, filters: ScalarDict):
for key, value in filters.items():
if value is None:
continue
if self.attributes.get(key) != value:
return False
return True

def find_image_by_path(
*,
images: list[SingleImage],
Expand All @@ -26,7 +54,7 @@ def find_image_by_path(
The first image from `images` which has path equal to `path`.
"""
try:
image = next(image for image in images if image["path"] == path)
image = next(image for image in images if image.path == path)
return copy(image)
except StopIteration:
raise ValueError(f"No image with {path=} found in image list.")
Expand All @@ -41,3 +69,38 @@ def _deduplicate_list_of_dicts(list_of_dicts: list[dict]) -> list[dict]:
if my_dict not in new_list_of_dicts:
new_list_of_dicts.append(my_dict)
return new_list_of_dicts


def _filter_image_list(
images: list[SingleImage],
filters: Optional[ScalarDict] = None,
) -> list[SingleImage]:

if filters is None:
return images
filtered_images = []
for this_image in images:
if this_image.apply_filter(filters):
filtered_images.append(copy(this_image))
return filtered_images


def filter_images(
*,
dataset_images: list[SingleImage],
dataset_filters: Optional[ScalarDict] = None,
wftask_filters: Optional[ScalarDict] = None,
) -> list[SingleImage]:

current_filters = copy(dataset_filters)
current_filters.update(wftask_filters)
print(f"[filter_images] Dataset filters:\n{ipjson(dataset_filters)}")
print(f"[filter_images] WorkflowTask filters:\n{ipjson(wftask_filters)}")
print(f"[filter_images] Dataset images:\n{ipjson(dataset_images)}")
print(f"[filter_images] Current selection filters:\n{ipjson(current_filters)}")
filtered_images = _filter_image_list(
dataset_images,
filters=current_filters,
)
print(f"[filter_images] Filtered image list: {pjson(filtered_images)}")
return filtered_images
17 changes: 10 additions & 7 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Literal
from typing import Optional

from filters import FilterSet
from images import ScalarDict
from images import SingleImage
from pydantic import BaseModel
from pydantic import Field
from pydantic import validator

from task_output import TaskOutput

KwargsType = dict[str, Any]

Expand All @@ -19,7 +19,7 @@ class Dataset(BaseModel):
# New in v2
root_dir: str
images: list[SingleImage] = Field(default_factory=list)
filters: FilterSet = Field(default_factory=dict)
filters: ScalarDict = Field(default_factory=dict)
# Temporary state
buffer: Optional[dict[str, Any]] = None
parallelization_list: Optional[list[dict[str, Any]]] = None
Expand All @@ -32,17 +32,20 @@ def image_paths(self) -> list[str]:


class Task(BaseModel):
function: Callable # mock of task.command
_function: Callable # mock of task.command
meta: dict[str, Any] = Field(default_factory=dict)
new_filters: dict[str, Any] = Field(default_factory=dict) # FIXME: this is not using FilterSet any more!
new_filters: dict[str, Any] = Field(default_factory=dict) # FIXME: this is not using ScalarDict any more!
task_type: Literal["non_parallel", "parallel"] = "non_parallel"

def function(self, **kwargs):
return TaskOutput(self._function(**kwargs))

@validator("new_filters")
def scalar_filters(cls, v):
"""
Check that values of new_filters are all JSON-scalar.
Replacement for `new_filters: FilterSet` attribute type, which
Replacement for `new_filters: ScalarDict` attribute type, which
does not work in Pydantic.
"""
for value in v.values():
Expand All @@ -60,7 +63,7 @@ class WorkflowTask(BaseModel):
args: dict[str, Any] = Field(default_factory=dict)
meta: dict[str, Any] = Field(default_factory=dict)
task: Optional[Task] = None
filters: FilterSet = Field(default_factory=dict)
filters: ScalarDict = Field(default_factory=dict)


class Workflow(BaseModel):
Expand Down
6 changes: 3 additions & 3 deletions src/runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from copy import copy
from copy import deepcopy

from filters import filter_images
from filters import FilterSet
from images import filter_images
from images import ScalarDict
from images import _deduplicate_list_of_dicts
from images import find_image_by_path
from images import SingleImage
Expand All @@ -22,7 +22,7 @@
def _apply_attributes_to_image(
*,
image: SingleImage,
filters: FilterSet,
filters: ScalarDict,
) -> SingleImage:
updated_image = copy(image)
for key, value in filters.items():
Expand Down
9 changes: 4 additions & 5 deletions src/task_output.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any
from typing import Optional

from filters import FilterSet
from images import find_image_by_path
from images import ScalarDict
from images import SingleImage
from models import KwargsType
from pydantic import BaseModel
from utils import pjson

Expand All @@ -16,7 +15,7 @@ class TaskOutput(BaseModel):
edited_images: Optional[list[SingleImage]] = None
"""List of images edited by a given task instance."""

new_filters: Optional[FilterSet] = None # FIXME: this does not actually work in Pydantic
new_filters: Optional[ScalarDict] = None # FIXME: this does not actually work in Pydantic
"""
*Global* filters (common to all images) added by this task.
Expand All @@ -31,7 +30,7 @@ class TaskOutput(BaseModel):
companion task.
"""

parallelization_list: Optional[list[KwargsType]] = None
parallelization_list: Optional[list[ScalarDict]] = None
"""
Used in the output of an init task, to expose customizable parallelization
of the companion task.
Expand All @@ -47,7 +46,7 @@ class Config:

new_images: Optional[list[SingleImage]] = None
edited_images: Optional[list[SingleImage]] = None
new_filters: Optional[FilterSet] = None # FIXME
new_filters: Optional[ScalarDict] = None # FIXME


def merge_outputs(
Expand Down
20 changes: 10 additions & 10 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,22 +394,22 @@ def init_registration(


TASK_LIST = {
"create_ome_zarr": Task(function=create_ome_zarr, task_type="non_parallel"),
"yokogawa_to_zarr": Task(function=yokogawa_to_zarr, task_type="parallel"),
"create_ome_zarr_multiplex": Task(function=create_ome_zarr_multiplex, task_type="non_parallel"),
"cellpose_segmentation": Task(function=cellpose_segmentation, task_type="parallel"),
"new_ome_zarr": Task(function=new_ome_zarr, task_type="non_parallel"),
"copy_data": Task(function=copy_data, task_type="parallel"),
"create_ome_zarr": Task(_function=create_ome_zarr, task_type="non_parallel"),
"yokogawa_to_zarr": Task(_function=yokogawa_to_zarr, task_type="parallel"),
"create_ome_zarr_multiplex": Task(_function=create_ome_zarr_multiplex, task_type="non_parallel"),
"cellpose_segmentation": Task(_function=cellpose_segmentation, task_type="parallel"),
"new_ome_zarr": Task(_function=new_ome_zarr, task_type="non_parallel"),
"copy_data": Task(_function=copy_data, task_type="parallel"),
"illumination_correction": Task(
function=illumination_correction,
_function=illumination_correction,
task_type="parallel",
new_filters=dict(illumination_correction=True),
),
"maximum_intensity_projection": Task(
function=maximum_intensity_projection,
_function=maximum_intensity_projection,
task_type="parallel",
new_filters=dict(data_dimensionality="2"),
),
"init_channel_parallelization": Task(function=init_channel_parallelization, task_type="non_parallel"),
"init_registration": Task(function=init_registration, task_type="non_parallel"),
"init_channel_parallelization": Task(_function=init_channel_parallelization, task_type="non_parallel"),
"init_registration": Task(_function=init_registration, task_type="non_parallel"),
}
4 changes: 2 additions & 2 deletions tests/test_apply_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_single_non_parallel_task():
WorkflowTask(
task=Task(
task_type="non_parallel",
function=create_images_from_scratch,
_function=create_images_from_scratch,
),
args=dict(new_paths=NEW_PATHS),
)
Expand All @@ -33,7 +33,7 @@ def test_single_parallel_task_no_parallization_list():
WorkflowTask(
task=Task(
task_type="parallel",
function=print_path,
_function=print_path,
)
)
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from filters import _filter_image_list
from images import _filter_image_list

images = [
dict(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ def _dummy_function():

def test_model_task():
NEW_FILTERS = dict(x=True, y=1, z="asd", w=None)
task = Task(function=_dummy_function, new_filters=NEW_FILTERS)
task = Task(_function=_dummy_function, new_filters=NEW_FILTERS)
debug(task.new_filters)
debug(NEW_FILTERS)
assert task.new_filters == NEW_FILTERS

with pytest.raises(ValueError) as e:
Task(function=_dummy_function, new_filters=dict(key=[]))
Task(_function=_dummy_function, new_filters=dict(key=[]))
debug(str(e.value))

with pytest.raises(ValueError) as e:
Task(function=_dummy_function, new_filters=dict(key={}))
Task(_function=_dummy_function, new_filters=dict(key={}))
debug(str(e.value))
4 changes: 2 additions & 2 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_max_parallelization_list_size(N: int):
WorkflowTask(
task=Task(
task_type="parallel",
function=dummy_task,
_function=dummy_task,
),
)
]
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_image_attribute_propagation(
WorkflowTask(
task=Task(
task_type="parallel",
function=_copy_and_edit_image,
_function=_copy_and_edit_image,
),
)
]
Expand Down

0 comments on commit d8621a9

Please sign in to comment.