diff --git a/fractal_tasks_core/tasks/_registration_utils.py b/fractal_tasks_core/tasks/_registration_utils.py index d74744fd7..edfcf75d4 100644 --- a/fractal_tasks_core/tasks/_registration_utils.py +++ b/fractal_tasks_core/tasks/_registration_utils.py @@ -12,16 +12,7 @@ import pandas as pd from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta - - -def _split_well_path_image_path(zarr_url: str) -> tuple[str, str]: - """ - Returns path to well folder for HCS OME-Zarr `zarr_url`. - """ - zarr_url = zarr_url.rstrip("/") - well_path = "/".join(zarr_url.split("/")[:-1]) - img_path = zarr_url.split("/")[-1] - return well_path, img_path +from fractal_tasks_core.tasks._zarr_utils import _split_well_path_image_path def create_well_acquisition_dict( diff --git a/fractal_tasks_core/tasks/_zarr_utils.py b/fractal_tasks_core/tasks/_zarr_utils.py new file mode 100644 index 000000000..7575f1593 --- /dev/null +++ b/fractal_tasks_core/tasks/_zarr_utils.py @@ -0,0 +1,104 @@ +import copy + +import zarr +from filelock import FileLock + +from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta + + +def _copy_hcs_ome_zarr_metadata( + zarr_url_origin: str, + zarr_url_new: str, +) -> None: + """ + Updates the necessary metadata for a new copy of an OME-Zarr image + + Based on an existing OME-Zarr image in the same well, the metadata is + copied and added to the new zarr well. Additionally, the well-level + metadata is updated to include this new image. + + Args: + zarr_url_origin: zarr_url of the origin image + zarr_url_new: zarr_url of the newly created image. The zarr-group + already needs to exist, but metadata is written by this function. + """ + # Copy over OME-Zarr metadata for illumination_corrected image + # See #681 for discussion for validation of this zattrs + old_image_group = zarr.open_group(zarr_url_origin, mode="r") + old_attrs = old_image_group.attrs.asdict() + zarr_url_new = zarr_url_new.rstrip("/") + new_image_group = zarr.group(zarr_url_new) + new_image_group.attrs.put(old_attrs) + + # Update well metadata about adding the new image: + new_image_path = zarr_url_new.split("/")[-1] + well_url, old_image_path = _split_well_path_image_path(zarr_url_origin) + _update_well_metadata(well_url, old_image_path, new_image_path) + + +def _update_well_metadata( + well_url: str, + old_image_path: str, + new_image_path: str, + timeout: int = 120, +) -> None: + """ + Update the well metadata by adding the new_image_path to the image list. + + The content of new_image_path will be based on old_image_path, the origin + for the new image that was created. + This function aims to avoid race conditions with other processes that try + to update the well metadata file by using FileLock & Timeouts + + Args: + well_url: Path to the HCS OME-Zarr well that needs to be updated + old_image_path: path relative to well_url where the original image is + found + new_image_path: path relative to well_url where the new image is placed + timeout: Timeout in seconds for trying to get the file lock + """ + lock = FileLock(f"{well_url}/.zattrs.lock") + with lock.acquire(timeout=timeout): + + well_meta = load_NgffWellMeta(well_url) + existing_well_images = [image.path for image in well_meta.well.images] + if new_image_path in existing_well_images: + raise ValueError( + f"Could not add the {new_image_path=} image to the well " + "metadata because and image with that name " + f"already existed in the well metadata: {well_meta}" + ) + try: + well_meta_image_old = next( + image + for image in well_meta.well.images + if image.path == old_image_path + ) + except StopIteration: + raise ValueError( + f"Could not find an image with {old_image_path=} in the " + "current well metadata." + ) + well_meta_image = copy.deepcopy(well_meta_image_old) + well_meta_image.path = new_image_path + well_meta.well.images.append(well_meta_image) + well_meta.well.images = sorted( + well_meta.well.images, + key=lambda _image: _image.path, + ) + + well_group = zarr.group(well_url) + well_group.attrs.put(well_meta.dict(exclude_none=True)) + + # One could catch the timeout with a try except Timeout. But what to do + # with it? + + +def _split_well_path_image_path(zarr_url: str) -> tuple[str, str]: + """ + Returns path to well folder for HCS OME-Zarr `zarr_url`. + """ + zarr_url = zarr_url.rstrip("/") + well_path = "/".join(zarr_url.split("/")[:-1]) + img_path = zarr_url.split("/")[-1] + return well_path, img_path diff --git a/fractal_tasks_core/tasks/apply_registration_to_image.py b/fractal_tasks_core/tasks/apply_registration_to_image.py index ba203926c..174117758 100644 --- a/fractal_tasks_core/tasks/apply_registration_to_image.py +++ b/fractal_tasks_core/tasks/apply_registration_to_image.py @@ -35,7 +35,7 @@ from fractal_tasks_core.roi import is_standard_roi_table from fractal_tasks_core.roi import load_region from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks._registration_utils import ( +from fractal_tasks_core.tasks._zarr_utils import ( _split_well_path_image_path, ) from fractal_tasks_core.utils import _get_table_path_dict diff --git a/fractal_tasks_core/tasks/illumination_correction.py b/fractal_tasks_core/tasks/illumination_correction.py index 873742fdc..37a2b9fef 100644 --- a/fractal_tasks_core/tasks/illumination_correction.py +++ b/fractal_tasks_core/tasks/illumination_correction.py @@ -33,6 +33,7 @@ from fractal_tasks_core.roi import ( convert_ROI_table_to_indices, ) +from fractal_tasks_core.tasks._zarr_utils import _copy_hcs_ome_zarr_metadata logger = logging.getLogger(__name__) @@ -142,9 +143,9 @@ def illumination_correction( # Defione old/new zarrurls if overwrite_input: - zarr_url_new = zarr_url + zarr_url_new = zarr_url.rstrip("/") else: - zarr_url_new = zarr_url + suffix + zarr_url_new = zarr_url.rstrip("/") + suffix t_start = time.perf_counter() logger.info("Start illumination_correction") @@ -226,6 +227,7 @@ def illumination_correction( overwrite=False, dimension_separator="/", ) + _copy_hcs_ome_zarr_metadata(zarr_url, zarr_url_new) # Iterate over FOV ROIs num_ROIs = len(list_indices) diff --git a/tests/data/generate_zarr_ones.py b/tests/data/generate_zarr_ones.py index 961135465..26e71ac00 100644 --- a/tests/data/generate_zarr_ones.py +++ b/tests/data/generate_zarr_ones.py @@ -28,6 +28,23 @@ if os.path.isdir(zarrurl): shutil.rmtree(zarrurl) + +plate_group = zarr.open_group(zarrurl) +plate_group.attrs.put( + { + "plate": { + "acquisitions": [{"id": 1, "name": "plate_ones"}], + "columns": [{"name": "03"}], + "rows": [{"name": "B"}], + "version": "0.4", + "wells": [{"columnIndex": 0, "path": "B/03", "rowIndex": 0}], + } + } +) + +well_group = zarr.open(f"{zarrurl}B/03/") +well_group.attrs.put({"well": {"images": [{"path": "0"}], "version": "0.4"}}) + component = "B/03/0/" for ind_level in range(num_levels): diff --git a/tests/data/plate_ones.zarr/.zattrs b/tests/data/plate_ones.zarr/.zattrs new file mode 100644 index 000000000..de73fc002 --- /dev/null +++ b/tests/data/plate_ones.zarr/.zattrs @@ -0,0 +1,28 @@ +{ + "plate": { + "acquisitions": [ + { + "id": 1, + "name": "plate_ones" + } + ], + "columns": [ + { + "name": "03" + } + ], + "rows": [ + { + "name": "B" + } + ], + "version": "0.4", + "wells": [ + { + "columnIndex": 0, + "path": "B/03", + "rowIndex": 0 + } + ] + } +} diff --git a/tests/data/plate_ones.zarr/B/03/.zattrs b/tests/data/plate_ones.zarr/B/03/.zattrs new file mode 100644 index 000000000..afe52e509 --- /dev/null +++ b/tests/data/plate_ones.zarr/B/03/.zattrs @@ -0,0 +1,10 @@ +{ + "well": { + "images": [ + { + "path": "0" + } + ], + "version": "0.4" + } +} diff --git a/tests/tasks/test_unit_illumination_correction.py b/tests/tasks/test_unit_illumination_correction.py index 237c2fdd9..f739c18b4 100644 --- a/tests/tasks/test_unit_illumination_correction.py +++ b/tests/tasks/test_unit_illumination_correction.py @@ -6,24 +6,31 @@ import anndata as ad import dask.array as da import numpy as np +import pytest from pytest import LogCaptureFixture from pytest import MonkeyPatch from fractal_tasks_core.ngff.zarr_utils import load_NgffImageMeta +from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta from fractal_tasks_core.roi import ( convert_ROI_table_to_indices, ) +from fractal_tasks_core.tasks._registration_utils import ( + _split_well_path_image_path, +) from fractal_tasks_core.tasks.illumination_correction import correct from fractal_tasks_core.tasks.illumination_correction import ( illumination_correction, ) +@pytest.mark.parametrize("overwrite_input", [True, False]) def test_illumination_correction( tmp_path: Path, testdata_path: Path, monkeypatch: MonkeyPatch, caplog: LogCaptureFixture, + overwrite_input: bool, ): # GIVEN a zarr pyramid on disk, made of all ones # WHEN I apply illumination_correction @@ -49,13 +56,9 @@ def test_illumination_correction( with open(zarr_url + ".zattrs") as fin: zattrs = json.load(fin) num_levels = len(zattrs["multiscales"][0]["datasets"]) - metadata: dict = { - "num_levels": num_levels, - "coarsening_xy": 2, - } - num_channels = 2 - num_levels = metadata["num_levels"] + num_channels = 2 + num_levels = num_levels # Read FOV ROIs and create corresponding indices ngff_image_meta = load_NgffImageMeta(zarr_url) pixels = ngff_image_meta.get_pixel_sizes_zyx(level=0) @@ -82,14 +85,15 @@ def patched_correct(*args, **kwargs): "fractal_tasks_core.tasks.illumination_correction.correct", patched_correct, ) - + suffix = "_illum_corr" # Call illumination correction task, with patched correct() illumination_correction( zarr_url=zarr_url, - overwrite_input=True, + overwrite_input=overwrite_input, illumination_profiles_folder=illumination_profiles_folder, illumination_profiles=illum_params, background=0, + suffix=suffix, ) print(caplog.text) @@ -100,13 +104,35 @@ def patched_correct(*args, **kwargs): tot_calls_correct = len(f.read().splitlines()) assert tot_calls_correct == expected_tot_calls_correct + old_urls = [testdata_path / "plate_ones.zarr/B/03/0"] + if overwrite_input: + new_zarr_url = zarr_url.rstrip("/") + else: + new_zarr_url = zarr_url.rstrip("/") + suffix + old_urls.append(zarr_url.rstrip("/")) + # Verify the output - for ind_level in range(num_levels): - old = da.from_zarr( - testdata_path / f"plate_ones.zarr/B/03/0/{ind_level}" - ) - new = da.from_zarr(f"{zarr_url}{ind_level}") - assert old.shape == new.shape - assert old.chunks == new.chunks - assert new.compute()[0, 0, 0, 0] == 1 - assert np.allclose(old.compute(), new.compute()) + for old_url in old_urls: + for ind_level in range(num_levels): + old = da.from_zarr(f"{old_url}/{ind_level}") + print(testdata_path / f"plate_ones.zarr/B/03/0/{ind_level}") + print(f"{zarr_url}{ind_level}") + new = da.from_zarr(f"{new_zarr_url}/{ind_level}") + assert old.shape == new.shape + assert old.chunks == new.chunks + assert new.compute()[0, 0, 0, 0] == 1 + assert np.allclose(old.compute(), new.compute()) + + # Verify that the new_zarr_url has valid OME-Zarr metadata + _ = load_NgffImageMeta(new_zarr_url) + + # Verify the well metadata: Are all the images in well present in the + # well metadata? + well_url, _ = _split_well_path_image_path(new_zarr_url) + well_meta = load_NgffWellMeta(well_url) + well_paths = [image.path for image in well_meta.well.images] + + if overwrite_input: + assert well_paths == ["0"] + else: + assert well_paths == ["0", "0" + suffix] diff --git a/tests/tasks/test_unit_zarr_utils.py b/tests/tasks/test_unit_zarr_utils.py new file mode 100644 index 000000000..6b86fb24e --- /dev/null +++ b/tests/tasks/test_unit_zarr_utils.py @@ -0,0 +1,167 @@ +import logging +import shutil +import time +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path + +import pytest +import zarr +from devtools import debug +from filelock._error import Timeout +from pytest import LogCaptureFixture + +from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta +from fractal_tasks_core.tasks._registration_utils import ( + _split_well_path_image_path, +) +from fractal_tasks_core.tasks._zarr_utils import _copy_hcs_ome_zarr_metadata +from fractal_tasks_core.tasks._zarr_utils import _update_well_metadata + + +@pytest.mark.parametrize("trailing_slash", [True, False]) +def test_copy_hcs_ome_zarr_metadata( + tmp_path: Path, + testdata_path: Path, + caplog: LogCaptureFixture, + trailing_slash: bool, +): + caplog.set_level(logging.INFO) + + # Copy a reference zarr into a temporary folder + raw_zarrurl = (testdata_path / "plate_ones.zarr").as_posix() + zarr_url = (tmp_path / "plate.zarr").resolve().as_posix() + shutil.copytree(raw_zarrurl, zarr_url) + zarr_url += "/B/03/0" + suffix = "_illum_corr" + new_zarr_url = zarr_url + suffix + if trailing_slash: + zarr_url += "/" + new_zarr_url += "/" + + _copy_hcs_ome_zarr_metadata( + zarr_url_origin=zarr_url, zarr_url_new=new_zarr_url + ) + + group = zarr.open_group(zarr_url, mode="r") + old_attrs = group.attrs.asdict() + group_new = zarr.open_group(new_zarr_url, mode="r") + new_attrs = group_new.attrs.asdict() + debug(old_attrs) + assert old_attrs == new_attrs + + # Check well metadata: + well_url, _ = _split_well_path_image_path(zarr_url=zarr_url) + well_meta = load_NgffWellMeta(well_url) + debug(well_meta) + well_paths = [image.path for image in well_meta.well.images] + assert well_paths == ["0", "0" + suffix] + + +def _star_update_well_metadata(args): + """ + This is only needed because concurrent.futures executors have a `map` + method but not a `starmap` one. + """ + return _update_well_metadata(*args) + + +def test_update_well_metadata_concurrency( + tmp_path: Path, + testdata_path: Path, + monkeypatch, +): + """ + Run _update_well_metadata in parallel for adding N>1 new images to a given + well. We artificially slow down each call by INTERVAL seconds, and verify + that the test takes at least N x INTERVAL seconds (since each call to + `_update_well_metadata` is blocking). + + In the last section of the test, we verify that a timeout error is raised + when the timeout is too short. + """ + + N = 4 + INTERVAL = 0.5 + + # Copy a reference zarr into a temporary folder + raw_zarrurl = (testdata_path / "plate_ones.zarr").as_posix() + zarr_url = (tmp_path / "plate.zarr").resolve().as_posix() + shutil.copytree(raw_zarrurl, zarr_url) + + # Artificially slow down `_update_well_metadata` + import fractal_tasks_core.tasks._zarr_utils + + def _slow_load_NgffWellMeta(*args, **kwargs): + time.sleep(INTERVAL) + return load_NgffWellMeta(*args, **kwargs) + + monkeypatch.setattr( + fractal_tasks_core.tasks._zarr_utils, + "load_NgffWellMeta", + _slow_load_NgffWellMeta, + ) + + # Prepare parallel-execution argument list + well_url = Path(zarr_url, "B/03").as_posix() + list_args = [(well_url, "0", f"0_new_{suffix}") for suffix in range(N)] + + # Run `_update_well_metadata` N times + time_start = time.perf_counter() + executor = ProcessPoolExecutor() + res_iter = executor.map(_star_update_well_metadata, list_args) + list(res_iter) # This is needed, to wait for all results. + time_end = time.perf_counter() + + # Check that time was at least N*INTERVAL seconds + assert (time_end - time_start) > N * INTERVAL + + # Check that all new images were added + well_meta = load_NgffWellMeta(well_url) + well_image_paths = [img.path for img in well_meta.well.images] + debug(well_image_paths) + assert well_image_paths == [ + "0", + "0_new_0", + "0_new_1", + "0_new_2", + "0_new_3", + ] + + # Prepare parallel-execution argument list with short timeout + well_url = Path(zarr_url, "B/03").as_posix() + list_args = [ + (well_url, "0", f"0_new_{suffix}", INTERVAL / 100) + for suffix in range(N, 2 * N) + ] + with pytest.raises(Timeout) as e: + res_iter = executor.map(_star_update_well_metadata, list_args) + list(res_iter) # This is needed, to wait for all results. + debug(e.value) + + +def test_update_well_metadata_failures( + tmp_path: Path, + testdata_path: Path, +): + """ + When called with an invalid `old_image_path` or `new_image_path`, + `_update_well_metadata` fails as expected. + """ + + # Copy a reference zarr into a temporary folder + raw_zarrurl = (testdata_path / "plate_ones.zarr").as_posix() + zarr_url = (tmp_path / "plate.zarr").resolve().as_posix() + shutil.copytree(raw_zarrurl, zarr_url) + well_url = Path(zarr_url, "B/03").as_posix() + + # Failure case 1 + with pytest.raises(ValueError) as e: + _update_well_metadata(well_url, "INVALID_OLD_IMAGE_PATH", "0_new") + + assert "Could not find an image with old_image_path" in str(e.value) + + # Failure case 2 + with pytest.raises(ValueError) as e: + _update_well_metadata(well_url, "0", "0") + + assert "Could not add the new_image_path" in str(e.value)