Skip to content

Commit

Permalink
Rework Z-range handling to use deskewed max, and to update automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
multimeric committed Nov 20, 2023
1 parent 5addbcc commit e36bcfd
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 52 deletions.
4 changes: 2 additions & 2 deletions core/lls_core/cmds/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pydantic import ValidationError

if TYPE_CHECKING:
from lls_core.models.utils import FieldAccessMixin
from lls_core.models.utils import FieldAccessModel
from typing import Type, Any
from rich.table import Table

Expand Down Expand Up @@ -55,7 +55,7 @@ class CliDeskewDirection(StrEnum):

app = Typer(add_completion=False, rich_markup_mode="rich", no_args_is_help=True)

def field_from_model(model: Type[FieldAccessMixin], field_name: str, extra_description: str = "", description: Optional[str] = None, default: Optional[Any] = None, **kwargs) -> Any:
def field_from_model(model: Type[FieldAccessModel], field_name: str, extra_description: str = "", description: Optional[str] = None, default: Optional[Any] = None, **kwargs) -> Any:
"""
Generates a type Field from a Pydantic model field
"""
Expand Down
9 changes: 6 additions & 3 deletions core/lls_core/models/crop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Iterable, List, Tuple, Any
from pydantic import Field, NonNegativeInt, validator
from lls_core.models.utils import FieldAccessMixin
from lls_core.models.utils import FieldAccessModel
from lls_core.cropping import Roi

