diff --git a/iohub/cli/cli.py b/iohub/cli/cli.py index 901b92a9..bcf6d556 100644 --- a/iohub/cli/cli.py +++ b/iohub/cli/cli.py @@ -2,7 +2,9 @@ import click +from iohub import open_ome_zarr from iohub._version import __version__ +from iohub.cli.parsing import input_position_dirpaths from iohub.convert import TIFFConverter from iohub.reader import print_info @@ -87,3 +89,66 @@ def convert(input, output, grid_layout, chunks): chunks=chunks, ) converter() + + +@cli.command() +@click.help_option("-h", "--help") +@input_position_dirpaths() +@click.option( + "--t-scale", + "-t", + required=False, + type=float, + help="New t scale", +) +@click.option( + "--z-scale", + "-z", + required=False, + type=float, + help="New z scale", +) +@click.option( + "--y-scale", + "-y", + required=False, + type=float, + help="New y scale", +) +@click.option( + "--x-scale", + "-x", + required=False, + type=float, + help="New x scale", +) +def set_scale( + input_position_dirpaths, + t_scale=None, + z_scale=None, + y_scale=None, + x_scale=None, +): + """Update scale metadata in OME-Zarr datasets. + + >> iohub set-scale -i input.zarr/*/*/* -t 1.0 -z 1.0 -y 0.5 -x 0.5 + + Supports setting a single axis at a time: + + >> iohub set-scale -i input.zarr/*/*/* -z 2.0 + """ + for input_position_dirpath in input_position_dirpaths: + with open_ome_zarr( + input_position_dirpath, layout="fov", mode="r+" + ) as dataset: + for name, value in zip( + ["t", "z", "y", "x"], [t_scale, z_scale, y_scale, x_scale] + ): + if value is None: + continue + old_value = dataset.scale[dataset.get_axis_index(name)] + print( + f"Updating {input_position_dirpath} {name} scale from " + f"{old_value} to {value}." + ) + dataset.set_scale("0", name, value) diff --git a/iohub/cli/parsing.py b/iohub/cli/parsing.py new file mode 100644 index 00000000..aed9319d --- /dev/null +++ b/iohub/cli/parsing.py @@ -0,0 +1,88 @@ +from pathlib import Path +from typing import Callable, List + +import click +from natsort import natsorted + +from iohub.ngff import Plate, open_ome_zarr + + +def _validate_and_process_paths( + ctx: click.Context, opt: click.Option, value: List[str] +) -> list[Path]: + # Sort and validate the input paths, + # expanding plates into lists of positions + input_paths = [Path(path) for path in natsorted(value)] + for path in input_paths: + with open_ome_zarr(path, mode="r") as dataset: + if isinstance(dataset, Plate): + plate_path = input_paths.pop() + for position in dataset.positions(): + input_paths.append(plate_path / position[0]) + + return input_paths + + +def input_position_dirpaths() -> Callable: + def decorator(f: Callable) -> Callable: + return click.option( + "--input-position-dirpaths", + "-i", + cls=OptionEatAll, + type=tuple, + required=True, + callback=_validate_and_process_paths, + help=( + "List of paths to input positions, " + "each with the same TCZYX shape. " + "Supports wildcards e.g. 'input.zarr/*/*/*'." + ), + )(f) + + return decorator + + +# Copied directly from https://stackoverflow.com/a/48394004 +# Enables `-i ./input.zarr/*/*/*` +class OptionEatAll(click.Option): + def __init__(self, *args, **kwargs): + self.save_other_options = kwargs.pop("save_other_options", True) + nargs = kwargs.pop("nargs", -1) + assert nargs == -1, "nargs, if set, must be -1 not {}".format(nargs) + super(OptionEatAll, self).__init__(*args, **kwargs) + self._previous_parser_process = None + self._eat_all_parser = None + + def add_to_parser(self, parser, ctx): + def parser_process(value, state): + # method to hook to the parser.process + done = False + value = [value] + if self.save_other_options: + # grab everything up to the next option + while state.rargs and not done: + for prefix in self._eat_all_parser.prefixes: + if state.rargs[0].startswith(prefix): + done = True + if not done: + value.append(state.rargs.pop(0)) + else: + # grab everything remaining + value += state.rargs + state.rargs[:] = [] + value = tuple(value) + + # call the actual process + self._previous_parser_process(value, state) + + retval = super(OptionEatAll, self).add_to_parser(parser, ctx) + for name in self.opts: + our_parser = parser._long_opt.get(name) or parser._short_opt.get( + name + ) + if our_parser: + self._eat_all_parser = our_parser + self._previous_parser_process = our_parser.process + our_parser.process = parser_process + break + return retval diff --git a/iohub/ngff/nodes.py b/iohub/ngff/nodes.py index 9624baeb..6d63a8e0 100644 --- a/iohub/ngff/nodes.py +++ b/iohub/ngff/nodes.py @@ -976,6 +976,33 @@ def scale(self) -> list[float]: scale = [s1 * s2 for s1, s2 in zip(scale, trans.scale)] return scale + @property + def axis_names(self) -> list[str]: + """ + Helper function for axis names of the highest resolution scale. + + Returns lowercase axis names. + """ + return [ + axis.name.lower() for axis in self.metadata.multiscales[0].axes + ] + + def get_axis_index(self, axis_name: str) -> int: + """ + Get the index of a given axis. + + Parameters + ---------- + name : str + Name of the axis. Case insensitive. + + Returns + ------- + int + Index of the axis. + """ + return self.axis_names.index(axis_name.lower()) + def set_transform( self, image: str | Literal["*"], @@ -1007,6 +1034,59 @@ def set_transform( raise ValueError(f"Key {image} not recognized.") self.dump_meta() + def set_scale( + self, + image: str | Literal["*"], + axis_name: str, + new_scale: float, + ): + """Set the scale for a named axis. + Either one image array or the whole FOV. + + Parameters + ---------- + image : str | Literal[ + Name of one image array (e.g. "0") to transform, + or "*" for the whole FOV + axis_name : str + Name of the axis to set. + new_scale : float + Value of the new scale. + """ + if len(self.metadata.multiscales) > 1: + raise NotImplementedError( + "Cannot set scale for multi-resolution images." + ) + + if new_scale <= 0: + raise ValueError("New scale must be positive.") + + axis_index = self.get_axis_index(axis_name) + + # Append old scale to metadata + iohub_dict = {} + if "iohub" in self.zattrs: + iohub_dict = self.zattrs["iohub"] + iohub_dict.update({f"prior_{axis_name}_scale": self.scale[axis_index]}) + self.zattrs["iohub"] = iohub_dict + + # Update scale while preserving existing transforms + transforms = ( + self.metadata.multiscales[0].datasets[0].coordinate_transformations + ) + # Replace default identity transform with scale + if len(transforms) == 1 and transforms[0].type == "identity": + transforms = [TransformationMeta(type="scale", scale=[1] * 5)] + # Add scale transform if not present + if not any([transform.type == "scale" for transform in transforms]): + transforms.append(TransformationMeta(type="scale", scale=[1] * 5)) + + for transform in transforms: + if transform.type == "scale": + transform.scale[axis_index] = new_scale + + self.set_transform(image, transforms) + class TiledPosition(Position): """Variant of the NGFF position node diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 4f4f4b34..c0dc81f4 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1,9 +1,12 @@ +import random import re +from pathlib import Path from unittest.mock import patch import pytest from click.testing import CliRunner +from iohub import open_ome_zarr from iohub._version import __version__ from iohub.cli.cli import cli from tests.conftest import ( @@ -13,6 +16,8 @@ ndtiff_v3_labeled_positions, ) +from ..ngff.test_ngff import _temp_copy + def pytest_generate_tests(metafunc): if "mm2gamma_ome_tiff" in metafunc.fixturenames: @@ -103,3 +108,66 @@ def test_cli_convert_ome_tiff(grid_layout, tmpdir): result = runner.invoke(cli, cmd) assert result.exit_code == 0, result.output assert "Converting" in result.output + + +def test_cli_set_scale(): + with _temp_copy(hcs_ref) as store_path: + store_path = Path(store_path) + position_path = Path(store_path) / "B" / "03" / "0" + + with open_ome_zarr( + position_path, layout="fov", mode="r+" + ) as input_dataset: + old_scale = input_dataset.scale + + random_z = random.uniform(0, 1) + + runner = CliRunner() + result_pos = runner.invoke( + cli, + [ + "set-scale", + "-i", + str(position_path), + "-z", + random_z, + "-y", + 0.5, + "-x", + 0.5, + ], + ) + assert result_pos.exit_code == 0 + assert "Updating" in result_pos.output + + with open_ome_zarr(position_path, layout="fov") as output_dataset: + assert tuple(output_dataset.scale[-3:]) == (random_z, 0.5, 0.5) + assert output_dataset.scale != old_scale + assert ( + output_dataset.zattrs["iohub"]["prior_x_scale"] + == old_scale[-1] + ) + assert ( + output_dataset.zattrs["iohub"]["prior_y_scale"] + == old_scale[-2] + ) + assert ( + output_dataset.zattrs["iohub"]["prior_z_scale"] + == old_scale[-3] + ) + + # Test plate-expands-into-positions behavior + runner = CliRunner() + result_pos = runner.invoke( + cli, + [ + "set-scale", + "-i", + str(store_path), + "-x", + 0.1, + ], + ) + with open_ome_zarr(position_path, layout="fov") as output_dataset: + assert output_dataset.scale[-1] == 0.1 + assert output_dataset.zattrs["iohub"]["prior_x_scale"] == 0.5 diff --git a/tests/cli/test_parsing.py b/tests/cli/test_parsing.py new file mode 100644 index 00000000..296e537c --- /dev/null +++ b/tests/cli/test_parsing.py @@ -0,0 +1,62 @@ +import click +import numpy as np +from click.testing import CliRunner + +from iohub import open_ome_zarr +from iohub.cli.parsing import OptionEatAll, _validate_and_process_paths + + +def test_validate_and_process_paths(tmpdir): + # Setup plate + plate_path = tmpdir / "dataset.zarr" + position_list = [("A", "1", "0"), ("B", "2", "0"), ("X", "4", "1")] + with open_ome_zarr( + plate_path, mode="w", layout="hcs", channel_names=["1", "2"] + ) as dataset: + for position in position_list: + pos = dataset.create_position(*position) + pos.create_zeros("0", shape=(1, 1, 1, 1, 1), dtype=np.uint8) + + # Setup click + cmd = click.Command("test") + ctx = click.Context(cmd) + opt = click.Option(["--path"], type=click.Path(exists=True)) + + # Check plate expansion + processed = _validate_and_process_paths(ctx, opt, [str(plate_path)]) + assert len(processed) == len(position_list) + for i, position in enumerate(position_list): + assert processed[i].parts[-3:] == position + + # Check single position + processed = _validate_and_process_paths( + ctx, opt, [str(plate_path / "A" / "1" / "0")] + ) + assert len(processed) == 1 + + # Check two positions + processed = _validate_and_process_paths( + ctx, + opt, + [str(plate_path / "A" / "1" / "0"), str(plate_path / "B" / "2" / "0")], + ) + assert len(processed) == 2 + + +def test_option_eat_all(): + @click.command() + @click.option( + "--test", cls=OptionEatAll + ) # tests will fail w/o OptionEatAll + def foo(test): + print(test) + + runner = CliRunner() + result = runner.invoke(foo, ["--test", "a", "b", "c"]) + assert "('a', 'b', 'c')" in result.output + assert "Error" not in result.output + + result = runner.invoke(foo, ["--test", "a"]) + assert "a" in result.output + assert "b" not in result.output + assert "Error" not in result.output diff --git a/tests/ngff/test_ngff.py b/tests/ngff/test_ngff.py index d1701529..50376a3a 100644 --- a/tests/ngff/test_ngff.py +++ b/tests/ngff/test_ngff.py @@ -403,6 +403,38 @@ def test_set_transform_fov(ch_shape_dtype, arr_name): ] +@given( + ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(), +) +def test_set_scale(ch_shape_dtype): + channel_names, shape, dtype = ch_shape_dtype + transform = [ + TransformationMeta(type="translation", translation=(1, 2, 3, 4, 5)), + TransformationMeta(type="scale", scale=(5, 4, 3, 2, 1)), + ] + with TemporaryDirectory() as temp_dir: + store_path = os.path.join(temp_dir, "ome.zarr") + with open_ome_zarr( + store_path, layout="fov", mode="w-", channel_names=channel_names + ) as dataset: + dataset.create_zeros(name="0", shape=shape, dtype=dtype) + dataset.set_transform(image="0", transform=transform) + dataset.set_scale(image="0", axis_name="z", new_scale=10.0) + assert dataset.scale[-3] == 10.0 + assert ( + dataset.metadata.multiscales[0] + .datasets[0] + .coordinate_transformations[0] + .translation[-1] + == 5 + ) + + with pytest.raises(ValueError): + dataset.set_scale(image="0", axis_name="z", new_scale=-1.0) + + assert dataset.zattrs["iohub"]["prior_z_scale"] == 3.0 + + @given(channel_names=channel_names_st) @settings(max_examples=16) def test_create_tiled(channel_names): @@ -560,6 +592,22 @@ def test_get_channel_index(wrong_channel_name): _ = dataset.get_channel_index(wrong_channel_name) +def test_get_axis_index(): + with open_ome_zarr(hcs_ref, layout="hcs", mode="r+") as dataset: + position = dataset["B/03/0"] + + assert position.axis_names == ["c", "z", "y", "x"] + + assert position.get_axis_index("z") == 1 + assert position.get_axis_index("Z") == 1 + + with pytest.raises(ValueError): + _ = position.get_axis_index("t") + + with pytest.raises(ValueError): + _ = position.get_axis_index("DOG") + + @given( row=short_alpha_numeric, col=short_alpha_numeric, pos=short_alpha_numeric )