Skip to content

Commit

Permalink
set-scale utility (#228)
Browse files Browse the repository at this point in the history
* add update_scale_metadata.py

* add update-scale-metadata command to cli

* add zyx flags for update-scale-metadata utility

* add handling of missing zyx cli flags for update-scale-metadata

* update order of params in update_scale_metadata to be passed as z, y, x

* black

* isort

* flake8

* update to `iohub.ngff.models`

* move parsing utilities to iohub

* move update_scale_metadata inside cli folder

* typo

* update import for refactor

* require -z, -y, -x flags

* simplify interface and print statements

* clean up print statement

* update the last three dimensions for OME compatibility

* fix tests

* fix test

* helper functions for axis names

* set_scale API

* consolidate and clean CLI

* test get_axis_index

* test_set_scale

* case insensitive axis name

* tests don't overwrite data

* save old metadata in a namespace

* test multiple inputs to cli

* handle empty current_transforms

* improved empty handling

* unit test CLI plate expansion into positions

* cleanup test

* test OptionEatAll

* stronger plate-expansion test

* fix bug when scale transform does not exist

* create "iohub" dict if it doesn't exist, and update it

* test old_* metadata

* rename old_x to prior_x_scale

---------

Co-authored-by: Talon Chandler <[email protected]>
Co-authored-by: Ivan Ivanov <[email protected]>
  • Loading branch information
3 people authored Sep 26, 2024
1 parent 5a0c90b commit 5e223ec
Show file tree
Hide file tree
Showing 6 changed files with 411 additions and 0 deletions.
65 changes: 65 additions & 0 deletions iohub/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import click

from iohub import open_ome_zarr
from iohub._version import __version__
from iohub.cli.parsing import input_position_dirpaths
from iohub.convert import TIFFConverter
from iohub.reader import print_info

Expand Down Expand Up @@ -87,3 +89,66 @@ def convert(input, output, grid_layout, chunks):
chunks=chunks,
)
converter()


@cli.command()
@click.help_option("-h", "--help")
@input_position_dirpaths()
@click.option(
"--t-scale",
"-t",
required=False,
type=float,
help="New t scale",
)
@click.option(
"--z-scale",
"-z",
required=False,
type=float,
help="New z scale",
)
@click.option(
"--y-scale",
"-y",
required=False,
type=float,
help="New y scale",
)
@click.option(
"--x-scale",
"-x",
required=False,
type=float,
help="New x scale",
)
def set_scale(
input_position_dirpaths,
t_scale=None,
z_scale=None,
y_scale=None,
x_scale=None,
):
"""Update scale metadata in OME-Zarr datasets.
>> iohub set-scale -i input.zarr/*/*/* -t 1.0 -z 1.0 -y 0.5 -x 0.5
Supports setting a single axis at a time:
>> iohub set-scale -i input.zarr/*/*/* -z 2.0
"""
for input_position_dirpath in input_position_dirpaths:
with open_ome_zarr(
input_position_dirpath, layout="fov", mode="r+"
) as dataset:
for name, value in zip(
["t", "z", "y", "x"], [t_scale, z_scale, y_scale, x_scale]
):
if value is None:
continue
old_value = dataset.scale[dataset.get_axis_index(name)]
print(
f"Updating {input_position_dirpath} {name} scale from "
f"{old_value} to {value}."
)
dataset.set_scale("0", name, value)
88 changes: 88 additions & 0 deletions iohub/cli/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import Path
from typing import Callable, List

import click
from natsort import natsorted

from iohub.ngff import Plate, open_ome_zarr


def _validate_and_process_paths(
ctx: click.Context, opt: click.Option, value: List[str]
) -> list[Path]:
# Sort and validate the input paths,
# expanding plates into lists of positions
input_paths = [Path(path) for path in natsorted(value)]
for path in input_paths:
with open_ome_zarr(path, mode="r") as dataset:
if isinstance(dataset, Plate):
plate_path = input_paths.pop()
for position in dataset.positions():
input_paths.append(plate_path / position[0])

return input_paths


def input_position_dirpaths() -> Callable:
def decorator(f: Callable) -> Callable:
return click.option(
"--input-position-dirpaths",
"-i",
cls=OptionEatAll,
type=tuple,
required=True,
callback=_validate_and_process_paths,
help=(
"List of paths to input positions, "
"each with the same TCZYX shape. "
"Supports wildcards e.g. 'input.zarr/*/*/*'."
),
)(f)

return decorator


# Copied directly from https://stackoverflow.com/a/48394004
# Enables `-i ./input.zarr/*/*/*`
class OptionEatAll(click.Option):
def __init__(self, *args, **kwargs):
self.save_other_options = kwargs.pop("save_other_options", True)
nargs = kwargs.pop("nargs", -1)
assert nargs == -1, "nargs, if set, must be -1 not {}".format(nargs)
super(OptionEatAll, self).__init__(*args, **kwargs)
self._previous_parser_process = None
self._eat_all_parser = None

def add_to_parser(self, parser, ctx):
def parser_process(value, state):
# method to hook to the parser.process
done = False
value = [value]
if self.save_other_options:
# grab everything up to the next option
while state.rargs and not done:
for prefix in self._eat_all_parser.prefixes:
if state.rargs[0].startswith(prefix):
done = True
if not done:
value.append(state.rargs.pop(0))
else:
# grab everything remaining
value += state.rargs
state.rargs[:] = []
value = tuple(value)