class CropParams(FieldAccessMixin):
class CropParams(FieldAccessModel):
"""
Parameters for the optional cropping step
Parameters for the optional cropping step.
Note that cropping is performed in the space of the deskewed shape.
This is to support the workflow of performing a preview deskew and using that
to calculate the cropping coordinates.
"""
roi_list: List[Roi] = Field(
description="List of regions of interest, each of which must be an NxD array, where N is the number of vertices and D the coordinates of each vertex.",
Expand Down
4 changes: 2 additions & 2 deletions core/lls_core/models/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from xarray import DataArray

from lls_core import DeconvolutionChoice
from lls_core.models.utils import enum_choices, FieldAccessMixin
from lls_core.models.utils import enum_choices, FieldAccessModel

from lls_core.types import image_like_to_image, ImageLike

Background = Union[float, Literal["auto", "second_last"]]
class DeconvolutionParams(FieldAccessMixin):
class DeconvolutionParams(FieldAccessModel):
"""
Parameters for the optional deconvolution step
"""
Expand Down
69 changes: 42 additions & 27 deletions core/lls_core/models/deskew.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
# class for initializing lattice data and setting metadata
# TODO: handle scenes
from pydantic import Field, NonNegativeFloat, validator, root_validator
from pydantic import Field, NonNegativeFloat, validator

from typing import Any, Tuple
from typing_extensions import Self, TYPE_CHECKING
Expand All @@ -11,16 +11,14 @@
from lls_core import DeskewDirection
from xarray import DataArray

from lls_core.models.utils import FieldAccessMixin, enum_choices
from lls_core.models.utils import FieldAccessModel, enum_choices
from lls_core.types import image_like_to_image
from lls_core.utils import get_deskewed_shape

if TYPE_CHECKING:
from aicsimageio.types import PhysicalPixelSizes

# DeskewDirection = Literal["X", "Y"]

class DefinedPixelSizes(FieldAccessMixin):
class DefinedPixelSizes(FieldAccessModel):
"""
Like PhysicalPixelSizes, but it's a dataclass, and
none of its fields are None
Expand All @@ -39,8 +37,21 @@ def from_physical(cls, pixels: PhysicalPixelSizes) -> Self:
Z=raise_if_none(pixels.Z, "All pixels must be defined"),
)

class DerivedDeskewFields(FieldAccessModel):
"""
Fields that are automatically calculated based on other fields in DeskewParams.
Grouping these together into one model makes validation simpler.
"""
deskew_vol_shape: Tuple[int, ...] = Field(
init_var=False,
default=None,
description="Dimensions of the deskewed output. This is set automatically based on other input parameters, and doesn't need to be provided by the user."
)

deskew_affine_transform: cle.AffineTransform3D = Field(init_var=False, default=None, description="Deskewing transformation function. This is set automatically based on other input parameters, and doesn't need to be provided by the user.")


class DeskewParams(FieldAccessMixin):
class DeskewParams(FieldAccessModel):
input_image: DataArray = Field(
description="A 3-5D array containing the image data."
)
Expand All @@ -53,17 +64,14 @@ class DeskewParams(FieldAccessMixin):
description="Angle of deskewing, in degrees."
)
physical_pixel_sizes: DefinedPixelSizes = Field(
default_factory=DefinedPixelSizes,
description="Pixel size of the microscope, in microns."
)
deskew_vol_shape: Tuple[int, ...] = Field(
default_factory=DefinedPixelSizes,
description="Pixel size of the microscope, in microns."
)
derived: DerivedDeskewFields = Field(
init_var=False,
default=None,
description="Dimensions of the deskewed output. This is set automatically based on other input parameters, and doesn't need to be provided by the user."
description="Refer to the DerivedDeskewFields docstring"
)

deskew_affine_transform: cle.AffineTransform3D = Field(init_var=False, default=None, description="Deskewing transformation function. This is set automatically based on other input parameters, and doesn't need to be provided by the user.")

# Hack to ensure that .skew_dir behaves identically to .skew
@property
def skew_dir(self) -> DeskewDirection:
Expand Down Expand Up @@ -170,19 +178,26 @@ def reshaping(cls, v: Any):
def get_3d_slice(self) -> DataArray:
return self.input_image.isel(C=0, T=0)

@root_validator(pre=False)
def set_deskew(cls, values: dict) -> dict:
@validator("derived", always=True)
def calculate_derived(cls, v: Any, values: dict) -> DerivedDeskewFields:
"""
Sets the default deskew shape values if the user has not provided them
"""
# process the file to get shape of final deskewed image
if "input_image" not in values:
return values
data: DataArray = cls.reshaping(values["input_image"])
if values.get('deskew_vol_shape') is None:
if values.get('deskew_affine_transform') is None:
# If neither has been set, calculate them ourselves
values["deskew_vol_shape"], values["deskew_affine_transform"] = get_deskewed_shape(data.isel(C=0, T=0).to_numpy(), values["angle"], values["physical_pixel_sizes"].X, values["physical_pixel_sizes"].Y, values["physical_pixel_sizes"].Z, values["skew"])
else:
raise ValueError("deskew_vol_shape and deskew_affine_transform must be either both specified or neither specified")
return values
data: DataArray = values["input_image"]
if isinstance(v, DerivedDeskewFields):
return v
elif v is None:
deskew_vol_shape, deskew_affine_transform = get_deskewed_shape(
data.isel(C=0, T=0).to_numpy(),
values["angle"],
values["physical_pixel_sizes"].X,
values["physical_pixel_sizes"].Y,
values["physical_pixel_sizes"].Z,
values["skew"]
)
return DerivedDeskewFields(
deskew_affine_transform=deskew_affine_transform,
deskew_vol_shape=deskew_vol_shape
)
else:
raise ValueError("Invalid derived fields")
9 changes: 5 additions & 4 deletions core/lls_core/models/lattice_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ def parse_workflow(cls, v: Any):
return v

@validator("crop")
def default_z_range(cls, v: CropParams, values: dict):
def default_z_range(cls, v: Optional[CropParams], values: dict) -> Optional[CropParams]:
if v is None:
return v
with ignore_keyerror():
# Fill in missing parts of the z range with the min/max z values
# Fill in missing parts of the z range
# The max allowed value is the length of the deskew Z axis
default_start = 0
default_end = values["input_image"].sizes["Z"]
default_end = values["derived"].deskew_vol_shape[0]

# Set defaults
if v.z_range is None:
Expand All @@ -110,7 +111,7 @@ def default_z_range(cls, v: CropParams, values: dict):
if v.z_range[0] < default_start:
raise ValueError(f"The z-index start of {v.z_range[0]} is outside the size of the z-axis")

return v
return v

@validator("time_range", pre=True, always=True)
def parse_time_range(cls, v: Any, values: dict) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions core/lls_core/models/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

from lls_core.models.utils import FieldAccessMixin, enum_choices
from lls_core.models.utils import FieldAccessModel, enum_choices

if TYPE_CHECKING:
pass
Expand All @@ -14,7 +14,7 @@ class SaveFileType(StrEnum):
h5 = "h5"
tiff = "tiff"

class OutputParams(FieldAccessMixin):
class OutputParams(FieldAccessModel):
save_dir: DirectoryPath = Field(
description="The directory where the output data will be saved"
)
Expand Down
4 changes: 2 additions & 2 deletions core/lls_core/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def ignore_keyerror():
except KeyError:
pass

class FieldAccessMixin(BaseModel):
class FieldAccessModel(BaseModel):
"""
Adds methods to a BaseModel for accessing useful field information
"""
Expand Down Expand Up @@ -54,7 +54,7 @@ def to_definition_dict(cls) -> dict:
"""
ret = {}
for key, value in cls.__fields__.items():
if isinstance(value.outer_type_, type) and issubclass(value.outer_type_, FieldAccessMixin):
if isinstance(value.outer_type_, type) and issubclass(value.outer_type_, FieldAccessModel):
value = value.outer_type_.to_definition_dict()
else:
value = value.field_info.description
Expand Down
5 changes: 5 additions & 0 deletions core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
def runner() -> CliRunner:
return CliRunner()

@pytest.fixture
def rbc_tiny():
with as_file(resources / "RBC_tiny.czi") as image_path:
yield image_path

@pytest.fixture(params=[
"RBC_tiny.czi",
"RBC_lattice.tif",
Expand Down
20 changes: 13 additions & 7 deletions core/tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from pathlib import Path
from lls_core.models.crop import CropParams
from lls_core.models.lattice_data import LatticeData
from lls_core.sample import resources
from importlib_resources import as_file


def test_default_save_dir():
def test_default_save_dir(rbc_tiny: Path):
# Test that the save dir is inferred to be the input dir
with as_file(resources / "RBC_tiny.czi") as path:
params = LatticeData(input_image=path)
assert params.save_dir == path.parent
params = LatticeData(input_image=rbc_tiny)
assert params.save_dir == rbc_tiny.parent

def test_auto_z_range(rbc_tiny: Path):
# Tests that the Z range is automatically set, and it is set
# based on the size of the deskewed volume
params = LatticeData(input_image=rbc_tiny, crop=CropParams(
roi_list=[]
))
assert params.crop.z_range == (0, 59)
19 changes: 16 additions & 3 deletions plugin/napari_lattice/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# FieldGroups that the users interact with to input data
import logging
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING
from typing_extensions import TypeVar
import pyclesperanto_prototype as cle
Expand Down Expand Up @@ -396,6 +397,10 @@ def _make_model(self) -> Optional[DeconvolutionParams]:
@magicclass
class CroppingFields(NapariFieldGroup):
# A counterpart to the CropParams Pydantic class
header = field(dedent("""
Note that all cropping, including the regions of interest and Z range, is performed in the space of the deskewed shape.
This is to support the workflow of performing a preview deskew and using that to calculate the cropping coordinates.
"""), widget_type="Label")
fields_enabled = field(False, label="Enabled")
shapes= field(List[Shapes], widget_type="Select", label = "ROI Shape Layers").with_options(choices=lambda _x, _y: get_layers(Shapes))
z_range = field(Tuple[int, int]).with_options(
Expand Down Expand Up @@ -432,9 +437,15 @@ def new_crop_layer(self):
shapes.name = "Napari Lattice Crop"

def _on_image_changed(self, img: DataArray):
# Update the maximum Z
deskew = self._get_deskew()
deskewed_zmax = deskew.derived.deskew_vol_shape[0]

# Update the allowed Z based the deskewed shape
for widget in self.z_range:
adjust_maximum(widget, img.sizes["Z"])
adjust_maximum(widget, deskewed_zmax)

# Update the current max value to be the max of the shape
self.z_range[1].value = deskewed_zmax

@fields_enabled.connect
@enable_if([shapes, z_range])
Expand All @@ -445,8 +456,10 @@ def _make_model(self) -> Optional[CropParams]:
import numpy as np
if self.fields_enabled.value:
return CropParams(
# Convert from the input image space to the deskewed image space
# We assume here that dx == dy which isn't ideal
roi_list=ShapesData([np.array(shape.data) / self._get_deskew().dy for shape in self.shapes.value]),
z_range=self.z_range.value,
z_range=tuple(np.array(self.z_range.value) / self._get_deskew().new_dz),
)
return None

Expand Down

0 comments on commit e36bcfd

Please sign in to comment.