# call the actual process
self._previous_parser_process(value, state)

retval = super(OptionEatAll, self).add_to_parser(parser, ctx)
for name in self.opts:
our_parser = parser._long_opt.get(name) or parser._short_opt.get(
name
)
if our_parser:
self._eat_all_parser = our_parser
self._previous_parser_process = our_parser.process
our_parser.process = parser_process
break
return retval
80 changes: 80 additions & 0 deletions iohub/ngff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,33 @@ def scale(self) -> list[float]:
scale = [s1 * s2 for s1, s2 in zip(scale, trans.scale)]
return scale

@property
def axis_names(self) -> list[str]:
"""
Helper function for axis names of the highest resolution scale.
Returns lowercase axis names.
"""
return [
axis.name.lower() for axis in self.metadata.multiscales[0].axes
]

def get_axis_index(self, axis_name: str) -> int:
"""
Get the index of a given axis.
Parameters
----------
name : str
Name of the axis. Case insensitive.
Returns
-------
int
Index of the axis.
"""
return self.axis_names.index(axis_name.lower())

def set_transform(
self,
image: str | Literal["*"],
Expand Down Expand Up @@ -1007,6 +1034,59 @@ def set_transform(
raise ValueError(f"Key {image} not recognized.")
self.dump_meta()

def set_scale(
self,
image: str | Literal["*"],
axis_name: str,
new_scale: float,
):
"""Set the scale for a named axis.
Either one image array or the whole FOV.
Parameters
----------
image : str | Literal[
Name of one image array (e.g. "0") to transform,
or "*" for the whole FOV
axis_name : str
Name of the axis to set.
new_scale : float
Value of the new scale.
"""
if len(self.metadata.multiscales) > 1:
raise NotImplementedError(
"Cannot set scale for multi-resolution images."
)

if new_scale <= 0:
raise ValueError("New scale must be positive.")

axis_index = self.get_axis_index(axis_name)

# Append old scale to metadata
iohub_dict = {}
if "iohub" in self.zattrs:
iohub_dict = self.zattrs["iohub"]
iohub_dict.update({f"prior_{axis_name}_scale": self.scale[axis_index]})
self.zattrs["iohub"] = iohub_dict

# Update scale while preserving existing transforms
transforms = (
self.metadata.multiscales[0].datasets[0].coordinate_transformations
)
# Replace default identity transform with scale
if len(transforms) == 1 and transforms[0].type == "identity":
transforms = [TransformationMeta(type="scale", scale=[1] * 5)]
# Add scale transform if not present
if not any([transform.type == "scale" for transform in transforms]):
transforms.append(TransformationMeta(type="scale", scale=[1] * 5))

for transform in transforms:
if transform.type == "scale":
transform.scale[axis_index] = new_scale

self.set_transform(image, transforms)


class TiledPosition(Position):
"""Variant of the NGFF position node
Expand Down
68 changes: 68 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import random
import re
from pathlib import Path
from unittest.mock import patch

import pytest
from click.testing import CliRunner

from iohub import open_ome_zarr
from iohub._version import __version__
from iohub.cli.cli import cli
from tests.conftest import (
Expand All @@ -13,6 +16,8 @@
ndtiff_v3_labeled_positions,
)

from ..ngff.test_ngff import _temp_copy


def pytest_generate_tests(metafunc):
if "mm2gamma_ome_tiff" in metafunc.fixturenames:
Expand Down Expand Up @@ -103,3 +108,66 @@ def test_cli_convert_ome_tiff(grid_layout, tmpdir):
result = runner.invoke(cli, cmd)
assert result.exit_code == 0, result.output
assert "Converting" in result.output


def test_cli_set_scale():
with _temp_copy(hcs_ref) as store_path:
store_path = Path(store_path)
position_path = Path(store_path) / "B" / "03" / "0"

with open_ome_zarr(
position_path, layout="fov", mode="r+"
) as input_dataset:
old_scale = input_dataset.scale

random_z = random.uniform(0, 1)

runner = CliRunner()
result_pos = runner.invoke(
cli,
[
"set-scale",
"-i",
str(position_path),
"-z",
random_z,
"-y",
0.5,
"-x",
0.5,
],
)
assert result_pos.exit_code == 0
assert "Updating" in result_pos.output

with open_ome_zarr(position_path, layout="fov") as output_dataset:
assert tuple(output_dataset.scale[-3:]) == (random_z, 0.5, 0.5)
assert output_dataset.scale != old_scale
assert (
output_dataset.zattrs["iohub"]["prior_x_scale"]
== old_scale[-1]
)
assert (
output_dataset.zattrs["iohub"]["prior_y_scale"]
== old_scale[-2]
)
assert (
output_dataset.zattrs["iohub"]["prior_z_scale"]
== old_scale[-3]
)

# Test plate-expands-into-positions behavior
runner = CliRunner()
result_pos = runner.invoke(
cli,
[
"set-scale",
"-i",
str(store_path),
"-x",
0.1,
],
)
with open_ome_zarr(position_path, layout="fov") as output_dataset:
assert output_dataset.scale[-1] == 0.1
assert output_dataset.zattrs["iohub"]["prior_x_scale"] == 0.5
Loading

0 comments on commit 5e223ec

Please sign in to comment.