diff --git a/fractal_tasks_core/tasks/__init__.py b/fractal_tasks_core/tasks/__init__.py new file mode 100644 index 000000000..d959b02bb --- /dev/null +++ b/fractal_tasks_core/tasks/__init__.py @@ -0,0 +1,3 @@ +""" +Tasks subpackage (requires installation extra `fractal-tasks`). +""" diff --git a/fractal_tasks_core/tasks/_registration_utils.py b/fractal_tasks_core/tasks/_registration_utils.py new file mode 100755 index 000000000..8cac07736 --- /dev/null +++ b/fractal_tasks_core/tasks/_registration_utils.py @@ -0,0 +1,238 @@ +# Copyright 2024 (C) BioVisionCenter +# +# Original authors: +# Joel Lüthi +# +# This file is part of Fractal +"""Utils functions for registration""" +import copy + +import anndata as ad +import dask.array as da +import numpy as np +import pandas as pd +from image_registration import chi2_shift + + +def calculate_physical_shifts( + shifts: np.array, + level: int, + coarsening_xy: int, + full_res_pxl_sizes_zyx: list[float], +) -> list[float]: + """ + Calculates shifts in physical units based on pixel shifts + + Args: + shifts: array of shifts, zyx or yx + level: resolution level + coarsening_xy: coarsening factor between levels + full_res_pxl_sizes_zyx: pixel sizes in physical units as zyx + + Returns: + shifts in physical units as zyx + """ + + curr_pixel_size = np.array(full_res_pxl_sizes_zyx) * coarsening_xy**level + if len(shifts) == 3: + shifts_physical = shifts * curr_pixel_size + elif len(shifts) == 2: + shifts_physical = [ + 0, + shifts[0] * curr_pixel_size[1], + shifts[1] * curr_pixel_size[2], + ] + else: + raise ValueError( + f"Wrong input for calculate_physical_shifts ({shifts=})" + ) + return shifts_physical + + +def get_ROI_table_with_translation( + ROI_table: ad.AnnData, + new_shifts: dict[str, list[float]], +) -> ad.AnnData: + """ + Adds translation columns to a ROI table + + Args: + ROI_table: Fractal ROI table + new_shifts: zyx list of shifts + + Returns: + Fractal ROI table with 3 additional columns for calculated translations + """ + + shift_table = pd.DataFrame(new_shifts).T + shift_table.columns = ["translation_z", "translation_y", "translation_x"] + shift_table = shift_table.rename_axis("FieldIndex") + new_roi_table = ROI_table.to_df().merge( + shift_table, left_index=True, right_index=True + ) + if len(new_roi_table) != len(ROI_table): + raise ValueError( + "New ROI table with registration info has a " + f"different length ({len(new_roi_table)=}) " + f"from the original ROI table ({len(ROI_table)=})" + ) + + adata = ad.AnnData(X=new_roi_table.astype(np.float32)) + adata.obs_names = new_roi_table.index + adata.var_names = list(map(str, new_roi_table.columns)) + return adata + + +# Helper functions +def add_zero_translation_columns(ad_table: ad.AnnData): + """ + Add three zero-filled columns (`translation_{x,y,z}`) to an AnnData table. + """ + columns = ["translation_z", "translation_y", "translation_x"] + if ad_table.var.index.isin(columns).any().any(): + raise ValueError( + "The roi table already contains translation columns. Did you " + "enter a wrong reference acquisition?" + ) + df = pd.DataFrame(np.zeros([len(ad_table), 3]), columns=columns) + df.index = ad_table.obs.index + ad_new = ad.concat([ad_table, ad.AnnData(df)], axis=1) + return ad_new + + +def calculate_min_max_across_dfs(tables_list): + # Initialize dataframes to store the max and min values + max_df = pd.DataFrame( + index=tables_list[0].index, columns=tables_list[0].columns + ) + min_df = pd.DataFrame( + index=tables_list[0].index, columns=tables_list[0].columns + ) + + # Loop through the tables and calculate max and min values + for table in tables_list: + max_df = pd.DataFrame( + np.maximum(max_df.values, table.values), + columns=max_df.columns, + index=max_df.index, + ) + min_df = pd.DataFrame( + np.minimum(min_df.values, table.values), + columns=min_df.columns, + index=min_df.index, + ) + + return max_df, min_df + + +def apply_registration_to_single_ROI_table( + roi_table: ad.AnnData, + max_df: pd.DataFrame, + min_df: pd.DataFrame, +) -> ad.AnnData: + """ + Applies the registration to a ROI table + + Calculates the new position as: p = position + max(shift, 0) - own_shift + Calculates the new len as: l = len - max(shift, 0) + min(shift, 0) + + Args: + roi_table: AnnData table which contains a Fractal ROI table. + Rows are ROIs + max_df: Max translation shift in z, y, x for each ROI. Rows are ROIs, + columns are translation_z, translation_y, translation_x + min_df: Min translation shift in z, y, x for each ROI. Rows are ROIs, + columns are translation_z, translation_y, translation_x + Returns: + ROI table where all ROIs are registered to the smallest common area + across all acquisitions. + """ + roi_table = copy.deepcopy(roi_table) + rois = roi_table.obs.index + if (rois != max_df.index).all() or (rois != min_df.index).all(): + raise ValueError( + "ROI table and max & min translation need to contain the same " + f"ROIS, but they were {rois=}, {max_df.index=}, {min_df.index=}" + ) + + for roi in rois: + roi_table[[roi], ["z_micrometer"]] = ( + roi_table[[roi], ["z_micrometer"]].X + + float(max_df.loc[roi, "translation_z"]) + - roi_table[[roi], ["translation_z"]].X + ) + roi_table[[roi], ["y_micrometer"]] = ( + roi_table[[roi], ["y_micrometer"]].X + + float(max_df.loc[roi, "translation_y"]) + - roi_table[[roi], ["translation_y"]].X + ) + roi_table[[roi], ["x_micrometer"]] = ( + roi_table[[roi], ["x_micrometer"]].X + + float(max_df.loc[roi, "translation_x"]) + - roi_table[[roi], ["translation_x"]].X + ) + # This calculation only works if all ROIs are the same size initially! + roi_table[[roi], ["len_z_micrometer"]] = ( + roi_table[[roi], ["len_z_micrometer"]].X + - float(max_df.loc[roi, "translation_z"]) + + float(min_df.loc[roi, "translation_z"]) + ) + roi_table[[roi], ["len_y_micrometer"]] = ( + roi_table[[roi], ["len_y_micrometer"]].X + - float(max_df.loc[roi, "translation_y"]) + + float(min_df.loc[roi, "translation_y"]) + ) + roi_table[[roi], ["len_x_micrometer"]] = ( + roi_table[[roi], ["len_x_micrometer"]].X + - float(max_df.loc[roi, "translation_x"]) + + float(min_df.loc[roi, "translation_x"]) + ) + return roi_table + + +def chi2_shift_out(img_ref, img_cycle_x) -> list[np.ndarray]: + """ + Helper function to get the output of chi2_shift into the same format as + phase_cross_correlation. Calculates the shift between two images using + the chi2_shift method. + + Args: + img_ref (np.ndarray): First image. + img_cycle_x (np.ndarray): Second image. + + Returns: + List containing numpy array of shift in y and x direction. + """ + x, y, a, b = chi2_shift(np.squeeze(img_ref), np.squeeze(img_cycle_x)) + + """ + Running into issues when using direct float output for fractal. + When rounding to integer and using integer dtype, it typically works + but for some reasons fails when run over a whole 384 well plate (but + the well where it fails works fine when run alone). For now, rounding + to integer, but still using float64 dtype (like the scikit-image + phase cross correlation function) seems to be the safest option. + """ + shifts = np.array([-np.round(y), -np.round(x)], dtype="float64") + # return as a list to adhere to the phase_cross_correlation output format + return [shifts] + + +def is_3D(dask_array: da.array) -> bool: + """ + Check if a dask array is 3D. + + Treats singelton Z dimensions as 2D images. + (1, 2000, 2000) => False + (10, 2000, 2000) => True + + Args: + dask_array: Input array to be checked + + Returns: + bool on whether the array is 3D + """ + if len(dask_array.shape) == 3 and dask_array.shape[0] > 1: + return True + else: + return False diff --git a/fractal_tasks_core/tasks/_utils.py b/fractal_tasks_core/tasks/_utils.py new file mode 100644 index 000000000..a7c0a4b89 --- /dev/null +++ b/fractal_tasks_core/tasks/_utils.py @@ -0,0 +1,84 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Standard input/output interface for tasks. +""" +import json +import logging +from argparse import ArgumentParser +from json import JSONEncoder +from pathlib import Path +from typing import Callable +from typing import Optional + + +class TaskParameterEncoder(JSONEncoder): + """ + Custom JSONEncoder that transforms Path objects to strings. + """ + + def default(self, value): + """ + Subclass implementation of `default`, to serialize Path objects as + strings. + """ + if isinstance(value, Path): + return value.as_posix() + return JSONEncoder.default(self, value) + + +def run_fractal_task( + *, + task_function: Callable, + logger_name: Optional[str] = None, +): + """ + Implement standard task interface and call task_function. + + Args: + task_function: the callable function that runs the task. + logger_name: TBD + """ + + # Parse `-j` and `--metadata-out` arguments + parser = ArgumentParser() + parser.add_argument( + "--args-json", help="Read parameters from json file", required=True + ) + parser.add_argument( + "--out-json", + help="Output file to redirect serialised returned data", + required=True, + ) + parsed_args = parser.parse_args() + + # Set logger + logger = logging.getLogger(logger_name) + + # Preliminary check + if Path(parsed_args.out_json).exists(): + logger.error( + f"Output file {parsed_args.out_json} already exists. Terminating" + ) + exit(1) + + # Read parameters dictionary + with open(parsed_args.args_json, "r") as f: + pars = json.load(f) + + # Run task + logger.info(f"START {task_function.__name__} task") + metadata_update = task_function(**pars) + logger.info(f"END {task_function.__name__} task") + + # Write output metadata to file, with custom JSON encoder + with open(parsed_args.out_json, "w") as fout: + json.dump(metadata_update, fout, cls=TaskParameterEncoder, indent=2) diff --git a/fractal_tasks_core/tasks/_zarr_utils.py b/fractal_tasks_core/tasks/_zarr_utils.py new file mode 100644 index 000000000..0f030272b --- /dev/null +++ b/fractal_tasks_core/tasks/_zarr_utils.py @@ -0,0 +1,205 @@ +import copy +import logging + +import anndata as ad +import zarr +from filelock import FileLock + +from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tables.v1 import get_tables_list_v1 +from fractal_tasks_core.utils import _split_well_path_image_path + +logger = logging.getLogger(__name__) + + +def _copy_hcs_ome_zarr_metadata( + zarr_url_origin: str, + zarr_url_new: str, +) -> None: + """ + Updates the necessary metadata for a new copy of an OME-Zarr image + + Based on an existing OME-Zarr image in the same well, the metadata is + copied and added to the new zarr well. Additionally, the well-level + metadata is updated to include this new image. + + Args: + zarr_url_origin: zarr_url of the origin image + zarr_url_new: zarr_url of the newly created image. The zarr-group + already needs to exist, but metadata is written by this function. + """ + # Copy over OME-Zarr metadata for illumination_corrected image + # See #681 for discussion for validation of this zattrs + old_image_group = zarr.open_group(zarr_url_origin, mode="r") + old_attrs = old_image_group.attrs.asdict() + zarr_url_new = zarr_url_new.rstrip("/") + new_image_group = zarr.group(zarr_url_new) + new_image_group.attrs.put(old_attrs) + + # Update well metadata about adding the new image: + new_image_path = zarr_url_new.split("/")[-1] + well_url, old_image_path = _split_well_path_image_path(zarr_url_origin) + _update_well_metadata(well_url, old_image_path, new_image_path) + + +def _update_well_metadata( + well_url: str, + old_image_path: str, + new_image_path: str, + timeout: int = 120, +) -> None: + """ + Update the well metadata by adding the new_image_path to the image list. + + The content of new_image_path will be based on old_image_path, the origin + for the new image that was created. + This function aims to avoid race conditions with other processes that try + to update the well metadata file by using FileLock & Timeouts + + Args: + well_url: Path to the HCS OME-Zarr well that needs to be updated + old_image_path: path relative to well_url where the original image is + found + new_image_path: path relative to well_url where the new image is placed + timeout: Timeout in seconds for trying to get the file lock + """ + lock = FileLock(f"{well_url}/.zattrs.lock") + with lock.acquire(timeout=timeout): + well_meta = load_NgffWellMeta(well_url) + existing_well_images = [image.path for image in well_meta.well.images] + if new_image_path in existing_well_images: + raise ValueError( + f"Could not add the {new_image_path=} image to the well " + "metadata because and image with that name " + f"already existed in the well metadata: {well_meta}" + ) + try: + well_meta_image_old = next( + image + for image in well_meta.well.images + if image.path == old_image_path + ) + except StopIteration: + raise ValueError( + f"Could not find an image with {old_image_path=} in the " + "current well metadata." + ) + well_meta_image = copy.deepcopy(well_meta_image_old) + well_meta_image.path = new_image_path + well_meta.well.images.append(well_meta_image) + well_meta.well.images = sorted( + well_meta.well.images, + key=lambda _image: _image.path, + ) + + well_group = zarr.group(well_url) + well_group.attrs.put(well_meta.model_dump(exclude_none=True)) + + # One could catch the timeout with a try except Timeout. But what to do + # with it? + + +def _split_base_suffix(input: str) -> tuple[str, str]: + parts = input.split("_") + base = parts[0] + if len(parts) > 1: + suffix = "_".join(parts[1:]) + else: + suffix = "" + return base, suffix + + +def _get_matching_ref_acquisition_path_heuristic( + path_list: list[str], path: str +) -> str: + """ + Pick the best match from path_list to a given path + + This is a workaround to find the reference registration acquisition when + there are multiple OME-Zarrs with the same acquisition identifier in the + well metadata and we need to find which one is the reference for a given + path. + + Args: + path_list: List of paths to OME-Zarr images in the well metadata. For + example: ['0', '0_illum_corr'] + path: A given path for which we want to find the reference image. For + example, '1_illum_corr' + + Returns: + The best matching reference path. If no direct match is found, it + returns the most similar one based on suffix hierarchy or the base + path if applicable. For example, '0_illum_corr' with the example + inputs above. + """ + + # Extract the base number and suffix from the input path + base, suffix = _split_base_suffix(path) + + # Sort path_list + sorted_path_list = sorted(path_list) + + # Never return the input `path` + if path in sorted_path_list: + sorted_path_list.remove(path) + + # First matching rule: a path with the same suffix + for p in sorted_path_list: + # Split the list path into base and suffix + p_base, p_suffix = _split_base_suffix(p) + # If suffices match, it's the match. + if p_suffix == suffix: + return p + + # If no match is found, return the first entry in the list + logger.warning( + "No heuristic reference acquisition match found, defaulting to first " + f"option {sorted_path_list[0]}." + ) + return sorted_path_list[0] + + +def _copy_tables_from_zarr_url( + origin_zarr_url: str, + target_zarr_url: str, + table_type: str = None, + overwrite: bool = True, +) -> None: + """ + Copies all ROI tables from one Zarr into a new Zarr + + Args: + origin_zarr_url: url of the OME-Zarr image that contains tables. + e.g. /path/to/my_plate.zarr/B/03/0 + target_zarr_url: url of the new OME-Zarr image where tables are copied + to. e.g. /path/to/my_plate.zarr/B/03/0_illum_corr + table_type: Filter for specific table types that should be copied. + overwrite: Whether existing tables of the same name in the + target_zarr_url should be overwritten. + """ + table_list = get_tables_list_v1( + zarr_url=origin_zarr_url, table_type=table_type + ) + + if table_list: + logger.info( + f"Copying the tables {table_list} from {origin_zarr_url} to " + f"{target_zarr_url}." + ) + new_image_group = zarr.group(target_zarr_url) + + for table in table_list: + logger.info(f"Copying table: {table}") + # Get the relevant metadata of the Zarr table & add it + table_url = f"{origin_zarr_url}/tables/{table}" + old_table_group = zarr.open_group(table_url, mode="r") + # Write the Zarr table + curr_table = ad.read_zarr(table_url) + write_table( + new_image_group, + table, + curr_table, + table_attrs=old_table_group.attrs.asdict(), + overwrite=overwrite, + ) diff --git a/fractal_tasks_core/tasks/apply_registration_to_image.py b/fractal_tasks_core/tasks/apply_registration_to_image.py new file mode 100644 index 000000000..156f5b3e6 --- /dev/null +++ b/fractal_tasks_core/tasks/apply_registration_to_image.py @@ -0,0 +1,392 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Joel Lüthi +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Calculates translation for 2D image-based registration +""" +import logging +import os +import shutil +import time +from typing import Callable + +import anndata as ad +import dask.array as da +import numpy as np +import zarr +from pydantic import validate_call + +from fractal_tasks_core.ngff import load_NgffImageMeta +from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta +from fractal_tasks_core.pyramids import build_pyramid +from fractal_tasks_core.roi import ( + convert_indices_to_regions, +) +from fractal_tasks_core.roi import ( + convert_ROI_table_to_indices, +) +from fractal_tasks_core.roi import is_standard_roi_table +from fractal_tasks_core.roi import load_region +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tasks._zarr_utils import ( + _get_matching_ref_acquisition_path_heuristic, +) +from fractal_tasks_core.tasks._zarr_utils import _update_well_metadata +from fractal_tasks_core.utils import _get_table_path_dict +from fractal_tasks_core.utils import ( + _split_well_path_image_path, +) + +logger = logging.getLogger(__name__) + + +@validate_call +def apply_registration_to_image( + *, + # Fractal parameters + zarr_url: str, + # Core parameters + registered_roi_table: str, + reference_acquisition: int = 0, + overwrite_input: bool = True, +): + """ + Apply registration to images by using a registered ROI table + + This task consists of 4 parts: + + 1. Mask all regions in images that are not available in the + registered ROI table and store each acquisition aligned to the + reference_acquisition (by looping over ROIs). + 2. Do the same for all label images. + 3. Copy all tables from the non-aligned image to the aligned image + (currently only works well if the only tables are well & FOV ROI tables + (registered and original). Not implemented for measurement tables and + other ROI tables). + 4. Clean up: Delete the old, non-aligned image and rename the new, + aligned image to take over its place. + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + (standard argument for Fractal tasks, managed by Fractal server). + registered_roi_table: Name of the ROI table which has been registered + and will be applied to mask and shift the images. + Examples: `registered_FOV_ROI_table` => loop over the field of + views, `registered_well_ROI_table` => process the whole well as + one image. + reference_acquisition: Which acquisition to register against. Uses the + OME-NGFF HCS well metadata acquisition keys to find the reference + acquisition. + overwrite_input: Whether the old image data should be replaced with the + newly registered image data. Currently only implemented for + `overwrite_input=True`. + + """ + logger.info(zarr_url) + logger.info( + f"Running `apply_registration_to_image` on {zarr_url=}, " + f"{registered_roi_table=} and {reference_acquisition=}. " + f"Using {overwrite_input=}" + ) + + well_url, old_img_path = _split_well_path_image_path(zarr_url) + new_zarr_url = f"{well_url}/{zarr_url.split('/')[-1]}_registered" + # Get the zarr_url for the reference acquisition + acq_dict = load_NgffWellMeta(well_url).get_acquisition_paths() + if reference_acquisition not in acq_dict: + raise ValueError( + f"{reference_acquisition=} was not one of the available " + f"acquisitions in {acq_dict=} for well {well_url}" + ) + elif len(acq_dict[reference_acquisition]) > 1: + ref_path = _get_matching_ref_acquisition_path_heuristic( + acq_dict[reference_acquisition], old_img_path + ) + logger.warning( + "Running registration when there are multiple images of the same " + "acquisition in a well. Using a heuristic to match the reference " + f"acquisition. Using {ref_path} as the reference image." + ) + else: + ref_path = acq_dict[reference_acquisition][0] + reference_zarr_url = f"{well_url}/{ref_path}" + + ROI_table_ref = ad.read_zarr( + f"{reference_zarr_url}/tables/{registered_roi_table}" + ) + ROI_table_acq = ad.read_zarr(f"{zarr_url}/tables/{registered_roi_table}") + + ngff_image_meta = load_NgffImageMeta(zarr_url) + coarsening_xy = ngff_image_meta.coarsening_xy + num_levels = ngff_image_meta.num_levels + + #################### + # Process images + #################### + logger.info("Write the registered Zarr image to disk") + write_registered_zarr( + zarr_url=zarr_url, + new_zarr_url=new_zarr_url, + ROI_table=ROI_table_acq, + ROI_table_ref=ROI_table_ref, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + aggregation_function=np.mean, + ) + + #################### + # Process labels + #################### + try: + labels_group = zarr.open_group(f"{zarr_url}/labels", "r") + label_list = labels_group.attrs["labels"] + except (zarr.errors.GroupNotFoundError, KeyError): + label_list = [] + + if label_list: + logger.info(f"Processing the label images: {label_list}") + labels_group = zarr.group(f"{new_zarr_url}/labels") + labels_group.attrs["labels"] = label_list + + for label in label_list: + write_registered_zarr( + zarr_url=f"{zarr_url}/labels/{label}", + new_zarr_url=f"{new_zarr_url}/labels/{label}", + ROI_table=ROI_table_acq, + ROI_table_ref=ROI_table_ref, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + aggregation_function=np.max, + ) + + #################### + # Copy tables + # 1. Copy all standard ROI tables from the reference acquisition. + # 2. Copy all tables that aren't standard ROI tables from the given + # acquisition. + #################### + table_dict_reference = _get_table_path_dict(reference_zarr_url) + table_dict_component = _get_table_path_dict(zarr_url) + + table_dict = {} + # Define which table should get copied: + for table in table_dict_reference: + if is_standard_roi_table(table): + table_dict[table] = table_dict_reference[table] + for table in table_dict_component: + if not is_standard_roi_table(table): + if reference_zarr_url != zarr_url: + logger.warning( + f"{zarr_url} contained a table that is not a standard " + "ROI table. The `Apply Registration To Image task` is " + "best used before additional tables are generated. It " + f"will copy the {table} from this acquisition without " + "applying any transformations. This will work well if " + f"{table} contains measurements. But if {table} is a " + "custom ROI table coming from another task, the " + "transformation is not applied and it will not match " + "with the registered image anymore." + ) + table_dict[table] = table_dict_component[table] + + if table_dict: + logger.info(f"Processing the tables: {table_dict}") + new_image_group = zarr.group(new_zarr_url) + + for table in table_dict.keys(): + logger.info(f"Copying table: {table}") + # Get the relevant metadata of the Zarr table & add it + # See issue #516 for the need for this workaround + max_retries = 20 + sleep_time = 5 + current_round = 0 + while current_round < max_retries: + try: + old_table_group = zarr.open_group( + table_dict[table], mode="r" + ) + current_round = max_retries + except zarr.errors.GroupNotFoundError: + logger.debug( + f"Table {table} not found in attempt {current_round}. " + f"Waiting {sleep_time} seconds before trying again." + ) + current_round += 1 + time.sleep(sleep_time) + # Write the Zarr table + curr_table = ad.read_zarr(table_dict[table]) + write_table( + new_image_group, + table, + curr_table, + table_attrs=old_table_group.attrs.asdict(), + overwrite=True, + ) + + #################### + # Clean up Zarr file + #################### + if overwrite_input: + logger.info( + "Replace original zarr image with the newly created Zarr image" + ) + # Potential for race conditions: Every acquisition reads the + # reference acquisition, but the reference acquisition also gets + # modified + # See issue #516 for the details + os.rename(zarr_url, f"{zarr_url}_tmp") + os.rename(new_zarr_url, zarr_url) + shutil.rmtree(f"{zarr_url}_tmp") + image_list_updates = dict(image_list_updates=[dict(zarr_url=zarr_url)]) + else: + image_list_updates = dict( + image_list_updates=[dict(zarr_url=new_zarr_url, origin=zarr_url)] + ) + # Update the metadata of the the well + well_url, new_img_path = _split_well_path_image_path(new_zarr_url) + _update_well_metadata( + well_url=well_url, + old_image_path=old_img_path, + new_image_path=new_img_path, + ) + + return image_list_updates + + +def write_registered_zarr( + zarr_url: str, + new_zarr_url: str, + ROI_table: ad.AnnData, + ROI_table_ref: ad.AnnData, + num_levels: int, + coarsening_xy: int = 2, + aggregation_function: Callable = np.mean, +): + """ + Write registered zarr array based on ROI tables + + This function loads the image or label data from a zarr array based on the + ROI bounding-box coordinates and stores them into a new zarr array. + The new Zarr array has the same shape as the original array, but will have + 0s where the ROI tables don't specify loading of the image data. + The ROIs loaded from `list_indices` will be written into the + `list_indices_ref` position, thus performing translational registration if + the two lists of ROI indices vary. + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be used as + the basis for the new OME-Zarr image. + new_zarr_url: Path or url to the new OME-Zarr image to be written + ROI_table: Fractal ROI table for the component + ROI_table_ref: Fractal ROI table for the reference acquisition + num_levels: Number of pyramid layers to be created (argument of + `build_pyramid`). + coarsening_xy: Coarsening factor between pyramid levels + aggregation_function: Function to be used when downsampling (argument + of `build_pyramid`). + + """ + # Read pixel sizes from Zarr attributes + ngff_image_meta = load_NgffImageMeta(zarr_url) + pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) + + # Create list of indices for 3D ROIs + list_indices = convert_ROI_table_to_indices( + ROI_table, + level=0, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=pxl_sizes_zyx, + ) + list_indices_ref = convert_ROI_table_to_indices( + ROI_table_ref, + level=0, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=pxl_sizes_zyx, + ) + + old_image_group = zarr.open_group(zarr_url, mode="r") + old_ngff_image_meta = load_NgffImageMeta(zarr_url) + new_image_group = zarr.group(new_zarr_url) + new_image_group.attrs.put(old_image_group.attrs.asdict()) + + # Loop over all channels. For each channel, write full-res image data. + data_array = da.from_zarr(old_image_group["0"]) + # Create dask array with 0s of same shape + new_array = da.zeros_like(data_array) + + # TODO: Add sanity checks on the 2 ROI tables: + # 1. The number of ROIs need to match + # 2. The size of the ROIs need to match + # (otherwise, we can't assign them to the reference regions) + # ROI_table_ref vs ROI_table_acq + for i, roi_indices in enumerate(list_indices): + reference_region = convert_indices_to_regions(list_indices_ref[i]) + region = convert_indices_to_regions(roi_indices) + + axes_list = old_ngff_image_meta.axes_names + + if axes_list == ["c", "z", "y", "x"]: + num_channels = data_array.shape[0] + # Loop over channels + for ind_ch in range(num_channels): + idx = tuple( + [slice(ind_ch, ind_ch + 1)] + list(reference_region) + ) + new_array[idx] = load_region( + data_zyx=data_array[ind_ch], region=region, compute=False + ) + elif axes_list == ["z", "y", "x"]: + new_array[reference_region] = load_region( + data_zyx=data_array, region=region, compute=False + ) + elif axes_list == ["c", "y", "x"]: + # TODO: Implement cyx case (based on looping over xy case) + raise NotImplementedError( + "`write_registered_zarr` has not been implemented for " + f"a zarr with {axes_list=}" + ) + elif axes_list == ["y", "x"]: + # TODO: Implement yx case + raise NotImplementedError( + "`write_registered_zarr` has not been implemented for " + f"a zarr with {axes_list=}" + ) + else: + raise NotImplementedError( + "`write_registered_zarr` has not been implemented for " + f"a zarr with {axes_list=}" + ) + + new_array.to_zarr( + f"{new_zarr_url}/0", + overwrite=True, + dimension_separator="/", + write_empty_chunks=False, + ) + + # Starting from on-disk highest-resolution data, build and write to + # disk a pyramid of coarser levels + build_pyramid( + zarrurl=new_zarr_url, + overwrite=True, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + chunksize=data_array.chunksize, + aggregation_function=aggregation_function, + ) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=apply_registration_to_image, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/calculate_registration_image_based.py b/fractal_tasks_core/tasks/calculate_registration_image_based.py new file mode 100755 index 000000000..6bb783662 --- /dev/null +++ b/fractal_tasks_core/tasks/calculate_registration_image_based.py @@ -0,0 +1,317 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Joel Lüthi +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Calculates translation for image-based registration +""" +import logging +from enum import Enum + +import anndata as ad +import dask.array as da +import numpy as np +import zarr +from pydantic import validate_call +from skimage.exposure import rescale_intensity +from skimage.registration import phase_cross_correlation + +from fractal_tasks_core.channels import get_channel_from_image_zarr +from fractal_tasks_core.channels import OmeroChannel +from fractal_tasks_core.ngff import load_NgffImageMeta +from fractal_tasks_core.roi import check_valid_ROI_indices +from fractal_tasks_core.roi import ( + convert_indices_to_regions, +) +from fractal_tasks_core.roi import ( + convert_ROI_table_to_indices, +) +from fractal_tasks_core.roi import load_region +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tasks._registration_utils import ( + calculate_physical_shifts, +) +from fractal_tasks_core.tasks._registration_utils import chi2_shift_out +from fractal_tasks_core.tasks._registration_utils import ( + get_ROI_table_with_translation, +) +from fractal_tasks_core.tasks._registration_utils import is_3D +from fractal_tasks_core.tasks.io_models import InitArgsRegistration + +logger = logging.getLogger(__name__) + + +class RegistrationMethod(Enum): + """ + RegistrationMethod Enum class + + Attributes: + PHASE_CROSS_CORRELATION: phase cross correlation based on scikit-image + (works with 2D & 3D images). + CHI2_SHIFT: chi2 shift based on image-registration library + (only works with 2D images). + """ + + PHASE_CROSS_CORRELATION = "phase_cross_correlation" + CHI2_SHIFT = "chi2_shift" + + def register(self, img_ref, img_acq_x): + if self == RegistrationMethod.PHASE_CROSS_CORRELATION: + return phase_cross_correlation(img_ref, img_acq_x) + elif self == RegistrationMethod.CHI2_SHIFT: + return chi2_shift_out(img_ref, img_acq_x) + + +@validate_call +def calculate_registration_image_based( + *, + # Fractal arguments + zarr_url: str, + init_args: InitArgsRegistration, + # Core parameters + wavelength_id: str, + method: RegistrationMethod = RegistrationMethod.PHASE_CROSS_CORRELATION, + lower_rescale_quantile: float = 0.0, + upper_rescale_quantile: float = 0.99, + roi_table: str = "FOV_ROI_table", + level: int = 2, +) -> None: + """ + Calculate registration based on images + + This task consists of 3 parts: + + 1. Loading the images of a given ROI (=> loop over ROIs) + 2. Calculating the transformation for that ROI + 3. Storing the calculated transformation in the ROI table + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + (standard argument for Fractal tasks, managed by Fractal server). + init_args: Intialization arguments provided by + `image_based_registration_hcs_init`. They contain the + reference_zarr_url that is used for registration. + (standard argument for Fractal tasks, managed by Fractal server). + wavelength_id: Wavelength that will be used for image-based + registration; e.g. `A01_C01` for Yokogawa, `C01` for MD. + method: Method to use for image registration. The available methods + are `phase_cross_correlation` (scikit-image package, works for 2D + & 3D) and "chi2_shift" (image_registration package, only works for + 2D images). + lower_rescale_quantile: Lower quantile for rescaling the image + intensities before applying registration. Can be helpful + to deal with image artifacts. Default is 0. + upper_rescale_quantile: Upper quantile for rescaling the image + intensities before applying registration. Can be helpful + to deal with image artifacts. Default is 0.99. + roi_table: Name of the ROI table over which the task loops to + calculate the registration. Examples: `FOV_ROI_table` => loop over + the field of views, `well_ROI_table` => process the whole well as + one image. + level: Pyramid level of the image to be used for registration. + Choose `0` to process at full resolution. + + """ + logger.info( + f"Running for {zarr_url=}.\n" + f"Calculating translation registration per {roi_table=} for " + f"{wavelength_id=}." + ) + + init_args.reference_zarr_url = init_args.reference_zarr_url + + # Read some parameters from Zarr metadata + ngff_image_meta = load_NgffImageMeta(str(init_args.reference_zarr_url)) + coarsening_xy = ngff_image_meta.coarsening_xy + + # Get channel_index via wavelength_id. + # Intially only allow registration of the same wavelength + channel_ref: OmeroChannel = get_channel_from_image_zarr( + image_zarr_path=init_args.reference_zarr_url, + wavelength_id=wavelength_id, + ) + channel_index_ref = channel_ref.index + + channel_align: OmeroChannel = get_channel_from_image_zarr( + image_zarr_path=zarr_url, + wavelength_id=wavelength_id, + ) + channel_index_align = channel_align.index + + # Lazily load zarr array + data_reference_zyx = da.from_zarr( + f"{init_args.reference_zarr_url}/{level}" + )[channel_index_ref] + data_alignment_zyx = da.from_zarr(f"{zarr_url}/{level}")[ + channel_index_align + ] + + # Check if data is 3D (as not all registration methods work in 3D) + # TODO: Abstract this check into a higher-level Zarr loading class + if is_3D(data_reference_zyx): + if method == RegistrationMethod(RegistrationMethod.CHI2_SHIFT): + raise ValueError( + f"The `{RegistrationMethod.CHI2_SHIFT}` registration method " + "has not been implemented for 3D images and the input image " + f"had a shape of {data_reference_zyx.shape}." + ) + + # Read ROIs + ROI_table_ref = ad.read_zarr( + f"{init_args.reference_zarr_url}/tables/{roi_table}" + ) + ROI_table_x = ad.read_zarr(f"{zarr_url}/tables/{roi_table}") + logger.info( + f"Found {len(ROI_table_x)} ROIs in {roi_table=} to be processed." + ) + + # Check that table type of ROI_table_ref is valid. Note that + # "ngff:region_table" and None are accepted for backwards compatibility + valid_table_types = [ + "roi_table", + "masking_roi_table", + "ngff:region_table", + None, + ] + ROI_table_ref_group = zarr.open_group( + f"{init_args.reference_zarr_url}/tables/{roi_table}", + mode="r", + ) + ref_table_attrs = ROI_table_ref_group.attrs.asdict() + ref_table_type = ref_table_attrs.get("type") + if ref_table_type not in valid_table_types: + raise ValueError( + ( + f"Table '{roi_table}' (with type '{ref_table_type}') is " + "not a valid ROI table." + ) + ) + + # For each acquisition, get the relevant info + # TODO: Add additional checks on ROIs? + if (ROI_table_ref.obs.index != ROI_table_x.obs.index).all(): + raise ValueError( + "Registration is only implemented for ROIs that match between the " + "acquisitions (e.g. well, FOV ROIs). Here, the ROIs in the " + f"reference acquisitions were {ROI_table_ref.obs.index}, but the " + f"ROIs in the alignment acquisition were {ROI_table_x.obs.index}" + ) + # TODO: Make this less restrictive? i.e. could we also run it if different + # acquisitions have different FOVs? But then how do we know which FOVs to + # match? + # If we relax this, downstream assumptions on matching based on order + # in the list will break. + + # Read pixel sizes from zarr attributes + ngff_image_meta_acq_x = load_NgffImageMeta(zarr_url) + pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) + pxl_sizes_zyx_acq_x = ngff_image_meta_acq_x.get_pixel_sizes_zyx(level=0) + + if pxl_sizes_zyx != pxl_sizes_zyx_acq_x: + raise ValueError( + "Pixel sizes need to be equal between acquisitions for " + "registration." + ) + + # Create list of indices for 3D ROIs spanning the entire Z direction + list_indices_ref = convert_ROI_table_to_indices( + ROI_table_ref, + level=level, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=pxl_sizes_zyx, + ) + check_valid_ROI_indices(list_indices_ref, roi_table) + + list_indices_acq_x = convert_ROI_table_to_indices( + ROI_table_x, + level=level, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=pxl_sizes_zyx, + ) + check_valid_ROI_indices(list_indices_acq_x, roi_table) + + num_ROIs = len(list_indices_ref) + compute = True + new_shifts = {} + for i_ROI in range(num_ROIs): + logger.info( + f"Now processing ROI {i_ROI+1}/{num_ROIs} " + f"for channel {channel_align}." + ) + img_ref = load_region( + data_zyx=data_reference_zyx, + region=convert_indices_to_regions(list_indices_ref[i_ROI]), + compute=compute, + ) + img_acq_x = load_region( + data_zyx=data_alignment_zyx, + region=convert_indices_to_regions(list_indices_acq_x[i_ROI]), + compute=compute, + ) + + # Rescale the images + img_ref = rescale_intensity( + img_ref, + in_range=( + np.quantile(img_ref, lower_rescale_quantile), + np.quantile(img_ref, upper_rescale_quantile), + ), + ) + img_acq_x = rescale_intensity( + img_acq_x, + in_range=( + np.quantile(img_acq_x, lower_rescale_quantile), + np.quantile(img_acq_x, upper_rescale_quantile), + ), + ) + + ############## + # Calculate the transformation + ############## + if img_ref.shape != img_acq_x.shape: + raise NotImplementedError( + "This registration is not implemented for ROIs with " + "different shapes between acquisitions." + ) + + shifts = method.register(np.squeeze(img_ref), np.squeeze(img_acq_x))[0] + + ############## + # Store the calculated transformation ### + ############## + # Adapt ROIs for the given ROI table: + ROI_name = ROI_table_ref.obs.index[i_ROI] + new_shifts[ROI_name] = calculate_physical_shifts( + shifts, + level=level, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=pxl_sizes_zyx, + ) + + # Write physical shifts to disk (as part of the ROI table) + logger.info(f"Updating the {roi_table=} with translation columns") + image_group = zarr.group(zarr_url) + new_ROI_table = get_ROI_table_with_translation(ROI_table_x, new_shifts) + write_table( + image_group, + roi_table, + new_ROI_table, + overwrite=True, + table_attrs=ref_table_attrs, + ) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=calculate_registration_image_based, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/cellpose_segmentation.py b/fractal_tasks_core/tasks/cellpose_segmentation.py new file mode 100644 index 000000000..9ee23b8e8 --- /dev/null +++ b/fractal_tasks_core/tasks/cellpose_segmentation.py @@ -0,0 +1,627 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Marco Franzon +# Joel Lüthi +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Image segmentation via Cellpose library. +""" +import logging +import os +import time +from typing import Literal +from typing import Optional + +import anndata as ad +import cellpose +import dask.array as da +import numpy as np +import zarr +from cellpose import models +from pydantic import Field +from pydantic import validate_call + +import fractal_tasks_core +from fractal_tasks_core.labels import prepare_label_group +from fractal_tasks_core.masked_loading import masked_loading_wrapper +from fractal_tasks_core.ngff import load_NgffImageMeta +from fractal_tasks_core.pyramids import build_pyramid +from fractal_tasks_core.roi import ( + array_to_bounding_box_table, +) +from fractal_tasks_core.roi import check_valid_ROI_indices +from fractal_tasks_core.roi import ( + convert_ROI_table_to_indices, +) +from fractal_tasks_core.roi import create_roi_table_from_df_list +from fractal_tasks_core.roi import ( + find_overlaps_in_ROI_indices, +) +from fractal_tasks_core.roi import get_overlapping_pairs_3D +from fractal_tasks_core.roi import is_ROI_table_valid +from fractal_tasks_core.roi import load_region +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tasks.cellpose_utils import ( + _normalize_cellpose_channels, +) +from fractal_tasks_core.tasks.cellpose_utils import ( + CellposeChannel1InputModel, +) +from fractal_tasks_core.tasks.cellpose_utils import ( + CellposeChannel2InputModel, +) +from fractal_tasks_core.tasks.cellpose_utils import ( + CellposeCustomNormalizer, +) +from fractal_tasks_core.tasks.cellpose_utils import ( + CellposeModelParams, +) +from fractal_tasks_core.utils import rescale_datasets + +logger = logging.getLogger(__name__) + +__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ + + +def segment_ROI( + x: np.ndarray, + num_labels_tot: dict[str, int], + model: models.CellposeModel = None, + do_3D: bool = True, + channels: list[int] = [0, 0], + diameter: float = 30.0, + normalize: CellposeCustomNormalizer = CellposeCustomNormalizer(), + normalize2: Optional[CellposeCustomNormalizer] = None, + label_dtype: Optional[np.dtype] = None, + relabeling: bool = True, + advanced_cellpose_model_params: CellposeModelParams = CellposeModelParams(), # noqa: E501 +) -> np.ndarray: + """ + Internal function that runs Cellpose segmentation for a single ROI. + + Args: + x: 4D numpy array. + num_labels_tot: Number of labels already in total image. Used for + relabeling purposes. Using a dict to have a mutable object that + can be edited from within the function without having to be passed + back through the masked_loading_wrapper. + model: An instance of `models.CellposeModel`. + do_3D: If `True`, cellpose runs in 3D mode: runs on xy, xz & yz planes, + then averages the flows. + channels: Which channels to use. If only one channel is provided, `[0, + 0]` should be used. If two channels are provided (the first + dimension of `x` has length of 2), `[1, 2]` should be used + (`x[0, :, :,:]` contains the membrane channel and + `x[1, :, :, :]` contains the nuclear channel). + diameter: Expected object diameter in pixels for cellpose. + normalize: By default, data is normalized so 0.0=1st percentile and + 1.0=99th percentile of image intensities in each channel. + This automatic normalization can lead to issues when the image to + be segmented is very sparse. You can turn off the default + rescaling. With the "custom" option, you can either provide your + own rescaling percentiles or fixed rescaling upper and lower + bound integers. + normalize2: Normalization options for channel 2. If one channel is + normalized with default settings, both channels need to be + normalized with default settings. + label_dtype: Label images are cast into this `np.dtype`. + relabeling: Whether relabeling based on num_labels_tot is performed. + advanced_cellpose_model_params: Advanced Cellpose model parameters + that are passed to the Cellpose `model.eval` method. + """ + + # Write some debugging info + logger.info( + "[segment_ROI] START |" + f" x: {type(x)}, {x.shape} |" + f" {do_3D=} |" + f" {model.diam_mean=} |" + f" {diameter=} |" + f" {advanced_cellpose_model_params.flow_threshold=} |" + f" {normalize.type=}" + ) + x = _normalize_cellpose_channels(x, channels, normalize, normalize2) + + # Actual labeling + t0 = time.perf_counter() + mask, _, _ = model.eval( + x, + channels=channels, + do_3D=do_3D, + net_avg=advanced_cellpose_model_params.net_avg, + augment=advanced_cellpose_model_params.augment, + diameter=diameter, + anisotropy=advanced_cellpose_model_params.anisotropy, + cellprob_threshold=advanced_cellpose_model_params.cellprob_threshold, + flow_threshold=advanced_cellpose_model_params.flow_threshold, + normalize=normalize.cellpose_normalize, + min_size=advanced_cellpose_model_params.min_size, + batch_size=advanced_cellpose_model_params.batch_size, + invert=advanced_cellpose_model_params.invert, + tile=advanced_cellpose_model_params.tile, + tile_overlap=advanced_cellpose_model_params.tile_overlap, + resample=advanced_cellpose_model_params.resample, + interp=advanced_cellpose_model_params.interp, + stitch_threshold=advanced_cellpose_model_params.stitch_threshold, + ) + + if mask.ndim == 2: + # If we get a 2D image, we still return it as a 3D array + mask = np.expand_dims(mask, axis=0) + t1 = time.perf_counter() + + # Write some debugging info + logger.info( + "[segment_ROI] END |" + f" Elapsed: {t1-t0:.3f} s |" + f" {mask.shape=}," + f" {mask.dtype=} (then {label_dtype})," + f" {np.max(mask)=} |" + f" {model.diam_mean=} |" + f" {diameter=} |" + f" {advanced_cellpose_model_params.flow_threshold=}" + ) + + # Shift labels and update relabeling counters + if relabeling: + num_labels_roi = np.max(mask) + mask[mask > 0] += num_labels_tot["num_labels_tot"] + num_labels_tot["num_labels_tot"] += num_labels_roi + + # Write some logs + logger.info(f"ROI had {num_labels_roi=}, {num_labels_tot=}") + + # Check that total number of labels is under control + if num_labels_tot["num_labels_tot"] > np.iinfo(label_dtype).max: + raise ValueError( + "ERROR in re-labeling:" + f"Reached {num_labels_tot} labels, " + f"but dtype={label_dtype}" + ) + + return mask.astype(label_dtype) + + +@validate_call +def cellpose_segmentation( + *, + # Fractal parameters + zarr_url: str, + # Core parameters + level: int, + channel: CellposeChannel1InputModel, + channel2: CellposeChannel2InputModel = Field( + default_factory=CellposeChannel2InputModel + ), + input_ROI_table: str = "FOV_ROI_table", + output_ROI_table: Optional[str] = None, + output_label_name: Optional[str] = None, + # Cellpose-related arguments + diameter_level0: float = 30.0, + # https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/401 # noqa E501 + model_type: Literal[tuple(models.MODEL_NAMES)] = "cyto2", # type: ignore + pretrained_model: Optional[str] = None, + relabeling: bool = True, + use_masks: bool = True, + advanced_cellpose_model_params: CellposeModelParams = Field( + default_factory=CellposeModelParams + ), + overwrite: bool = True, +) -> None: + """ + Run cellpose segmentation on the ROIs of a single OME-Zarr image. + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + (standard argument for Fractal tasks, managed by Fractal server). + level: Pyramid level of the image to be segmented. Choose `0` to + process at full resolution. + channel: Primary channel for segmentation; requires either + `wavelength_id` (e.g. `A01_C01`) or `label` (e.g. `DAPI`), but not + both. Also contains normalization options. By default, data is + normalized so 0.0=1st percentile and 1.0=99th percentile of image + intensities in each channel. + This automatic normalization can lead to issues when the image to + be segmented is very sparse. You can turn off the default + rescaling. With the "custom" option, you can either provide your + own rescaling percentiles or fixed rescaling upper and lower + bound integers. + channel2: Second channel for segmentation (in the same format as + `channel`). If specified, cellpose runs in dual channel mode. + For dual channel segmentation of cells, the first channel should + contain the membrane marker, the second channel should contain the + nuclear marker. + input_ROI_table: Name of the ROI table over which the task loops to + apply Cellpose segmentation. Examples: `FOV_ROI_table` => loop over + the field of views, `organoid_ROI_table` => loop over the organoid + ROI table (generated by another task), `well_ROI_table` => process + the whole well as one image. + output_ROI_table: If provided, a ROI table with that name is created, + which will contain the bounding boxes of the newly segmented + labels. ROI tables should have `ROI` in their name. + output_label_name: Name of the output label image (e.g. `"organoids"`). + diameter_level0: Expected diameter of the objects that should be + segmented in pixels at level 0. Initial diameter is rescaled using + the `level` that was selected. The rescaled value is passed as + the diameter to the `CellposeModel.eval` method. + model_type: Parameter of `CellposeModel` class. Defines which model + should be used. Typical choices are `nuclei`, `cyto`, `cyto2`, etc. + pretrained_model: Parameter of `CellposeModel` class (takes + precedence over `model_type`). Allows you to specify the path of + a custom trained cellpose model. + relabeling: If `True`, apply relabeling so that label values are + unique for all objects in the well. + use_masks: If `True`, try to use masked loading and fall back to + `use_masks=False` if the ROI table is not suitable. Masked + loading is relevant when only a subset of the bounding box should + actually be processed (e.g. running within `organoid_ROI_table`). + advanced_cellpose_model_params: Advanced Cellpose model parameters + that are passed to the Cellpose `model.eval` method. + overwrite: If `True`, overwrite the task output. + """ + logger.info(f"Processing {zarr_url=}") + + # Preliminary checks on Cellpose model + if pretrained_model: + if not os.path.exists(pretrained_model): + raise ValueError(f"{pretrained_model=} does not exist.") + + # Read attributes from NGFF metadata + ngff_image_meta = load_NgffImageMeta(zarr_url) + num_levels = ngff_image_meta.num_levels + coarsening_xy = ngff_image_meta.coarsening_xy + full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) + actual_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=level) + logger.info(f"NGFF image has {num_levels=}") + logger.info(f"NGFF image has {coarsening_xy=}") + logger.info( + f"NGFF image has full-res pixel sizes {full_res_pxl_sizes_zyx}" + ) + logger.info( + f"NGFF image has level-{level} pixel sizes " + f"{actual_res_pxl_sizes_zyx}" + ) + + # Find channel index + omero_channel = channel.get_omero_channel(zarr_url) + if omero_channel: + ind_channel = omero_channel.index + else: + return + + # Find channel index for second channel, if one is provided + if channel2.is_set(): + omero_channel_2 = channel2.get_omero_channel(zarr_url) + if omero_channel_2: + ind_channel_c2 = omero_channel_2.index + else: + return + + # Set channel label + if output_label_name is None: + try: + channel_label = omero_channel.label + output_label_name = f"label_{channel_label}" + except (KeyError, IndexError): + output_label_name = f"label_{ind_channel}" + + # Load ZYX data + # Workaround for #788: Only load channel index when there is a channel + # dimension + if ngff_image_meta.axes_names[0] != "c": + data_zyx = da.from_zarr(f"{zarr_url}/{level}") + if channel2.is_set(): + raise ValueError( + "Dual channel input was specified for an OME-Zarr image " + "without a channel axis" + ) + else: + data_zyx = da.from_zarr(f"{zarr_url}/{level}")[ind_channel] + if channel2.is_set(): + data_zyx_c2 = da.from_zarr(f"{zarr_url}/{level}")[ind_channel_c2] + logger.info(f"Second channel: {data_zyx_c2.shape=}") + logger.info(f"{data_zyx.shape=}") + + # Read ROI table + ROI_table_path = f"{zarr_url}/tables/{input_ROI_table}" + ROI_table = ad.read_zarr(ROI_table_path) + + # Perform some checks on the ROI table + valid_ROI_table = is_ROI_table_valid( + table_path=ROI_table_path, use_masks=use_masks + ) + if use_masks and not valid_ROI_table: + logger.info( + f"ROI table at {ROI_table_path} cannot be used for masked " + "loading. Set use_masks=False." + ) + use_masks = False + logger.info(f"{use_masks=}") + + # Create list of indices for 3D ROIs spanning the entire Z direction + list_indices = convert_ROI_table_to_indices( + ROI_table, + level=level, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) + check_valid_ROI_indices(list_indices, input_ROI_table) + + # If we are not planning to use masked loading, fail for overlapping ROIs + if not use_masks: + overlap = find_overlaps_in_ROI_indices(list_indices) + if overlap: + raise ValueError( + f"ROI indices created from {input_ROI_table} table have " + "overlaps, but we are not using masked loading." + ) + + # Select 2D/3D behavior and set some parameters + do_3D = data_zyx.shape[0] > 1 and len(data_zyx.shape) == 3 + if do_3D: + if advanced_cellpose_model_params.anisotropy is None: + # Compute anisotropy as pixel_size_z/pixel_size_x + advanced_cellpose_model_params.anisotropy = ( + actual_res_pxl_sizes_zyx[0] / actual_res_pxl_sizes_zyx[2] + ) + logger.info(f"Anisotropy: {advanced_cellpose_model_params.anisotropy}") + + # Rescale datasets (only relevant for level>0) + # Workaround for #788 + if ngff_image_meta.axes_names[0] != "c": + new_datasets = rescale_datasets( + datasets=[ds.model_dump() for ds in ngff_image_meta.datasets], + coarsening_xy=coarsening_xy, + reference_level=level, + remove_channel_axis=False, + ) + else: + new_datasets = rescale_datasets( + datasets=[ds.model_dump() for ds in ngff_image_meta.datasets], + coarsening_xy=coarsening_xy, + reference_level=level, + remove_channel_axis=True, + ) + + label_attrs = { + "image-label": { + "version": __OME_NGFF_VERSION__, + "source": {"image": "../../"}, + }, + "multiscales": [ + { + "name": output_label_name, + "version": __OME_NGFF_VERSION__, + "axes": [ + ax.dict() + for ax in ngff_image_meta.multiscale.axes + if ax.type != "channel" + ], + "datasets": new_datasets, + } + ], + } + + image_group = zarr.group(zarr_url) + label_group = prepare_label_group( + image_group, + output_label_name, + overwrite=overwrite, + label_attrs=label_attrs, + logger=logger, + ) + + logger.info( + f"Helper function `prepare_label_group` returned {label_group=}" + ) + logger.info(f"Output label path: {zarr_url}/labels/{output_label_name}/0") + store = zarr.storage.FSStore(f"{zarr_url}/labels/{output_label_name}/0") + label_dtype = np.uint32 + + # Ensure that all output shapes & chunks are 3D (for 2D data: (1, y, x)) + # https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/398 + shape = data_zyx.shape + if len(shape) == 2: + shape = (1, *shape) + chunks = data_zyx.chunksize + if len(chunks) == 2: + chunks = (1, *chunks) + mask_zarr = zarr.create( + shape=shape, + chunks=chunks, + dtype=label_dtype, + store=store, + overwrite=False, + dimension_separator="/", + ) + + logger.info( + f"mask will have shape {data_zyx.shape} " + f"and chunks {data_zyx.chunks}" + ) + + # Initialize cellpose + gpu = advanced_cellpose_model_params.use_gpu and cellpose.core.use_gpu() + if pretrained_model: + model = models.CellposeModel( + gpu=gpu, pretrained_model=pretrained_model + ) + else: + model = models.CellposeModel(gpu=gpu, model_type=model_type) + + # Initialize other things + logger.info(f"Start cellpose_segmentation task for {zarr_url}") + logger.info(f"relabeling: {relabeling}") + logger.info(f"do_3D: {do_3D}") + logger.info(f"use_gpu: {gpu}") + logger.info(f"level: {level}") + logger.info(f"model_type: {model_type}") + logger.info(f"pretrained_model: {pretrained_model}") + logger.info(f"anisotropy: {advanced_cellpose_model_params.anisotropy}") + logger.info("Total well shape/chunks:") + logger.info(f"{data_zyx.shape}") + logger.info(f"{data_zyx.chunks}") + if channel2.is_set(): + logger.info("Dual channel input for cellpose model") + logger.info(f"{data_zyx_c2.shape}") + logger.info(f"{data_zyx_c2.chunks}") + + # Counters for relabeling + num_labels_tot = {"num_labels_tot": 0} + + # Iterate over ROIs + num_ROIs = len(list_indices) + + if output_ROI_table: + bbox_dataframe_list = [] + + logger.info(f"Now starting loop over {num_ROIs} ROIs") + for i_ROI, indices in enumerate(list_indices): + # Define region + s_z, e_z, s_y, e_y, s_x, e_x = indices[:] + region = ( + slice(s_z, e_z), + slice(s_y, e_y), + slice(s_x, e_x), + ) + logger.info(f"Now processing ROI {i_ROI+1}/{num_ROIs}") + + # Prepare single-channel or dual-channel input for cellpose + if channel2.is_set(): + # Dual channel mode, first channel is the membrane channel + img_1 = load_region( + data_zyx, + region, + compute=True, + return_as_3D=True, + ) + img_np = np.zeros((2, *img_1.shape)) + img_np[0, :, :, :] = img_1 + img_np[1, :, :, :] = load_region( + data_zyx_c2, + region, + compute=True, + return_as_3D=True, + ) + channels = [1, 2] + else: + img_np = np.expand_dims( + load_region(data_zyx, region, compute=True, return_as_3D=True), + axis=0, + ) + channels = [0, 0] + + # Prepare keyword arguments for segment_ROI function + kwargs_segment_ROI = dict( + num_labels_tot=num_labels_tot, + model=model, + channels=channels, + do_3D=do_3D, + label_dtype=label_dtype, + diameter=diameter_level0 / coarsening_xy**level, + normalize=channel.normalize, + normalize2=channel2.normalize, + relabeling=relabeling, + advanced_cellpose_model_params=advanced_cellpose_model_params, + ) + + # Prepare keyword arguments for preprocessing function + preprocessing_kwargs = {} + if use_masks: + preprocessing_kwargs = dict( + region=region, + current_label_path=f"{zarr_url}/labels/{output_label_name}/0", + ROI_table_path=ROI_table_path, + ROI_positional_index=i_ROI, + ) + + # Call segment_ROI through the masked-loading wrapper, which includes + # pre/post-processing functions if needed + new_label_img = masked_loading_wrapper( + image_array=img_np, + function=segment_ROI, + kwargs=kwargs_segment_ROI, + use_masks=use_masks, + preprocessing_kwargs=preprocessing_kwargs, + ) + + if output_ROI_table: + bbox_df = array_to_bounding_box_table( + new_label_img, + actual_res_pxl_sizes_zyx, + origin_zyx=(s_z, s_y, s_x), + ) + + bbox_dataframe_list.append(bbox_df) + + overlap_list = get_overlapping_pairs_3D( + bbox_df, full_res_pxl_sizes_zyx + ) + if len(overlap_list) > 0: + logger.warning( + f"ROI {indices} has " + f"{len(overlap_list)} bounding-box pairs overlap" + ) + + # Compute and store 0-th level to disk + da.array(new_label_img).to_zarr( + url=mask_zarr, + region=region, + compute=True, + ) + + logger.info( + f"End cellpose_segmentation task for {zarr_url}, " + "now building pyramids." + ) + + # Starting from on-disk highest-resolution data, build and write to disk a + # pyramid of coarser levels + build_pyramid( + zarrurl=f"{zarr_url}/labels/{output_label_name}", + overwrite=overwrite, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + chunksize=chunks, + aggregation_function=np.max, + ) + + logger.info("End building pyramids") + + if output_ROI_table: + bbox_table = create_roi_table_from_df_list(bbox_dataframe_list) + + # Write to zarr group + image_group = zarr.group(zarr_url) + logger.info( + "Now writing bounding-box ROI table to " + f"{zarr_url}/tables/{output_ROI_table}" + ) + table_attrs = { + "type": "masking_roi_table", + "region": {"path": f"../labels/{output_label_name}"}, + "instance_key": "label", + } + write_table( + image_group, + output_ROI_table, + bbox_table, + overwrite=overwrite, + table_attrs=table_attrs, + ) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=cellpose_segmentation, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/cellpose_utils.py b/fractal_tasks_core/tasks/cellpose_utils.py new file mode 100644 index 000000000..6ba22a29e --- /dev/null +++ b/fractal_tasks_core/tasks/cellpose_utils.py @@ -0,0 +1,468 @@ +# Copyright 2023 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Joel Lüthi +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Helper functions for image normalization in +""" +import logging +from typing import Literal +from typing import Optional + +import numpy as np +from pydantic import BaseModel +from pydantic import Field +from pydantic import model_validator +from typing_extensions import Self + +from fractal_tasks_core.channels import ChannelInputModel +from fractal_tasks_core.channels import ChannelNotFoundError +from fractal_tasks_core.channels import get_channel_from_image_zarr +from fractal_tasks_core.channels import OmeroChannel + + +logger = logging.getLogger(__name__) + + +class CellposeCustomNormalizer(BaseModel): + """ + Validator to handle different normalization scenarios for Cellpose models + + If `type="default"`, then Cellpose default normalization is + used and no other parameters can be specified. + If `type="no_normalization"`, then no normalization is used and no + other parameters can be specified. + If `type="custom"`, then either percentiles or explicit integer + bounds can be applied. + + Attributes: + type: + One of `default` (Cellpose default normalization), `custom` + (using the other custom parameters) or `no_normalization`. + lower_percentile: Specify a custom lower-bound percentile for rescaling + as a float value between 0 and 100. Set to 1 to run the same as + default). You can only specify percentiles or bounds, not both. + upper_percentile: Specify a custom upper-bound percentile for rescaling + as a float value between 0 and 100. Set to 99 to run the same as + default, set to e.g. 99.99 if the default rescaling was too harsh. + You can only specify percentiles or bounds, not both. + lower_bound: Explicit lower bound value to rescale the image at. + Needs to be an integer, e.g. 100. + You can only specify percentiles or bounds, not both. + upper_bound: Explicit upper bound value to rescale the image at. + Needs to be an integer, e.g. 2000. + You can only specify percentiles or bounds, not both. + """ + + type: Literal["default", "custom", "no_normalization"] = "default" + lower_percentile: Optional[float] = Field(None, ge=0, le=100) + upper_percentile: Optional[float] = Field(None, ge=0, le=100) + lower_bound: Optional[int] = None + upper_bound: Optional[int] = None + + # In the future, add an option to allow using precomputed percentiles + # that are stored in OME-Zarr histograms and use this pydantic model that + # those histograms actually exist + + @model_validator(mode="after") + def validate_conditions(self: Self) -> Self: + # Extract values + type = self.type + lower_percentile = self.lower_percentile + upper_percentile = self.upper_percentile + lower_bound = self.lower_bound + upper_bound = self.upper_bound + + # Verify that custom parameters are only provided when type="custom" + if type != "custom": + if lower_percentile is not None: + raise ValueError( + f"Type='{type}' but {lower_percentile=}. " + "Hint: set type='custom'." + ) + if upper_percentile is not None: + raise ValueError( + f"Type='{type}' but {upper_percentile=}. " + "Hint: set type='custom'." + ) + if lower_bound is not None: + raise ValueError( + f"Type='{type}' but {lower_bound=}. " + "Hint: set type='custom'." + ) + if upper_bound is not None: + raise ValueError( + f"Type='{type}' but {upper_bound=}. " + "Hint: set type='custom'." + ) + + # The only valid options are: + # 1. Both percentiles are set and both bounds are unset + # 2. Both bounds are set and both percentiles are unset + are_percentiles_set = ( + lower_percentile is not None, + upper_percentile is not None, + ) + are_bounds_set = ( + lower_bound is not None, + upper_bound is not None, + ) + if len(set(are_percentiles_set)) != 1: + raise ValueError( + "Both lower_percentile and upper_percentile must be set " + "together." + ) + if len(set(are_bounds_set)) != 1: + raise ValueError( + "Both lower_bound and upper_bound must be set together" + ) + if lower_percentile is not None and lower_bound is not None: + raise ValueError( + "You cannot set both explicit bounds and percentile bounds " + "at the same time. Hint: use only one of the two options." + ) + + return self + + @property + def cellpose_normalize(self) -> bool: + """ + Determine whether cellpose should apply its internal normalization. + + If type is set to `custom` or `no_normalization`, don't apply cellpose + internal normalization + """ + return self.type == "default" + + +class CellposeModelParams(BaseModel): + """ + Advanced Cellpose Model Parameters + + Attributes: + cellprob_threshold: Parameter of `CellposeModel.eval` method. Valid + values between -6 to 6. From Cellpose documentation: "Decrease this + threshold if cellpose is not returning as many ROIs as you'd + expect. Similarly, increase this threshold if cellpose is returning + too ROIs particularly from dim areas." + flow_threshold: Parameter of `CellposeModel.eval` method. Valid + values between 0.0 and 1.0. From Cellpose documentation: "Increase + this threshold if cellpose is not returning as many ROIs as you'd + expect. Similarly, decrease this threshold if cellpose is returning + too many ill-shaped ROIs." + anisotropy: Ratio of the pixel sizes along Z and XY axis (ignored if + the image is not three-dimensional). If unset, it is inferred from + the OME-NGFF metadata. + min_size: Parameter of `CellposeModel` class. Minimum size of the + segmented objects (in pixels). Use `-1` to turn off the size + filter. + augment: Parameter of `CellposeModel` class. Whether to use cellpose + augmentation to tile images with overlap. + net_avg: Parameter of `CellposeModel` class. Whether to use cellpose + net averaging to run the 4 built-in networks (useful for `nuclei`, + `cyto` and `cyto2`, not sure it works for the others). + use_gpu: If `False`, always use the CPU; if `True`, use the GPU if + possible (as defined in `cellpose.core.use_gpu()`) and fall-back + to the CPU otherwise. + batch_size: number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage) + invert: invert image pixel intensity before running network (if True, + image is also normalized) + tile: tiles image to ensure GPU/CPU memory usage limited (recommended) + tile_overlap: fraction of overlap of tiles when computing flows + resample: run dynamics at original image size (will be slower but + create more accurate boundaries) + interp: interpolate during 2D dynamics (not available in 3D) + (in previous versions it was False, now it defaults to True) + stitch_threshold: if stitch_threshold>0.0 and not do_3D and equal + image sizes, masks are stitched in 3D to return volume segmentation + """ + + cellprob_threshold: float = 0.0 + flow_threshold: float = 0.4 + anisotropy: Optional[float] = None + min_size: int = 15 + augment: bool = False + net_avg: bool = False + use_gpu: bool = True + batch_size: int = 8 + invert: bool = False + tile: bool = True + tile_overlap: float = 0.1 + resample: bool = True + interp: bool = True + stitch_threshold: float = 0.0 + + +class CellposeChannel1InputModel(ChannelInputModel): + """ + Channel input for cellpose with normalization options. + + Attributes: + wavelength_id: Unique ID for the channel wavelength, e.g. `A01_C01`. + Can only be specified if label is not set. + label: Name of the channel. Can only be specified if wavelength_id is + not set. + normalize: Validator to handle different normalization scenarios for + Cellpose models + """ + + normalize: CellposeCustomNormalizer = Field( + default_factory=CellposeCustomNormalizer + ) + + def get_omero_channel(self, zarr_url) -> OmeroChannel: + try: + return get_channel_from_image_zarr( + image_zarr_path=zarr_url, + wavelength_id=self.wavelength_id, + label=self.label, + ) + except ChannelNotFoundError as e: + logger.warning( + f"Channel with wavelength_id: {self.wavelength_id} " + f"and label: {self.label} not found, exit from the task.\n" + f"Original error: {str(e)}" + ) + return None + + +class CellposeChannel2InputModel(BaseModel): + """ + Channel input for secondary cellpose channel with normalization options. + + The secondary channel is Optional, thus both wavelength_id and label are + optional to be set. The `is_set` function shows whether either value was + set. + + Attributes: + wavelength_id: Unique ID for the channel wavelength, e.g. `A01_C01`. + Can only be specified if label is not set. + label: Name of the channel. Can only be specified if wavelength_id is + not set. + normalize: Validator to handle different normalization scenarios for + Cellpose models + """ + + wavelength_id: Optional[str] = None + label: Optional[str] = None + normalize: CellposeCustomNormalizer = Field( + default_factory=CellposeCustomNormalizer + ) + + @model_validator(mode="after") + def mutually_exclusive_channel_attributes(self: Self) -> Self: + """ + Check that only 1 of `label` or `wavelength_id` is set. + """ + wavelength_id = self.wavelength_id + label = self.label + if (wavelength_id is not None) and (label is not None): + raise ValueError( + "`wavelength_id` and `label` cannot be both set " + f"(given {wavelength_id=} and {label=})." + ) + return self + + def is_set(self): + if self.wavelength_id or self.label: + return True + return False + + def get_omero_channel(self, zarr_url) -> OmeroChannel: + try: + return get_channel_from_image_zarr( + image_zarr_path=zarr_url, + wavelength_id=self.wavelength_id, + label=self.label, + ) + except ChannelNotFoundError as e: + logger.warning( + f"Second channel with wavelength_id: {self.wavelength_id} " + f"and label: {self.label} not found, exit from the task.\n" + f"Original error: {str(e)}" + ) + return None + + +def _normalize_cellpose_channels( + x: np.ndarray, + channels: list[int], + normalize: CellposeCustomNormalizer, + normalize2: CellposeCustomNormalizer, +) -> np.ndarray: + """ + Normalize a cellpose input array by channel. + + Args: + x: 4D numpy array. + channels: Which channels to use. If only one channel is provided, `[0, + 0]` should be used. If two channels are provided (the first + dimension of `x` has length of 2), `[1, 2]` should be used + (`x[0, :, :,:]` contains the membrane channel and + `x[1, :, :, :]` contains the nuclear channel). + normalize: By default, data is normalized so 0.0=1st percentile and + 1.0=99th percentile of image intensities in each channel. + This automatic normalization can lead to issues when the image to + be segmented is very sparse. You can turn off the default + rescaling. With the "custom" option, you can either provide your + own rescaling percentiles or fixed rescaling upper and lower + bound integers. + normalize2: Normalization options for channel 2. If one channel is + normalized with default settings, both channels need to be + normalized with default settings. + + """ + # Optionally perform custom normalization + # normalize channels separately, if normalize2 is provided: + if channels == [1, 2]: + # If run in single channel mode, fails (specified as channel = [0, 0]) + if (normalize.type == "default") != (normalize2.type == "default"): + raise ValueError( + "ERROR in normalization:" + f" {normalize.type=} and {normalize2.type=}." + " Either both need to be 'default', or none of them." + ) + if normalize.type == "custom": + x[channels[0] - 1 : channels[0]] = normalized_img( + x[channels[0] - 1 : channels[0]], + lower_p=normalize.lower_percentile, + upper_p=normalize.upper_percentile, + lower_bound=normalize.lower_bound, + upper_bound=normalize.upper_bound, + ) + if normalize2.type == "custom": + x[channels[1] - 1 : channels[1]] = normalized_img( + x[channels[1] - 1 : channels[1]], + lower_p=normalize2.lower_percentile, + upper_p=normalize2.upper_percentile, + lower_bound=normalize2.lower_bound, + upper_bound=normalize2.upper_bound, + ) + + # otherwise, use first normalize to normalize all channels: + else: + if normalize.type == "custom": + x = normalized_img( + x, + lower_p=normalize.lower_percentile, + upper_p=normalize.upper_percentile, + lower_bound=normalize.lower_bound, + upper_bound=normalize.upper_bound, + ) + + return x + + +def normalized_img( + img: np.ndarray, + axis: int = -1, + invert: bool = False, + lower_p: float = 1.0, + upper_p: float = 99.0, + lower_bound: Optional[int] = None, + upper_bound: Optional[int] = None, +): + """normalize each channel of the image so that so that 0.0=lower percentile + or lower bound and 1.0=upper percentile or upper bound of image intensities. + + The normalization can result in values < 0 or > 1 (no clipping). + + Based on https://github.com/MouseLand/cellpose/blob/4f5661983c3787efa443bbbd3f60256f4fd8bf53/cellpose/transforms.py#L375 # noqa E501 + + optional inversion + + Parameters + ------------ + + img: ND-array (at least 3 dimensions) + + axis: channel axis to loop over for normalization + + invert: invert image (useful if cells are dark instead of bright) + + lower_p: Lower percentile for rescaling + + upper_p: Upper percentile for rescaling + + lower_bound: Lower fixed-value used for rescaling + + upper_bound: Upper fixed-value used for rescaling + + Returns + --------------- + + img: ND-array, float32 + normalized image of same size + + """ + if img.ndim < 3: + error_message = "Image needs to have at least 3 dimensions" + logger.critical(error_message) + raise ValueError(error_message) + + img = img.astype(np.float32) + img = np.moveaxis(img, axis, 0) + for k in range(img.shape[0]): + if lower_p is not None: + # ptp can still give nan's with weird images + i99 = np.percentile(img[k], upper_p) + i1 = np.percentile(img[k], lower_p) + if i99 - i1 > +1e-3: # np.ptp(img[k]) > 1e-3: + img[k] = normalize_percentile( + img[k], lower=lower_p, upper=upper_p + ) + if invert: + img[k] = -1 * img[k] + 1 + else: + img[k] = 0 + elif lower_bound is not None: + if upper_bound - lower_bound > +1e-3: + img[k] = normalize_bounds( + img[k], lower=lower_bound, upper=upper_bound + ) + if invert: + img[k] = -1 * img[k] + 1 + else: + img[k] = 0 + else: + raise ValueError("No normalization method specified") + img = np.moveaxis(img, 0, axis) + return img + + +def normalize_percentile(Y: np.ndarray, lower: float = 1, upper: float = 99): + """normalize image so 0.0 is lower percentile and 1.0 is upper percentile + Percentiles are passed as floats (must be between 0 and 100) + + Args: + Y: The image to be normalized + lower: Lower percentile + upper: Upper percentile + + """ + X = Y.copy() + x01 = np.percentile(X, lower) + x99 = np.percentile(X, upper) + X = (X - x01) / (x99 - x01) + return X + + +def normalize_bounds(Y: np.ndarray, lower: int = 0, upper: int = 65535): + """normalize image so 0.0 is lower value and 1.0 is upper value + + Args: + Y: The image to be normalized + lower: Lower normalization value + upper: Upper normalization value + + """ + X = Y.copy() + X = (X - lower) / (upper - lower) + return X diff --git a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_compute.py b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_compute.py new file mode 100644 index 000000000..a316e07db --- /dev/null +++ b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_compute.py @@ -0,0 +1,238 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Marco Franzon +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Task that writes image data to an existing OME-NGFF zarr array. +""" +import logging + +import dask.array as da +import zarr +from anndata import read_zarr +from dask.array.image import imread +from pydantic import validate_call + +from fractal_tasks_core.cellvoyager.filenames import ( + glob_with_multiple_patterns, +) +from fractal_tasks_core.cellvoyager.filenames import parse_filename +from fractal_tasks_core.channels import get_omero_channel_list +from fractal_tasks_core.channels import OmeroChannel +from fractal_tasks_core.ngff import load_NgffImageMeta +from fractal_tasks_core.pyramids import build_pyramid +from fractal_tasks_core.roi import check_valid_ROI_indices +from fractal_tasks_core.roi import ( + convert_ROI_table_to_indices, +) +from fractal_tasks_core.tasks.io_models import InitArgsCellVoyager + + +logger = logging.getLogger(__name__) + + +def sort_fun(filename: str) -> list[int]: + """ + Takes a string (filename of a Yokogawa image), extract site and + z-index metadata and returns them as a list of integers. + + Args: + filename: Name of the image file. + """ + + filename_metadata = parse_filename(filename) + site = int(filename_metadata["F"]) + z_index = int(filename_metadata["Z"]) + return [site, z_index] + + +@validate_call +def cellvoyager_to_ome_zarr_compute( + *, + # Fractal parameters + zarr_url: str, + init_args: InitArgsCellVoyager, +): + """ + Convert Yokogawa output (png, tif) to zarr file. + + This task is run after an init task (typically + `cellvoyager_to_ome_zarr_init` or + `cellvoyager_to_ome_zarr_init_multiplex`), and it populates the empty + OME-Zarr files that were prepared. + + Note that the current task always overwrites existing data. To avoid this + behavior, set the `overwrite` argument of the init task to `False`. + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + (standard argument for Fractal tasks, managed by Fractal server). + init_args: Intialization arguments provided by + `create_cellvoyager_ome_zarr_init`. + """ + zarr_url = zarr_url.rstrip("/") + # Read attributes from NGFF metadata + ngff_image_meta = load_NgffImageMeta(zarr_url) + num_levels = ngff_image_meta.num_levels + coarsening_xy = ngff_image_meta.coarsening_xy + full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) + logger.info(f"NGFF image has {num_levels=}") + logger.info(f"NGFF image has {coarsening_xy=}") + logger.info( + f"NGFF image has full-res pixel sizes {full_res_pxl_sizes_zyx}" + ) + + channels: list[OmeroChannel] = get_omero_channel_list( + image_zarr_path=zarr_url + ) + wavelength_ids = [c.wavelength_id for c in channels] + + # Read useful information from ROI table + adata = read_zarr(f"{zarr_url}/tables/FOV_ROI_table") + fov_indices = convert_ROI_table_to_indices( + adata, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) + check_valid_ROI_indices(fov_indices, "FOV_ROI_table") + adata_well = read_zarr(f"{zarr_url}/tables/well_ROI_table") + well_indices = convert_ROI_table_to_indices( + adata_well, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) + check_valid_ROI_indices(well_indices, "well_ROI_table") + if len(well_indices) > 1: + raise ValueError(f"Something wrong with {well_indices=}") + + max_z = well_indices[0][1] + max_y = well_indices[0][3] + max_x = well_indices[0][5] + + # Load a single image, to retrieve useful information + include_patterns = [ + f"{init_args.plate_prefix}_{init_args.well_ID}_*." + f"{init_args.image_extension}" + ] + if init_args.include_glob_patterns: + include_patterns.extend(init_args.include_glob_patterns) + + exclude_patterns = [] + if init_args.exclude_glob_patterns: + exclude_patterns.extend(init_args.exclude_glob_patterns) + + tmp_images = glob_with_multiple_patterns( + folder=init_args.image_dir, + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + sample = imread(tmp_images.pop()) + + # Initialize zarr + chunksize = (1, 1, sample.shape[1], sample.shape[2]) + canvas_zarr = zarr.create( + shape=(len(wavelength_ids), max_z, max_y, max_x), + chunks=chunksize, + dtype=sample.dtype, + store=zarr.storage.FSStore(zarr_url + "/0"), + overwrite=True, + dimension_separator="/", + ) + + # Loop over channels + for i_c, wavelength_id in enumerate(wavelength_ids): + A, C = wavelength_id.split("_") + + include_patterns = [ + f"{init_args.plate_prefix}_{init_args.well_ID}_*{A}*{C}*." + f"{init_args.image_extension}" + ] + if init_args.include_glob_patterns: + include_patterns.extend(init_args.include_glob_patterns) + filenames_set = glob_with_multiple_patterns( + folder=init_args.image_dir, + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + filenames = sorted(list(filenames_set), key=sort_fun) + if len(filenames) == 0: + raise ValueError( + "Error in yokogawa_to_ome_zarr: len(filenames)=0.\n" + f" image_dir: {init_args.image_dir}\n" + f" wavelength_id: {wavelength_id},\n" + f" patterns: {include_patterns}\n" + f" exclusion patterns: {exclude_patterns}\n" + ) + # Loop over 3D FOV ROIs + for indices in fov_indices: + s_z, e_z, s_y, e_y, s_x, e_x = indices[:] + region = ( + slice(i_c, i_c + 1), + slice(s_z, e_z), + slice(s_y, e_y), + slice(s_x, e_x), + ) + FOV_3D = da.concatenate( + [imread(img) for img in filenames[:e_z]], + ) + FOV_4D = da.expand_dims(FOV_3D, axis=0) + filenames = filenames[e_z:] + da.array(FOV_4D).to_zarr( + url=canvas_zarr, + region=region, + compute=True, + ) + + # Starting from on-disk highest-resolution data, build and write to disk a + # pyramid of coarser levels + build_pyramid( + zarrurl=zarr_url, + overwrite=True, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + chunksize=chunksize, + ) + + # Generate image list updates + # TODO: Can we check for dimensionality more robustly? Just checks for the + # last FOV of the last wavelength now + if FOV_4D.shape[-3] > 1: + is_3D = True + else: + is_3D = False + # FIXME: Get plate name from zarr_url => works for duplicate plate names + # with suffixes + print(zarr_url) + plate_name = zarr_url.split("/")[-4] + attributes = { + "plate": plate_name, + "well": init_args.well_ID, + } + if init_args.acquisition is not None: + attributes["acquisition"] = init_args.acquisition + + image_list_updates = dict( + image_list_updates=[ + dict( + zarr_url=zarr_url, + attributes=attributes, + types={"is_3D": is_3D}, + ) + ] + ) + + return image_list_updates + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=cellvoyager_to_ome_zarr_compute, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init.py b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init.py new file mode 100644 index 000000000..113f0113f --- /dev/null +++ b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init.py @@ -0,0 +1,494 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Create structure for OME-NGFF zarr array. +""" +import os +from pathlib import Path +from typing import Any +from typing import Optional + +import pandas as pd +from pydantic import validate_call + +import fractal_tasks_core +from fractal_tasks_core.cellvoyager.filenames import ( + glob_with_multiple_patterns, +) +from fractal_tasks_core.cellvoyager.filenames import parse_filename +from fractal_tasks_core.cellvoyager.metadata import ( + parse_yokogawa_metadata, +) +from fractal_tasks_core.cellvoyager.wells import generate_row_col_split +from fractal_tasks_core.cellvoyager.wells import get_filename_well_id +from fractal_tasks_core.channels import check_unique_wavelength_ids +from fractal_tasks_core.channels import define_omero_channels +from fractal_tasks_core.channels import OmeroChannel +from fractal_tasks_core.ngff.specs import NgffImageMeta +from fractal_tasks_core.ngff.specs import Plate +from fractal_tasks_core.ngff.specs import Well +from fractal_tasks_core.roi import prepare_FOV_ROI_table +from fractal_tasks_core.roi import prepare_well_ROI_table +from fractal_tasks_core.roi import remove_FOV_overlaps +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tasks.io_models import InitArgsCellVoyager +from fractal_tasks_core.zarr_utils import open_zarr_group_with_overwrite + +__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ + +import logging + +logger = logging.getLogger(__name__) + + +@validate_call +def cellvoyager_to_ome_zarr_init( + *, + # Fractal parameters + zarr_urls: list[str], + zarr_dir: str, + # Core parameters + image_dirs: list[str], + allowed_channels: list[OmeroChannel], + # Advanced parameters + include_glob_patterns: Optional[list[str]] = None, + exclude_glob_patterns: Optional[list[str]] = None, + num_levels: int = 5, + coarsening_xy: int = 2, + image_extension: str = "tif", + metadata_table_file: Optional[str] = None, + overwrite: bool = False, +) -> dict[str, Any]: + """ + Create a OME-NGFF zarr folder, without reading/writing image data. + + Find plates (for each folder in input_paths): + + - glob image files, + - parse metadata from image filename to identify plates, + - identify populated channels. + + Create a zarr folder (for each plate): + + - parse mlf metadata, + - identify wells and field of view (FOV), + - create FOV ZARR, + - verify that channels are uniform (i.e., same channels). + + Args: + zarr_urls: List of paths or urls to the individual OME-Zarr image to + be processed. Not used by the converter task. + (standard argument for Fractal tasks, managed by Fractal server). + zarr_dir: path of the directory where the new OME-Zarrs will be + created. + (standard argument for Fractal tasks, managed by Fractal server). + image_dirs: list of paths to the folders that contains the Cellvoyager + image files. Each entry is a path to a folder that contains the + image files themselves for a multiwell plate and the + MeasurementData & MeasurementDetail metadata files. + allowed_channels: A list of `OmeroChannel` s, where each channel must + include the `wavelength_id` attribute and where the + `wavelength_id` values must be unique across the list. + include_glob_patterns: If specified, only parse images with filenames + that match with all these patterns. Patterns must be defined as in + https://docs.python.org/3/library/fnmatch.html, Example: + `image_glob_pattern=["*_B03_*"]` => only process well B03 + `image_glob_pattern=["*_C09_*", "*F016*", "*Z[0-5][0-9]C*"]` => + only process well C09, field of view 16 and Z planes 0-59. + Can interact with exclude_glob_patterns: All included images - all + excluded images gives the final list of images to process + exclude_glob_patterns: If specified, exclude any image where the + filename matches any of the exclusion patterns. Patterns are + specified the same as for include_glob_patterns. + num_levels: Number of resolution-pyramid levels. If set to `5`, there + will be the full-resolution level and 4 levels of + downsampled images. + coarsening_xy: Linear coarsening factor between subsequent levels. + If set to `2`, level 1 is 2x downsampled, level 2 is + 4x downsampled etc. + image_extension: Filename extension of images (e.g. `"tif"` or `"png"`) + metadata_table_file: If `None`, parse Yokogawa metadata from mrf/mlf + files in the input_path folder; else, the full path to a csv file + containing the parsed metadata table. + overwrite: If `True`, overwrite the task output. + + Returns: + A metadata dictionary containing important metadata about the OME-Zarr + plate, the images and some parameters required by downstream tasks + (like `num_levels`). + """ + + # Preliminary checks on metadata_table_file + if metadata_table_file: + if not metadata_table_file.endswith(".csv"): + raise ValueError(f"{metadata_table_file=} is not a csv file") + if not os.path.isfile(metadata_table_file): + raise FileNotFoundError(f"{metadata_table_file=} does not exist") + + # Identify all plates and all channels, across all input folders + plates = [] + actual_wavelength_ids = None + dict_plate_paths = {} + dict_plate_prefixes: dict[str, Any] = {} + + # Preliminary checks on allowed_channels argument + check_unique_wavelength_ids(allowed_channels) + + for image_dir in image_dirs: + # Glob image filenames + include_patterns = [f"*.{image_extension}"] + exclude_patterns = [] + if include_glob_patterns: + include_patterns.extend(include_glob_patterns) + if exclude_glob_patterns: + exclude_patterns.extend(exclude_glob_patterns) + input_filenames = glob_with_multiple_patterns( + folder=image_dir, + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + + tmp_wavelength_ids = [] + tmp_plates = [] + for fn in input_filenames: + try: + filename_metadata = parse_filename(Path(fn).name) + plate_prefix = filename_metadata["plate_prefix"] + plate = filename_metadata["plate"] + if plate not in dict_plate_prefixes.keys(): + dict_plate_prefixes[plate] = plate_prefix + tmp_plates.append(plate) + A = filename_metadata["A"] + C = filename_metadata["C"] + tmp_wavelength_ids.append(f"A{A}_C{C}") + except ValueError as e: + logger.warning( + f'Skipping "{Path(fn).name}". Original error: ' + str(e) + ) + tmp_plates = sorted(list(set(tmp_plates))) + tmp_wavelength_ids = sorted(list(set(tmp_wavelength_ids))) + + info = ( + "Listing plates/channels:\n" + f"Folder: {image_dir}\n" + f"Include Patterns: {include_patterns}\n" + f"Exclude Patterns: {exclude_patterns}\n" + f"Plates: {tmp_plates}\n" + f"Channels: {tmp_wavelength_ids}\n" + ) + + # Check that only one plate is found + if len(tmp_plates) > 1: + raise ValueError(f"{info}ERROR: {len(tmp_plates)} plates detected") + elif len(tmp_plates) == 0: + raise ValueError(f"{info}ERROR: No plates detected") + plate = tmp_plates[0] + + # If plate already exists in other folder, add suffix + if plate in plates: + ind = 1 + new_plate = f"{plate}_{ind}" + while new_plate in plates: + new_plate = f"{plate}_{ind}" + ind += 1 + logger.info( + f"WARNING: {plate} already exists, renaming it as {new_plate}" + ) + plates.append(new_plate) + dict_plate_prefixes[new_plate] = dict_plate_prefixes[plate] + plate = new_plate + else: + plates.append(plate) + + # Check that channels are the same as in previous plates + if actual_wavelength_ids is None: + actual_wavelength_ids = tmp_wavelength_ids[:] + else: + if actual_wavelength_ids != tmp_wavelength_ids: + raise ValueError( + f"ERROR\n{info}\nERROR:" + f" expected channels {actual_wavelength_ids}" + ) + + # Update dict_plate_paths + dict_plate_paths[plate] = image_dir + + # Check that all channels are in the allowed_channels + allowed_wavelength_ids = [ + channel.wavelength_id for channel in allowed_channels + ] + if not set(actual_wavelength_ids).issubset(set(allowed_wavelength_ids)): + msg = "ERROR in create_ome_zarr\n" + msg += f"actual_wavelength_ids: {actual_wavelength_ids}\n" + msg += f"allowed_wavelength_ids: {allowed_wavelength_ids}\n" + raise ValueError(msg) + + # Create actual_channels, i.e. a list of the channel dictionaries which are + # present + actual_channels = [ + channel + for channel in allowed_channels + if channel.wavelength_id in actual_wavelength_ids + ] + + ################################################################ + # Create well/image OME-Zarr folders on disk, and prepare output + # metadata + parallelization_list = [] + + for plate in plates: + # Define plate zarr + relative_zarrurl = f"{plate}.zarr" + in_path = dict_plate_paths[plate] + logger.info(f"Creating {relative_zarrurl}") + # Call zarr.open_group wrapper, which handles overwrite=True/False + group_plate = open_zarr_group_with_overwrite( + str(Path(zarr_dir) / relative_zarrurl), + overwrite=overwrite, + ) + + # Obtain FOV-metadata dataframe + if metadata_table_file is None: + mrf_path = f"{in_path}/MeasurementDetail.mrf" + mlf_path = f"{in_path}/MeasurementData.mlf" + + site_metadata, number_images_mlf = parse_yokogawa_metadata( + mrf_path, + mlf_path, + include_patterns=include_glob_patterns, + exclude_patterns=exclude_glob_patterns, + ) + site_metadata = remove_FOV_overlaps(site_metadata) + + # If a metadata table was passed, load it and use it directly + else: + logger.warning( + "Since a custom metadata table was provided, there will " + "be no additional check on the number of image files." + ) + site_metadata = pd.read_csv(metadata_table_file) + site_metadata.set_index(["well_id", "FieldIndex"], inplace=True) + + # Extract pixel sizes and bit_depth + pixel_size_z = site_metadata["pixel_size_z"].iloc[0] + pixel_size_y = site_metadata["pixel_size_y"].iloc[0] + pixel_size_x = site_metadata["pixel_size_x"].iloc[0] + bit_depth = site_metadata["bit_depth"].iloc[0] + + if min(pixel_size_z, pixel_size_y, pixel_size_x) < 1e-9: + raise ValueError(pixel_size_z, pixel_size_y, pixel_size_x) + + # Identify all wells + plate_prefix = dict_plate_prefixes[plate] + + include_patterns = [f"{plate_prefix}_*.{image_extension}"] + if include_glob_patterns: + include_patterns.extend(include_glob_patterns) + plate_images = glob_with_multiple_patterns( + folder=str(in_path), + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + + wells = [ + parse_filename(os.path.basename(fn))["well"] for fn in plate_images + ] + wells = sorted(list(set(wells))) + + # Verify that all wells have all channels + for well in wells: + include_patterns = [f"{plate_prefix}_{well}_*.{image_extension}"] + if include_glob_patterns: + include_patterns.extend(include_glob_patterns) + well_images = glob_with_multiple_patterns( + folder=str(in_path), + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + + # Check number of images matches with expected one + if metadata_table_file is None: + num_images_glob = len(well_images) + num_images_expected = number_images_mlf[well] + if num_images_glob != num_images_expected: + raise ValueError( + f"Wrong number of images for {well=}\n" + f"Expected {num_images_expected} (from mlf file)\n" + f"Found {num_images_glob} files\n" + "Other parameters:\n" + f" {image_extension=}\n" + f" {include_glob_patterns=}\n" + f" {exclude_glob_patterns=}\n" + ) + + well_wavelength_ids = [] + for fpath in well_images: + try: + filename_metadata = parse_filename(os.path.basename(fpath)) + well_wavelength_ids.append( + f"A{filename_metadata['A']}_C{filename_metadata['C']}" + ) + except IndexError: + logger.info(f"Skipping {fpath}") + well_wavelength_ids = sorted(list(set(well_wavelength_ids))) + if well_wavelength_ids != actual_wavelength_ids: + raise ValueError( + f"ERROR: well {well} in plate {plate} (prefix: " + f"{plate_prefix}) has missing channels.\n" + f"Expected: {actual_channels}\n" + f"Found: {well_wavelength_ids}.\n" + ) + + well_rows_columns = generate_row_col_split(wells) + + row_list = [ + well_row_column[0] for well_row_column in well_rows_columns + ] + col_list = [ + well_row_column[1] for well_row_column in well_rows_columns + ] + row_list = sorted(list(set(row_list))) + col_list = sorted(list(set(col_list))) + + plate_attrs = { + "acquisitions": [{"id": 0, "name": plate}], + "columns": [{"name": col} for col in col_list], + "rows": [{"name": row} for row in row_list], + "version": __OME_NGFF_VERSION__, + "wells": [ + { + "path": well_row_column[0] + "/" + well_row_column[1], + "rowIndex": row_list.index(well_row_column[0]), + "columnIndex": col_list.index(well_row_column[1]), + } + for well_row_column in well_rows_columns + ], + } + + # Validate plate attrs: + Plate(**plate_attrs) + + group_plate.attrs["plate"] = plate_attrs + + for row, column in well_rows_columns: + parallelization_list.append( + { + "zarr_url": f"{zarr_dir}/{plate}.zarr/{row}/{column}/0", + "init_args": InitArgsCellVoyager( + image_dir=in_path, + plate_prefix=plate_prefix, + well_ID=get_filename_well_id(row, column), + image_extension=image_extension, + include_glob_patterns=include_glob_patterns, + exclude_glob_patterns=exclude_glob_patterns, + ).model_dump(), + } + ) + group_well = group_plate.create_group(f"{row}/{column}/") + + well_attrs = { + "images": [{"path": "0"}], + "version": __OME_NGFF_VERSION__, + } + + # Validate well attrs: + Well(**well_attrs) + group_well.attrs["well"] = well_attrs + + group_image = group_well.create_group("0") # noqa: F841 + group_image.attrs["multiscales"] = [ + { + "version": __OME_NGFF_VERSION__, + "axes": [ + {"name": "c", "type": "channel"}, + { + "name": "z", + "type": "space", + "unit": "micrometer", + }, + { + "name": "y", + "type": "space", + "unit": "micrometer", + }, + { + "name": "x", + "type": "space", + "unit": "micrometer", + }, + ], + "datasets": [ + { + "path": f"{ind_level}", + "coordinateTransformations": [ + { + "type": "scale", + "scale": [ + 1, + pixel_size_z, + pixel_size_y + * coarsening_xy**ind_level, + pixel_size_x + * coarsening_xy**ind_level, + ], + } + ], + } + for ind_level in range(num_levels) + ], + } + ] + + group_image.attrs["omero"] = { + "id": 1, # TODO does this depend on the plate number? + "name": "TBD", + "version": __OME_NGFF_VERSION__, + "channels": define_omero_channels( + channels=actual_channels, bit_depth=bit_depth + ), + } + + # Validate Image attrs + NgffImageMeta(**group_image.attrs) + + # Prepare AnnData tables for FOV/well ROIs + well_id = get_filename_well_id(row, column) + FOV_ROIs_table = prepare_FOV_ROI_table(site_metadata.loc[well_id]) + well_ROIs_table = prepare_well_ROI_table( + site_metadata.loc[well_id] + ) + + # Write AnnData tables into the `tables` zarr group + write_table( + group_image, + "FOV_ROI_table", + FOV_ROIs_table, + overwrite=overwrite, + table_attrs={"type": "roi_table"}, + ) + write_table( + group_image, + "well_ROI_table", + well_ROIs_table, + overwrite=overwrite, + table_attrs={"type": "roi_table"}, + ) + + return dict(parallelization_list=parallelization_list) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=cellvoyager_to_ome_zarr_init, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init_multiplex.py b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init_multiplex.py new file mode 100644 index 000000000..d9598a049 --- /dev/null +++ b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init_multiplex.py @@ -0,0 +1,541 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Create OME-NGFF zarr group, for multiplexing dataset. +""" +import os +from pathlib import Path +from typing import Any +from typing import Optional + +import pandas as pd +import zarr +from pydantic import validate_call +from zarr.errors import ContainsGroupError + +import fractal_tasks_core +from fractal_tasks_core.cellvoyager.filenames import ( + glob_with_multiple_patterns, +) +from fractal_tasks_core.cellvoyager.filenames import parse_filename +from fractal_tasks_core.cellvoyager.metadata import ( + parse_yokogawa_metadata, +) +from fractal_tasks_core.cellvoyager.wells import generate_row_col_split +from fractal_tasks_core.cellvoyager.wells import get_filename_well_id +from fractal_tasks_core.channels import check_unique_wavelength_ids +from fractal_tasks_core.channels import check_well_channel_labels +from fractal_tasks_core.channels import define_omero_channels +from fractal_tasks_core.ngff.specs import NgffImageMeta +from fractal_tasks_core.ngff.specs import Plate +from fractal_tasks_core.ngff.specs import Well +from fractal_tasks_core.roi import prepare_FOV_ROI_table +from fractal_tasks_core.roi import prepare_well_ROI_table +from fractal_tasks_core.roi import remove_FOV_overlaps +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tasks.io_models import InitArgsCellVoyager +from fractal_tasks_core.tasks.io_models import MultiplexingAcquisition +from fractal_tasks_core.zarr_utils import open_zarr_group_with_overwrite + +__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ + +import logging + +logger = logging.getLogger(__name__) + + +@validate_call +def cellvoyager_to_ome_zarr_init_multiplex( + *, + # Fractal parameters + zarr_urls: list[str], + zarr_dir: str, + # Core parameters + acquisitions: dict[str, MultiplexingAcquisition], + # Advanced parameters + include_glob_patterns: Optional[list[str]] = None, + exclude_glob_patterns: Optional[list[str]] = None, + num_levels: int = 5, + coarsening_xy: int = 2, + image_extension: str = "tif", + metadata_table_files: Optional[dict[str, str]] = None, + overwrite: bool = False, +) -> dict[str, Any]: + """ + Create OME-NGFF structure and metadata to host a multiplexing dataset. + + This task takes a set of image folders (i.e. different multiplexing + acquisitions) and build the internal structure and metadata of a OME-NGFF + zarr group, without actually loading/writing the image data. + + Each element in input_paths should be treated as a different acquisition. + + Args: + zarr_urls: List of paths or urls to the individual OME-Zarr image to + be processed. Not used by the converter task. + (standard argument for Fractal tasks, managed by Fractal server). + zarr_dir: path of the directory where the new OME-Zarrs will be + created. + (standard argument for Fractal tasks, managed by Fractal server). + acquisitions: dictionary of acquisitions. Each key is the acquisition + identifier (normally 0, 1, 2, 3 etc.). Each item defines the + acquisition by providing the image_dir and the allowed_channels. + include_glob_patterns: If specified, only parse images with filenames + that match with all these patterns. Patterns must be defined as in + https://docs.python.org/3/library/fnmatch.html, Example: + `image_glob_pattern=["*_B03_*"]` => only process well B03 + `image_glob_pattern=["*_C09_*", "*F016*", "*Z[0-5][0-9]C*"]` => + only process well C09, field of view 16 and Z planes 0-59. + Can interact with exclude_glob_patterns: All included images - all + excluded images gives the final list of images to process + exclude_glob_patterns: If specified, exclude any image where the + filename matches any of the exclusion patterns. Patterns are + specified the same as for include_glob_patterns. + num_levels: Number of resolution-pyramid levels. If set to `5`, there + will be the full-resolution level and 4 levels of downsampled + images. + coarsening_xy: Linear coarsening factor between subsequent levels. + If set to `2`, level 1 is 2x downsampled, level 2 is 4x downsampled + etc. + image_extension: Filename extension of images + (e.g. `"tif"` or `"png"`). + metadata_table_files: If `None`, parse Yokogawa metadata from mrf/mlf + files in the input_path folder; else, a dictionary of key-value + pairs like `(acquisition, path)` with `acquisition` a string like + the key of the `acquisitions` dict and `path` pointing to a csv + file containing the parsed metadata table. + overwrite: If `True`, overwrite the task output. + + Returns: + A metadata dictionary containing important metadata about the OME-Zarr + plate, the images and some parameters required by downstream tasks + (like `num_levels`). + """ + + if metadata_table_files: + # Checks on the dict: + # 1. Acquisitions in acquisitions dict and metadata_table_files match + # 2. Files end with ".csv" + # 3. Files exist. + if set(acquisitions.keys()) != set(metadata_table_files.keys()): + raise ValueError( + "Mismatch in acquisition keys between " + f"{acquisitions.keys()=} and " + f"{metadata_table_files.keys()=}" + ) + for f in metadata_table_files.values(): + if not f.endswith(".csv"): + raise ValueError( + f"{f} (in metadata_table_file) is not a csv file." + ) + if not os.path.isfile(f): + raise ValueError( + f"{f} (in metadata_table_file) does not exist." + ) + + # Preliminary checks on acquisitions + # Note that in metadata the keys of dictionary arguments should be + # strings (and not integers), so that they can be read from a JSON file + for key, values in acquisitions.items(): + if not isinstance(key, str): + raise ValueError(f"{acquisitions=} has non-string keys") + check_unique_wavelength_ids(values.allowed_channels) + try: + int(key) + except ValueError: + raise ValueError("Acquisition dictionary keys need to be integers") + + # Identify all plates and all channels, per input folders + dict_acquisitions: dict = {} + acquisitions_sorted = sorted(list(acquisitions.keys())) + for acquisition in acquisitions_sorted: + acq_input = acquisitions[acquisition] + dict_acquisitions[acquisition] = {} + + actual_wavelength_ids = [] + plates = [] + plate_prefixes = [] + + # Loop over all images + include_patterns = [f"*.{image_extension}"] + exclude_patterns = [] + if include_glob_patterns: + include_patterns.extend(include_glob_patterns) + if exclude_glob_patterns: + exclude_patterns.extend(exclude_glob_patterns) + input_filenames = glob_with_multiple_patterns( + folder=acq_input.image_dir, + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + for fn in input_filenames: + try: + filename_metadata = parse_filename(Path(fn).name) + plate = filename_metadata["plate"] + plates.append(plate) + plate_prefix = filename_metadata["plate_prefix"] + plate_prefixes.append(plate_prefix) + A = filename_metadata["A"] + C = filename_metadata["C"] + actual_wavelength_ids.append(f"A{A}_C{C}") + except ValueError as e: + logger.warning( + f'Skipping "{Path(fn).name}". Original error: ' + str(e) + ) + plates = sorted(list(set(plates))) + actual_wavelength_ids = sorted(list(set(actual_wavelength_ids))) + + info = ( + "Listing all plates/channels:\n" + f"Include patterns: {include_patterns}\n" + f"Exclude patterns: {exclude_patterns}\n" + f"Plates: {plates}\n" + f"Actual wavelength IDs: {actual_wavelength_ids}\n" + ) + + # Check that a folder includes a single plate + if len(plates) > 1: + raise ValueError(f"{info}ERROR: {len(plates)} plates detected") + elif len(plates) == 0: + raise ValueError(f"{info}ERROR: No plates detected") + original_plate = plates[0] + plate_prefix = plate_prefixes[0] + + # Replace plate with the one of the first acquisition + if acquisition != acquisitions_sorted[0]: + plate = dict_acquisitions[acquisitions_sorted[0]]["plate"] + logger.warning( + f"For {acquisition=}, we replace {original_plate=} with " + f"{plate=} (the one for acquisition {acquisitions_sorted[0]})" + ) + + # Check that all channels are in the allowed_channels + allowed_wavelength_ids = [ + c.wavelength_id for c in acq_input.allowed_channels + ] + if not set(actual_wavelength_ids).issubset( + set(allowed_wavelength_ids) + ): + msg = "ERROR in create_ome_zarr\n" + msg += f"actual_wavelength_ids: {actual_wavelength_ids}\n" + msg += f"allowed_wavelength_ids: {allowed_wavelength_ids}\n" + raise ValueError(msg) + + # Create actual_channels, i.e. a list of the channel dictionaries which + # are present + actual_channels = [ + channel + for channel in acq_input.allowed_channels + if channel.wavelength_id in actual_wavelength_ids + ] + + logger.info(f"plate: {plate}") + logger.info(f"actual_channels: {actual_channels}") + + dict_acquisitions[acquisition] = {} + dict_acquisitions[acquisition]["plate"] = plate + dict_acquisitions[acquisition]["original_plate"] = original_plate + dict_acquisitions[acquisition]["plate_prefix"] = plate_prefix + dict_acquisitions[acquisition]["image_folder"] = acq_input.image_dir + dict_acquisitions[acquisition]["original_paths"] = [ + acq_input.image_dir + ] + dict_acquisitions[acquisition]["actual_channels"] = actual_channels + dict_acquisitions[acquisition][ + "actual_wavelength_ids" + ] = actual_wavelength_ids + + parallelization_list = [] + current_plates = [item["plate"] for item in dict_acquisitions.values()] + if len(set(current_plates)) > 1: + raise ValueError(f"{current_plates=}") + plate = current_plates[0] + + zarrurl = dict_acquisitions[acquisitions_sorted[0]]["plate"] + ".zarr" + full_zarrurl = str(Path(zarr_dir) / zarrurl) + logger.info(f"Creating {full_zarrurl=}") + # Call zarr.open_group wrapper, which handles overwrite=True/False + group_plate = open_zarr_group_with_overwrite( + full_zarrurl, overwrite=overwrite + ) + group_plate.attrs["plate"] = { + "acquisitions": [ + { + "id": int(acquisition), + "name": dict_acquisitions[acquisition]["original_plate"], + } + for acquisition in acquisitions_sorted + ] + } + + zarrurls: dict[str, list[str]] = {"well": [], "image": []} + zarrurls["plate"] = [f"{plate}.zarr"] + + ################################################################ + logging.info(f"{acquisitions_sorted=}") + + for i, acquisition in enumerate(acquisitions_sorted): + # Define plate zarr + image_folder = dict_acquisitions[acquisition]["image_folder"] + logger.info(f"Looking at {image_folder=}") + + # Obtain FOV-metadata dataframe + if metadata_table_files is None: + mrf_path = f"{image_folder}/MeasurementDetail.mrf" + mlf_path = f"{image_folder}/MeasurementData.mlf" + site_metadata, _ = parse_yokogawa_metadata( + mrf_path, + mlf_path, + include_patterns=include_glob_patterns, + exclude_patterns=exclude_glob_patterns, + ) + site_metadata = remove_FOV_overlaps(site_metadata) + else: + site_metadata = pd.read_csv(metadata_table_files[acquisition]) + site_metadata.set_index(["well_id", "FieldIndex"], inplace=True) + + # Extract pixel sizes and bit_depth + pixel_size_z = site_metadata["pixel_size_z"].iloc[0] + pixel_size_y = site_metadata["pixel_size_y"].iloc[0] + pixel_size_x = site_metadata["pixel_size_x"].iloc[0] + bit_depth = site_metadata["bit_depth"].iloc[0] + + if min(pixel_size_z, pixel_size_y, pixel_size_x) < 1e-9: + raise ValueError(pixel_size_z, pixel_size_y, pixel_size_x) + + # Identify all wells + plate_prefix = dict_acquisitions[acquisition]["plate_prefix"] + include_patterns = [f"{plate_prefix}_*.{image_extension}"] + if include_glob_patterns: + include_patterns.extend(include_glob_patterns) + plate_images = glob_with_multiple_patterns( + folder=str(image_folder), + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + + wells = [ + parse_filename(os.path.basename(fn))["well"] for fn in plate_images + ] + wells = sorted(list(set(wells))) + logger.info(f"{wells=}") + + # Verify that all wells have all channels + actual_channels = dict_acquisitions[acquisition]["actual_channels"] + for well in wells: + include_patterns = [f"{plate_prefix}_{well}_*.{image_extension}"] + if include_glob_patterns: + include_patterns.extend(include_glob_patterns) + well_images = glob_with_multiple_patterns( + folder=str(image_folder), + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + ) + + well_wavelength_ids = [] + for fpath in well_images: + try: + filename_metadata = parse_filename(os.path.basename(fpath)) + A = filename_metadata["A"] + C = filename_metadata["C"] + well_wavelength_ids.append(f"A{A}_C{C}") + except IndexError: + logger.info(f"Skipping {fpath}") + well_wavelength_ids = sorted(list(set(well_wavelength_ids))) + actual_wavelength_ids = dict_acquisitions[acquisition][ + "actual_wavelength_ids" + ] + if well_wavelength_ids != actual_wavelength_ids: + raise ValueError( + f"ERROR: well {well} in plate {plate} (prefix: " + f"{plate_prefix}) has missing channels.\n" + f"Expected: {actual_wavelength_ids}\n" + f"Found: {well_wavelength_ids}.\n" + ) + + well_rows_columns = generate_row_col_split(wells) + row_list = [ + well_row_column[0] for well_row_column in well_rows_columns + ] + col_list = [ + well_row_column[1] for well_row_column in well_rows_columns + ] + row_list = sorted(list(set(row_list))) + col_list = sorted(list(set(col_list))) + + plate_attrs = group_plate.attrs["plate"] + plate_attrs["columns"] = [{"name": col} for col in col_list] + plate_attrs["rows"] = [{"name": row} for row in row_list] + plate_attrs["wells"] = [ + { + "path": well_row_column[0] + "/" + well_row_column[1], + "rowIndex": row_list.index(well_row_column[0]), + "columnIndex": col_list.index(well_row_column[1]), + } + for well_row_column in well_rows_columns + ] + plate_attrs["version"] = __OME_NGFF_VERSION__ + # Validate plate attrs + Plate(**plate_attrs) + group_plate.attrs["plate"] = plate_attrs + + for row, column in well_rows_columns: + parallelization_list.append( + { + "zarr_url": ( + f"{zarr_dir}/{plate}.zarr/{row}/{column}/" f"{i}" + ), + "init_args": InitArgsCellVoyager( + image_dir=acquisitions[acquisition].image_dir, + plate_prefix=plate_prefix, + well_ID=get_filename_well_id(row, column), + image_extension=image_extension, + include_glob_patterns=include_glob_patterns, + exclude_glob_patterns=exclude_glob_patterns, + acquisition=acquisition, + ).model_dump(), + } + ) + try: + group_well = group_plate.create_group(f"{row}/{column}/") + logging.info(f"Created new group_well at {row}/{column}/") + well_attrs = { + "images": [ + { + "path": f"{i}", + "acquisition": int(acquisition), + } + ], + "version": __OME_NGFF_VERSION__, + } + # Validate well attrs: + Well(**well_attrs) + group_well.attrs["well"] = well_attrs + zarrurls["well"].append(f"{plate}.zarr/{row}/{column}") + except ContainsGroupError: + group_well = zarr.open_group( + f"{full_zarrurl}/{row}/{column}/", mode="r+" + ) + logging.info( + f"Loaded group_well from {full_zarrurl}/{row}/{column}" + ) + current_images = group_well.attrs["well"]["images"] + [ + {"path": f"{i}", "acquisition": int(acquisition)} + ] + well_attrs = dict( + images=current_images, + version=group_well.attrs["well"]["version"], + ) + # Validate well attrs: + Well(**well_attrs) + group_well.attrs["well"] = well_attrs + + group_image = group_well.create_group(f"{i}/") # noqa: F841 + logging.info(f"Created image group {row}/{column}/{i}") + image = f"{plate}.zarr/{row}/{column}/{i}" + zarrurls["image"].append(image) + + group_image.attrs["multiscales"] = [ + { + "version": __OME_NGFF_VERSION__, + "axes": [ + {"name": "c", "type": "channel"}, + { + "name": "z", + "type": "space", + "unit": "micrometer", + }, + { + "name": "y", + "type": "space", + "unit": "micrometer", + }, + { + "name": "x", + "type": "space", + "unit": "micrometer", + }, + ], + "datasets": [ + { + "path": f"{ind_level}", + "coordinateTransformations": [ + { + "type": "scale", + "scale": [ + 1, + pixel_size_z, + pixel_size_y + * coarsening_xy**ind_level, + pixel_size_x + * coarsening_xy**ind_level, + ], + } + ], + } + for ind_level in range(num_levels) + ], + } + ] + + group_image.attrs["omero"] = { + "id": 1, # FIXME does this depend on the plate number? + "name": "TBD", + "version": __OME_NGFF_VERSION__, + "channels": define_omero_channels( + channels=actual_channels, + bit_depth=bit_depth, + label_prefix=i, + ), + } + # Validate Image attrs + NgffImageMeta(**group_image.attrs) + + # Prepare AnnData tables for FOV/well ROIs + well_id = get_filename_well_id(row, column) + FOV_ROIs_table = prepare_FOV_ROI_table(site_metadata.loc[well_id]) + well_ROIs_table = prepare_well_ROI_table( + site_metadata.loc[well_id] + ) + + # Write AnnData tables into the `tables` zarr group + write_table( + group_image, + "FOV_ROI_table", + FOV_ROIs_table, + overwrite=overwrite, + table_attrs={"type": "roi_table"}, + ) + write_table( + group_image, + "well_ROI_table", + well_ROIs_table, + overwrite=overwrite, + table_attrs={"type": "roi_table"}, + ) + + # Check that the different images (e.g. different acquisitions) in the each + # well have unique labels + for well_path in zarrurls["well"]: + check_well_channel_labels( + well_zarr_path=str(Path(zarr_dir) / well_path) + ) + + return dict(parallelization_list=parallelization_list) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=cellvoyager_to_ome_zarr_init_multiplex, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/copy_ome_zarr_hcs_plate.py b/fractal_tasks_core/tasks/copy_ome_zarr_hcs_plate.py new file mode 100644 index 000000000..4ec5e8111 --- /dev/null +++ b/fractal_tasks_core/tasks/copy_ome_zarr_hcs_plate.py @@ -0,0 +1,302 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Marco Franzon +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Task that copies the structure of an OME-NGFF zarr array to a new one. +""" +import logging +from typing import Any + +from pydantic import validate_call + +import fractal_tasks_core +from fractal_tasks_core.ngff.specs import NgffPlateMeta +from fractal_tasks_core.ngff.specs import WellInPlate +from fractal_tasks_core.ngff.zarr_utils import load_NgffPlateMeta +from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta +from fractal_tasks_core.tasks.io_models import InitArgsMIP +from fractal_tasks_core.tasks.projection_utils import DaskProjectionMethod +from fractal_tasks_core.zarr_utils import open_zarr_group_with_overwrite + +logger = logging.getLogger(__name__) + + +__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ + + +def _get_plate_url_from_image_url(zarr_url: str) -> str: + """ + Given the absolute `zarr_url` for an OME-Zarr image within an HCS plate, + return the path to the plate zarr group. + """ + zarr_url = zarr_url.rstrip("/") + plate_path = "/".join(zarr_url.split("/")[:-3]) + return plate_path + + +def _get_well_sub_url(zarr_url: str) -> str: + """ + Given the absolute `zarr_url` for an OME-Zarr image within an HCS plate, + return the path to the image zarr group. + """ + zarr_url = zarr_url.rstrip("/") + well_url = "/".join(zarr_url.split("/")[-3:-1]) + return well_url + + +def _get_image_sub_url(zarr_url: str) -> str: + """ + Given the absolute `zarr_url` for an OME-Zarr image, return the image + zarr-group name. + """ + zarr_url = zarr_url.rstrip("/") + image_sub_url = zarr_url.split("/")[-1] + return image_sub_url + + +def _generate_wells_rows_columns( + well_list: list[str], +) -> tuple[list[WellInPlate], list[str], list[str]]: + """ + Generate the plate well metadata based on the list of wells. + """ + rows = [] + columns = [] + wells = [] + for well in well_list: + rows.append(well.split("/")[0]) + columns.append(well.split("/")[1]) + rows = sorted(list(set(rows))) + columns = sorted(list(set(columns))) + for well in well_list: + wells.append( + WellInPlate( + path=well, + rowIndex=rows.index(well.split("/")[0]), + columnIndex=columns.index(well.split("/")[1]), + ) + ) + + return wells, rows, columns + + +def _generate_plate_well_metadata( + zarr_urls: list[str], +) -> tuple[dict[str, dict], dict[str, dict[str, dict]], dict[str, dict]]: + """ + Generate metadata for OME-Zarr HCS plates & wells. + + Based on the list of zarr_urls, generate metadata for all plates and all + their wells. + + Args: + zarr_urls: List of paths or urls to the individual OME-Zarr image to + be processed. + + Returns: + plate_metadata_dicts: Dictionary of plate plate metadata. The structure + is: {"old_plate_name": NgffPlateMeta (as dict)}. + new_well_image_attrs: Dictionary of image lists for the new wells. + The structure is: {"old_plate_name": {"old_well_name": + [ImageInWell(as dict)]}} + well_image_attrs: Dictionary of Image attributes of the existing wells. + """ + # TODO: Simplify this block. Currently complicated, because we need to loop + # through all potential plates, all their wells & their images to build up + # the metadata for the plate & well. + plate_metadata_dicts = {} + plate_wells = {} + well_image_attrs = {} + new_well_image_attrs = {} + for zarr_url in zarr_urls: + # Extract plate/well/image parts of `zarr_url` + old_plate_url = _get_plate_url_from_image_url(zarr_url) + well_sub_url = _get_well_sub_url(zarr_url) + curr_img_sub_url = _get_image_sub_url(zarr_url) + + # The first time a plate is found, create its metadata + if old_plate_url not in plate_metadata_dicts: + logger.info(f"Reading plate metadata of {old_plate_url=}") + old_plate_meta = load_NgffPlateMeta(old_plate_url) + plate_metadata = dict( + plate=dict( + acquisitions=old_plate_meta.plate.acquisitions, + field_count=old_plate_meta.plate.field_count, + name=old_plate_meta.plate.name, + # The new field count could be different from the old + # field count + version=old_plate_meta.plate.version, + ) + ) + plate_metadata_dicts[old_plate_url] = plate_metadata + plate_wells[old_plate_url] = [] + well_image_attrs[old_plate_url] = {} + new_well_image_attrs[old_plate_url] = {} + + # The first time a plate/well pair is found, create the well metadata + if well_sub_url not in plate_wells[old_plate_url]: + plate_wells[old_plate_url].append(well_sub_url) + old_well_url = f"{old_plate_url}/{well_sub_url}" + logger.info(f"Reading well metadata of {old_well_url}") + well_attrs = load_NgffWellMeta(old_well_url) + well_image_attrs[old_plate_url][well_sub_url] = well_attrs.well + new_well_image_attrs[old_plate_url][well_sub_url] = [] + + # Find images of the current well with name matching the current image + # TODO: clarify whether this list must always have length 1 + curr_well_image_list = [ + img + for img in well_image_attrs[old_plate_url][well_sub_url].images + if img.path == curr_img_sub_url + ] + new_well_image_attrs[old_plate_url][ + well_sub_url + ] += curr_well_image_list + + # Fill in the plate metadata based on all available wells + for old_plate_url in plate_metadata_dicts: + well_list, row_list, column_list = _generate_wells_rows_columns( + plate_wells[old_plate_url] + ) + plate_metadata_dicts[old_plate_url]["plate"]["columns"] = [] + for column in column_list: + plate_metadata_dicts[old_plate_url]["plate"]["columns"].append( + {"name": column} + ) + + plate_metadata_dicts[old_plate_url]["plate"]["rows"] = [] + for row in row_list: + plate_metadata_dicts[old_plate_url]["plate"]["rows"].append( + {"name": row} + ) + plate_metadata_dicts[old_plate_url]["plate"]["wells"] = well_list + + # Validate with NgffPlateMeta model + plate_metadata_dicts[old_plate_url] = NgffPlateMeta( + **plate_metadata_dicts[old_plate_url] + ).model_dump(exclude_none=True) + + return plate_metadata_dicts, new_well_image_attrs, well_image_attrs + + +@validate_call +def copy_ome_zarr_hcs_plate( + *, + # Fractal parameters + zarr_urls: list[str], + zarr_dir: str, + method: DaskProjectionMethod = DaskProjectionMethod.MIP, + # Advanced parameters + overwrite: bool = False, +) -> dict[str, Any]: + """ + Duplicate the OME-Zarr HCS structure for a set of zarr_urls. + + This task only processes the zarr images in the zarr_urls, not all the + images in the plate. It copies all the plate & well structure, but none + of the image metadata or the actual image data: + + - For each plate, create a new OME-Zarr HCS plate with the attributes for + all the images in zarr_urls + - For each well (in each plate), create a new zarr subgroup with the + same attributes as the original one. + + Note: this task makes use of methods from the `Attributes` class, see + https://zarr.readthedocs.io/en/stable/api/attrs.html. + + Args: + zarr_urls: List of paths or urls to the individual OME-Zarr image to + be processed. + (standard argument for Fractal tasks, managed by Fractal server). + zarr_dir: path of the directory where the new OME-Zarrs will be + created. + (standard argument for Fractal tasks, managed by Fractal server). + method: Choose which method to use for intensity projection along the + Z axis. mip is the default and performs a maximum intensity + projection. minip performs a minimum intensity projection, meanip + a mean intensity projection and sumip a sum intensity projection. + overwrite: If `True`, overwrite the task output. + + Returns: + A parallelization list to be used in a compute task to fill the wells + with OME-Zarr images. + """ + + parallelization_list = [] + + # Generate parallelization list + for zarr_url in zarr_urls: + old_plate_url = _get_plate_url_from_image_url(zarr_url) + well_sub_url = _get_well_sub_url(zarr_url) + old_plate_name = old_plate_url.split(".zarr")[-2].split("/")[-1] + new_plate_name = f"{old_plate_name}_{method.value}" + zarrurl_plate_new = f"{zarr_dir}/{new_plate_name}.zarr" + curr_img_sub_url = _get_image_sub_url(zarr_url) + new_zarr_url = f"{zarrurl_plate_new}/{well_sub_url}/{curr_img_sub_url}" + parallelization_item = dict( + zarr_url=new_zarr_url, + init_args=dict( + origin_url=zarr_url, + method=method.value, + overwrite=overwrite, + new_plate_name=f"{new_plate_name}.zarr", + ), + ) + InitArgsMIP(**parallelization_item["init_args"]) + parallelization_list.append(parallelization_item) + + # Generate the plate metadata & parallelization list + ( + plate_attrs_dicts, + new_well_image_attrs, + well_image_attrs, + ) = _generate_plate_well_metadata(zarr_urls=zarr_urls) + + # Create the new OME-Zarr HCS plate + for old_plate_url, plate_attrs in plate_attrs_dicts.items(): + old_plate_name = old_plate_url.split(".zarr")[-2].split("/")[-1] + new_plate_name = f"{old_plate_name}_{method.value}" + zarrurl_new = f"{zarr_dir}/{new_plate_name}.zarr" + logger.info(f"{old_plate_url=}") + logger.info(f"{zarrurl_new=}") + new_plate_group = open_zarr_group_with_overwrite( + zarrurl_new, overwrite=overwrite + ) + new_plate_group.attrs.put(plate_attrs) + + # Write well groups: + for well_sub_url in new_well_image_attrs[old_plate_url]: + new_well_group = new_plate_group.create_group(f"{well_sub_url}") + well_attrs = dict( + well=dict( + images=[ + img.model_dump(exclude_none=True) + for img in new_well_image_attrs[old_plate_url][ + well_sub_url + ] + ], + version=well_image_attrs[old_plate_url][ + well_sub_url + ].version, + ) + ) + new_well_group.attrs.put(well_attrs) + + return dict(parallelization_list=parallelization_list) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=copy_ome_zarr_hcs_plate, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/find_registration_consensus.py b/fractal_tasks_core/tasks/find_registration_consensus.py new file mode 100644 index 000000000..f0aadb010 --- /dev/null +++ b/fractal_tasks_core/tasks/find_registration_consensus.py @@ -0,0 +1,173 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Joel Lüthi +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Applies the multiplexing translation to all ROI tables +""" +import logging +from typing import Optional + +import anndata as ad +import zarr +from pydantic import validate_call + +from fractal_tasks_core.roi import ( + are_ROI_table_columns_valid, +) +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tasks._registration_utils import ( + add_zero_translation_columns, +) +from fractal_tasks_core.tasks._registration_utils import ( + apply_registration_to_single_ROI_table, +) +from fractal_tasks_core.tasks._registration_utils import ( + calculate_min_max_across_dfs, +) +from fractal_tasks_core.tasks.io_models import InitArgsRegistrationConsensus + + +logger = logging.getLogger(__name__) + + +@validate_call +def find_registration_consensus( + *, + # Fractal parameters + zarr_url: str, + init_args: InitArgsRegistrationConsensus, + # Core parameters + roi_table: str = "FOV_ROI_table", + # Advanced parameters + new_roi_table: Optional[str] = None, +): + """ + Applies pre-calculated registration to ROI tables. + + Apply pre-calculated registration such that resulting ROIs contain + the consensus align region between all acquisitions. + + Parallelization level: well + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + Refers to the zarr_url of the reference acquisition. + (standard argument for Fractal tasks, managed by Fractal server). + init_args: Intialization arguments provided by + `init_group_by_well_for_multiplexing`. It contains the + zarr_url_list listing all the zarr_urls in the same well as the + zarr_url of the reference acquisition that are being processed. + (standard argument for Fractal tasks, managed by Fractal server). + roi_table: Name of the ROI table over which the task loops to + calculate the registration. Examples: `FOV_ROI_table` => loop over + the field of views, `well_ROI_table` => process the whole well as + one image. + new_roi_table: Optional name for the new, registered ROI table. If no + name is given, it will default to "registered_" + `roi_table` + + """ + if not new_roi_table: + new_roi_table = "registered_" + roi_table + logger.info( + f"Running for {zarr_url=} & the other acquisitions in that well. \n" + f"Applying translation registration to {roi_table=} and storing it as " + f"{new_roi_table=}." + ) + + # Collect all the ROI tables + roi_tables = {} + roi_tables_attrs = {} + for acq_zarr_url in init_args.zarr_url_list: + curr_ROI_table = ad.read_zarr(f"{acq_zarr_url}/tables/{roi_table}") + curr_ROI_table_group = zarr.open_group( + f"{acq_zarr_url}/tables/{roi_table}", mode="r" + ) + curr_ROI_table_attrs = curr_ROI_table_group.attrs.asdict() + + # For reference_acquisition, handle the fact that it doesn't + # have the shifts + if acq_zarr_url == zarr_url: + curr_ROI_table = add_zero_translation_columns(curr_ROI_table) + # Check for valid ROI tables + are_ROI_table_columns_valid(table=curr_ROI_table) + translation_columns = [ + "translation_z", + "translation_y", + "translation_x", + ] + if curr_ROI_table.var.index.isin(translation_columns).sum() != 3: + raise ValueError( + f"{roi_table=} in {acq_zarr_url} does not contain the " + f"translation columns {translation_columns} necessary to use " + "this task." + ) + roi_tables[acq_zarr_url] = curr_ROI_table + roi_tables_attrs[acq_zarr_url] = curr_ROI_table_attrs + + # Check that all acquisitions have the same ROIs + rois = roi_tables[list(roi_tables.keys())[0]].obs.index + for acq_zarr_url, acq_roi_table in roi_tables.items(): + if not (acq_roi_table.obs.index == rois).all(): + raise ValueError( + f"Acquisition {acq_zarr_url} does not contain the same ROIs " + f"as the reference acquisition {zarr_url}:\n" + f"{acq_zarr_url}: {acq_roi_table.obs.index}\n" + f"{zarr_url}: {rois}" + ) + + roi_table_dfs = [ + roi_table.to_df().loc[:, translation_columns] + for roi_table in roi_tables.values() + ] + logger.info("Calculating min & max translation across acquisitions.") + max_df, min_df = calculate_min_max_across_dfs(roi_table_dfs) + shifted_rois = {} + + # Loop over acquisitions + for acq_zarr_url in init_args.zarr_url_list: + shifted_rois[acq_zarr_url] = apply_registration_to_single_ROI_table( + roi_tables[acq_zarr_url], max_df, min_df + ) + + # TODO: Drop translation columns from this table? + + logger.info( + f"Write the registered ROI table {new_roi_table} for " + "{acq_zarr_url=}" + ) + # Save the shifted ROI table as a new table + image_group = zarr.group(acq_zarr_url) + write_table( + image_group, + new_roi_table, + shifted_rois[acq_zarr_url], + table_attrs=roi_tables_attrs[acq_zarr_url], + overwrite=True, + ) + + # TODO: Optionally apply registration to other tables as well? + # e.g. to well_ROI_table based on FOV_ROI_table + # => out of scope for the initial task, apply registration separately + # to each table + # Easiest implementation: Apply average shift calculcated here to other + # ROIs. From many to 1 (e.g. FOV => well) => average shift, but crop len + # From well to many (e.g. well to FOVs) => average shift, crop len by that + # amount + # Many to many (FOVs to organoids) => tricky because of matching + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=find_registration_consensus, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/illumination_correction.py b/fractal_tasks_core/tasks/illumination_correction.py new file mode 100644 index 000000000..edb045b7b --- /dev/null +++ b/fractal_tasks_core/tasks/illumination_correction.py @@ -0,0 +1,292 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Marco Franzon +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Apply illumination correction to all fields of view. +""" +import logging +import time +import warnings +from pathlib import Path +from typing import Any + +import anndata as ad +import dask.array as da +import numpy as np +import zarr +from pydantic import validate_call +from skimage.io import imread + +from fractal_tasks_core.channels import get_omero_channel_list +from fractal_tasks_core.channels import OmeroChannel +from fractal_tasks_core.ngff import load_NgffImageMeta +from fractal_tasks_core.pyramids import build_pyramid +from fractal_tasks_core.roi import check_valid_ROI_indices +from fractal_tasks_core.roi import ( + convert_ROI_table_to_indices, +) +from fractal_tasks_core.tasks._zarr_utils import _copy_hcs_ome_zarr_metadata +from fractal_tasks_core.tasks._zarr_utils import _copy_tables_from_zarr_url + +logger = logging.getLogger(__name__) + + +def correct( + img_stack: np.ndarray, + corr_img: np.ndarray, + background: int = 110, +): + """ + Corrects a stack of images, using a given illumination profile (e.g. bright + in the center of the image, dim outside). + + Args: + img_stack: 4D numpy array (czyx), with dummy size along c. + corr_img: 2D numpy array (yx) + background: Background value that is subtracted from the image before + the illumination correction is applied. + """ + + logger.info(f"Start correct, {img_stack.shape}") + + # Check shapes + if corr_img.shape != img_stack.shape[2:] or img_stack.shape[0] != 1: + raise ValueError( + "Error in illumination_correction:\n" + f"{img_stack.shape=}\n{corr_img.shape=}" + ) + + # Store info about dtype + dtype = img_stack.dtype + dtype_max = np.iinfo(dtype).max + + # Background subtraction + img_stack[img_stack <= background] = 0 + img_stack[img_stack > background] -= background + + # Apply the normalized correction matrix (requires a float array) + # img_stack = img_stack.astype(np.float64) + new_img_stack = img_stack / (corr_img / np.max(corr_img))[None, None, :, :] + + # Handle edge case: corrected image may have values beyond the limit of + # the encoding, e.g. beyond 65535 for 16bit images. This clips values + # that surpass this limit and triggers a warning + if np.sum(new_img_stack > dtype_max) > 0: + warnings.warn( + "Illumination correction created values beyond the max range of " + f"the current image type. These have been clipped to {dtype_max=}." + ) + new_img_stack[new_img_stack > dtype_max] = dtype_max + + logger.info("End correct") + + # Cast back to original dtype and return + return new_img_stack.astype(dtype) + + +@validate_call +def illumination_correction( + *, + # Fractal parameters + zarr_url: str, + # Core parameters + illumination_profiles_folder: str, + illumination_profiles: dict[str, str], + background: int = 0, + input_ROI_table: str = "FOV_ROI_table", + overwrite_input: bool = True, + # Advanced parameters + suffix: str = "_illum_corr", +) -> dict[str, Any]: + """ + Applies illumination correction to the images in the OME-Zarr. + + Assumes that the illumination correction profiles were generated before + separately and that the same background subtraction was used during + calculation of the illumination correction (otherwise, it will not work + well & the correction may only be partial). + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + (standard argument for Fractal tasks, managed by Fractal server). + illumination_profiles_folder: Path of folder of illumination profiles. + illumination_profiles: Dictionary where keys match the `wavelength_id` + attributes of existing channels (e.g. `A01_C01` ) and values are + the filenames of the corresponding illumination profiles. + background: Background value that is subtracted from the image before + the illumination correction is applied. Set it to `0` if you don't + want any background subtraction. + input_ROI_table: Name of the ROI table that contains the information + about the location of the individual field of views (FOVs) to + which the illumination correction shall be applied. Defaults to + "FOV_ROI_table", the default name Fractal converters give the ROI + tables that list all FOVs separately. If you generated your + OME-Zarr with a different converter and used Import OME-Zarr to + generate the ROI tables, `image_ROI_table` is the right choice if + you only have 1 FOV per Zarr image and `grid_ROI_table` if you + have multiple FOVs per Zarr image and set the right grid options + during import. + overwrite_input: If `True`, the results of this task will overwrite + the input image data. If false, a new image is generated and the + illumination corrected data is saved there. + suffix: What suffix to append to the illumination corrected images. + Only relevant if `overwrite_input=False`. + """ + + # Defione old/new zarrurls + if overwrite_input: + zarr_url_new = zarr_url.rstrip("/") + else: + zarr_url_new = zarr_url.rstrip("/") + suffix + + t_start = time.perf_counter() + logger.info("Start illumination_correction") + logger.info(f" {overwrite_input=}") + logger.info(f" {zarr_url=}") + logger.info(f" {zarr_url_new=}") + + # Read attributes from NGFF metadata + ngff_image_meta = load_NgffImageMeta(zarr_url) + num_levels = ngff_image_meta.num_levels + coarsening_xy = ngff_image_meta.coarsening_xy + full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) + logger.info(f"NGFF image has {num_levels=}") + logger.info(f"NGFF image has {coarsening_xy=}") + logger.info( + f"NGFF image has full-res pixel sizes {full_res_pxl_sizes_zyx}" + ) + + # Read channels from .zattrs + channels: list[OmeroChannel] = get_omero_channel_list( + image_zarr_path=zarr_url + ) + num_channels = len(channels) + + # Read FOV ROIs + FOV_ROI_table = ad.read_zarr(f"{zarr_url}/tables/{input_ROI_table}") + + # Create list of indices for 3D FOVs spanning the entire Z direction + list_indices = convert_ROI_table_to_indices( + FOV_ROI_table, + level=0, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) + check_valid_ROI_indices(list_indices, input_ROI_table) + + # Extract image size from FOV-ROI indices. Note: this works at level=0, + # where FOVs should all be of the exact same size (in pixels) + ref_img_size = None + for indices in list_indices: + img_size = (indices[3] - indices[2], indices[5] - indices[4]) + if ref_img_size is None: + ref_img_size = img_size + else: + if img_size != ref_img_size: + raise ValueError( + "ERROR: inconsistent image sizes in list_indices" + ) + img_size_y, img_size_x = img_size[:] + + # Assemble dictionary of matrices and check their shapes + corrections = {} + for channel in channels: + wavelength_id = channel.wavelength_id + corrections[wavelength_id] = imread( + ( + Path(illumination_profiles_folder) + / illumination_profiles[wavelength_id] + ).as_posix() + ) + if corrections[wavelength_id].shape != (img_size_y, img_size_x): + raise ValueError( + "Error in illumination_correction, " + "correction matrix has wrong shape." + ) + + # Lazily load highest-res level from original zarr array + data_czyx = da.from_zarr(f"{zarr_url}/0") + + # Create zarr for output + if overwrite_input: + new_zarr = zarr.open(f"{zarr_url_new}/0") + else: + new_zarr = zarr.create( + shape=data_czyx.shape, + chunks=data_czyx.chunksize, + dtype=data_czyx.dtype, + store=zarr.storage.FSStore(f"{zarr_url_new}/0"), + overwrite=False, + dimension_separator="/", + ) + _copy_hcs_ome_zarr_metadata(zarr_url, zarr_url_new) + # Copy ROI tables from the old zarr_url to keep ROI tables and other + # tables available in the new Zarr + _copy_tables_from_zarr_url(zarr_url, zarr_url_new) + + # Iterate over FOV ROIs + num_ROIs = len(list_indices) + for i_c, channel in enumerate(channels): + for i_ROI, indices in enumerate(list_indices): + # Define region + s_z, e_z, s_y, e_y, s_x, e_x = indices[:] + region = ( + slice(i_c, i_c + 1), + slice(s_z, e_z), + slice(s_y, e_y), + slice(s_x, e_x), + ) + logger.info( + f"Now processing ROI {i_ROI+1}/{num_ROIs} " + f"for channel {i_c+1}/{num_channels}" + ) + # Execute illumination correction + corrected_fov = correct( + data_czyx[region].compute(), + corrections[channel.wavelength_id], + background=background, + ) + # Write to disk + da.array(corrected_fov).to_zarr( + url=new_zarr, + region=region, + compute=True, + ) + + # Starting from on-disk highest-resolution data, build and write to disk a + # pyramid of coarser levels + build_pyramid( + zarrurl=zarr_url_new, + overwrite=True, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + chunksize=data_czyx.chunksize, + ) + + t_end = time.perf_counter() + logger.info(f"End illumination_correction, elapsed: {t_end-t_start}") + + if overwrite_input: + image_list_updates = dict(image_list_updates=[dict(zarr_url=zarr_url)]) + else: + image_list_updates = dict( + image_list_updates=[dict(zarr_url=zarr_url_new, origin=zarr_url)] + ) + return image_list_updates + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=illumination_correction, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/image_based_registration_hcs_init.py b/fractal_tasks_core/tasks/image_based_registration_hcs_init.py new file mode 100644 index 000000000..db76eb595 --- /dev/null +++ b/fractal_tasks_core/tasks/image_based_registration_hcs_init.py @@ -0,0 +1,98 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Joel Lüthi +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Initializes the parallelization list for registration in HCS plates. +""" +import logging +from typing import Any + +from pydantic import validate_call + +from fractal_tasks_core.utils import ( + create_well_acquisition_dict, +) + +logger = logging.getLogger(__name__) + + +@validate_call +def image_based_registration_hcs_init( + *, + # Fractal parameters + zarr_urls: list[str], + zarr_dir: str, + # Core parameters + reference_acquisition: int = 0, +) -> dict[str, list[dict[str, Any]]]: + """ + Initialized calculate registration task + + This task prepares a parallelization list of all zarr_urls that need to be + used to calculate the registration between acquisitions (all zarr_urls + except the reference acquisition vs. the reference acquisition). + This task only works for HCS OME-Zarrs for 2 reasons: Only HCS OME-Zarrs + currently have defined acquisition metadata to determine reference + acquisitions. And we have only implemented the grouping of images for + HCS OME-Zarrs by well (with the assumption that every well just has 1 + image per acqusition). + + Args: + zarr_urls: List of paths or urls to the individual OME-Zarr image to + be processed. + (standard argument for Fractal tasks, managed by Fractal server). + zarr_dir: path of the directory where the new OME-Zarrs will be + created. Not used by this task. + (standard argument for Fractal tasks, managed by Fractal server). + reference_acquisition: Which acquisition to register against. Needs to + match the acquisition metadata in the OME-Zarr image. + + Returns: + task_output: Dictionary for Fractal server that contains a + parallelization list. + """ + logger.info( + f"Running `image_based_registration_hcs_init` for {zarr_urls=}" + ) + image_groups = create_well_acquisition_dict(zarr_urls) + + # Create the parallelization list + parallelization_list = [] + for key, image_group in image_groups.items(): + # Assert that all image groups have the reference acquisition present + if reference_acquisition not in image_group.keys(): + raise ValueError( + f"Registration with {reference_acquisition=} can only work if " + "all wells have the reference acquisition present. It was not " + f"found for well {key}." + ) + # Add all zarr_urls except the reference acquisition to the + # parallelization list + for acquisition, zarr_url in image_group.items(): + if acquisition != reference_acquisition: + reference_zarr_url = image_group[reference_acquisition] + parallelization_list.append( + dict( + zarr_url=zarr_url, + init_args=dict(reference_zarr_url=reference_zarr_url), + ) + ) + + return dict(parallelization_list=parallelization_list) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=image_based_registration_hcs_init, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/import_ome_zarr.py b/fractal_tasks_core/tasks/import_ome_zarr.py new file mode 100644 index 000000000..f35bdd48f --- /dev/null +++ b/fractal_tasks_core/tasks/import_ome_zarr.py @@ -0,0 +1,314 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Task to import an existing OME-Zarr. +""" +import logging +from typing import Any +from typing import Optional + +import dask.array as da +import zarr +from pydantic import validate_call + +from fractal_tasks_core.channels import update_omero_channels +from fractal_tasks_core.ngff import detect_ome_ngff_type +from fractal_tasks_core.ngff import NgffImageMeta +from fractal_tasks_core.roi import get_image_grid_ROIs +from fractal_tasks_core.roi import get_single_image_ROI +from fractal_tasks_core.tables import write_table + +logger = logging.getLogger(__name__) + + +def _process_single_image( + image_path: str, + add_image_ROI_table: bool, + add_grid_ROI_table: bool, + update_omero_metadata: bool, + *, + grid_YX_shape: Optional[tuple[int, int]] = None, + overwrite: bool = False, +) -> dict[str, str]: + """ + Validate OME-NGFF metadata and optionally generate ROI tables. + + This task: + + 1. Validates OME-NGFF image metadata, via `NgffImageMeta`; + 2. Optionally generates and writes two ROI tables; + 3. Optionally update OME-NGFF omero metadata. + 4. Returns dataset types + + Args: + image_path: Absolute path to the image Zarr group. + add_image_ROI_table: Whether to add a `image_ROI_table` table + (argument propagated from `import_ome_zarr`). + add_grid_ROI_table: Whether to add a `grid_ROI_table` table (argument + propagated from `import_ome_zarr`). + update_omero_metadata: Whether to update Omero-channels metadata + (argument propagated from `import_ome_zarr`). + grid_YX_shape: YX shape of the ROI grid (it must be not `None`, if + `add_grid_ROI_table=True`. + """ + + # Note from zarr docs: `r+` means read/write (must exist) + image_group = zarr.open_group(image_path, mode="r+") + image_meta = NgffImageMeta(**image_group.attrs.asdict()) + + # Preliminary checks + if add_grid_ROI_table and (grid_YX_shape is None): + raise ValueError( + f"_process_single_image called with {add_grid_ROI_table=}, " + f"but {grid_YX_shape=}." + ) + + pixels_ZYX = image_meta.get_pixel_sizes_zyx(level=0) + + # Read zarr array + dataset_subpath = image_meta.datasets[0].path + array = da.from_zarr(f"{image_path}/{dataset_subpath}") + + # Prepare image_ROI_table and write it into the zarr group + if add_image_ROI_table: + image_ROI_table = get_single_image_ROI(array.shape, pixels_ZYX) + write_table( + image_group, + "image_ROI_table", + image_ROI_table, + overwrite=overwrite, + table_attrs={"type": "roi_table"}, + ) + + # Prepare grid_ROI_table and write it into the zarr group + if add_grid_ROI_table: + grid_ROI_table = get_image_grid_ROIs( + array.shape, + pixels_ZYX, + grid_YX_shape, + ) + write_table( + image_group, + "grid_ROI_table", + grid_ROI_table, + overwrite=overwrite, + table_attrs={"type": "roi_table"}, + ) + + # Update Omero-channels metadata + if update_omero_metadata: + # Extract number of channels from zarr array + try: + channel_axis_index = image_meta.axes_names.index("c") + except ValueError: + logger.error(f"Existing axes: {image_meta.axes_names}") + msg = ( + "OME-Zarrs with no channel axis are not currently " + "supported in fractal-tasks-core. Upcoming flexibility " + "improvements are tracked in https://github.com/" + "fractal-analytics-platform/fractal-tasks-core/issues/150." + ) + logger.error(msg) + raise NotImplementedError(msg) + logger.info(f"Existing axes: {image_meta.axes_names}") + logger.info(f"Channel-axis index: {channel_axis_index}") + num_channels_zarr = array.shape[channel_axis_index] + logger.info( + f"{num_channels_zarr} channel(s) found in Zarr array " + f"at {image_path}/{dataset_subpath}" + ) + # Update or create omero channels metadata + old_omero = image_group.attrs.get("omero", {}) + old_channels = old_omero.get("channels", []) + if len(old_channels) > 0: + logger.info( + f"{len(old_channels)} channel(s) found in NGFF omero metadata" + ) + if len(old_channels) != num_channels_zarr: + error_msg = ( + "Channels-number mismatch: Number of channels in the " + f"zarr array ({num_channels_zarr}) differs from number " + "of channels listed in NGFF omero metadata " + f"({len(old_channels)})." + ) + logging.error(error_msg) + raise ValueError(error_msg) + else: + old_channels = [{} for ind in range(num_channels_zarr)] + new_channels = update_omero_channels(old_channels) + new_omero = old_omero.copy() + new_omero["channels"] = new_channels + image_group.attrs.update(omero=new_omero) + + # Determine image types: + # Later: also provide a has_T flag. + # TODO: Potentially also load acquisition metadata if available in a Zarr + is_3D = False + if "z" in image_meta.axes_names: + if array.shape[-3] > 1: + is_3D = True + types = dict(is_3D=is_3D) + return types + + +@validate_call +def import_ome_zarr( + *, + # Fractal parameters + zarr_urls: list[str], + zarr_dir: str, + # Core parameters + zarr_name: str, + update_omero_metadata: bool = True, + add_image_ROI_table: bool = True, + add_grid_ROI_table: bool = True, + # Advanced parameters + grid_y_shape: int = 2, + grid_x_shape: int = 2, + overwrite: bool = False, +) -> dict[str, Any]: + """ + Import a single OME-Zarr into Fractal. + + The single OME-Zarr can be a full OME-Zarr HCS plate or an individual + OME-Zarr image. The image needs to be in the zarr_dir as specified by the + dataset. The current version of this task: + + 1. Creates the appropriate components-related metadata, needed for + processing an existing OME-Zarr through Fractal. + 2. Optionally adds new ROI tables to the existing OME-Zarr. + + Args: + zarr_urls: List of paths or urls to the individual OME-Zarr image to + be processed. Not used. + (standard argument for Fractal tasks, managed by Fractal server). + zarr_dir: path of the directory where the new OME-Zarrs will be + created. + (standard argument for Fractal tasks, managed by Fractal server). + zarr_name: The OME-Zarr name, without its parent folder. The parent + folder is provided by zarr_dir; e.g. `zarr_name="array.zarr"`, + if the OME-Zarr path is in `/zarr_dir/array.zarr`. + add_image_ROI_table: Whether to add a `image_ROI_table` table to each + image, with a single ROI covering the whole image. + add_grid_ROI_table: Whether to add a `grid_ROI_table` table to each + image, with the image split into a rectangular grid of ROIs. + grid_y_shape: Y shape of the ROI grid in `grid_ROI_table`. + grid_x_shape: X shape of the ROI grid in `grid_ROI_table`. + update_omero_metadata: Whether to update Omero-channels metadata, to + make them Fractal-compatible. + overwrite: Whether new ROI tables (added when `add_image_ROI_table` + and/or `add_grid_ROI_table` are `True`) can overwite existing ones. + """ + + # Is this based on the Zarr_dir or the zarr_urls? + if len(zarr_urls) > 0: + logger.warning( + "Running import while there are already items from the image list " + "provided to the task. The following inputs were provided: " + f"{zarr_urls=}" + "This task will not process the existing images, but look for " + f"zarr files named {zarr_name=} in the {zarr_dir=} instead." + ) + + zarr_path = f"{zarr_dir.rstrip('/')}/{zarr_name}" + logger.info(f"Zarr path: {zarr_path}") + + root_group = zarr.open_group(zarr_path, mode="r") + ngff_type = detect_ome_ngff_type(root_group) + grid_YX_shape = (grid_y_shape, grid_x_shape) + + image_list_updates = [] + if ngff_type == "plate": + for well in root_group.attrs["plate"]["wells"]: + well_path = well["path"] + + well_group = zarr.open_group(zarr_path, path=well_path, mode="r") + for image in well_group.attrs["well"]["images"]: + image_path = image["path"] + zarr_url = f"{zarr_path}/{well_path}/{image_path}" + types = _process_single_image( + zarr_url, + add_image_ROI_table, + add_grid_ROI_table, + update_omero_metadata, + grid_YX_shape=grid_YX_shape, + overwrite=overwrite, + ) + image_list_updates.append( + dict( + zarr_url=zarr_url, + attributes=dict( + plate=zarr_name, + well=well_path.replace("/", ""), + ), + types=types, + ) + ) + elif ngff_type == "well": + logger.warning( + "Only OME-Zarr for plates are fully supported in Fractal; " + f"e.g. the current one ({ngff_type=}) cannot be " + "processed via the `maximum_intensity_projection` task." + ) + for image in root_group.attrs["well"]["images"]: + image_path = image["path"] + zarr_url = f"{zarr_path}/{image_path}" + well_name = "".join(zarr_path.split("/")[-2:]) + types = _process_single_image( + zarr_url, + add_image_ROI_table, + add_grid_ROI_table, + update_omero_metadata, + grid_YX_shape=grid_YX_shape, + overwrite=overwrite, + ) + image_list_updates.append( + dict( + zarr_url=zarr_url, + attributes=dict( + well=well_name, + ), + types=types, + ) + ) + elif ngff_type == "image": + logger.warning( + "Only OME-Zarr for plates are fully supported in Fractal; " + f"e.g. the current one ({ngff_type=}) cannot be " + "processed via the `maximum_intensity_projection` task." + ) + zarr_url = zarr_path + types = _process_single_image( + zarr_url, + add_image_ROI_table, + add_grid_ROI_table, + update_omero_metadata, + grid_YX_shape=grid_YX_shape, + overwrite=overwrite, + ) + image_list_updates.append( + dict( + zarr_url=zarr_url, + types=types, + ) + ) + + image_list_changes = dict(image_list_updates=image_list_updates) + return image_list_changes + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=import_ome_zarr, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/init_group_by_well_for_multiplexing.py b/fractal_tasks_core/tasks/init_group_by_well_for_multiplexing.py new file mode 100644 index 000000000..1d0072671 --- /dev/null +++ b/fractal_tasks_core/tasks/init_group_by_well_for_multiplexing.py @@ -0,0 +1,91 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Joel Lüthi +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Applies the multiplexing translation to all ROI tables +""" +import logging + +from pydantic import validate_call + +from fractal_tasks_core.utils import ( + create_well_acquisition_dict, +) + +logger = logging.getLogger(__name__) + + +@validate_call +def init_group_by_well_for_multiplexing( + *, + # Fractal parameters + zarr_urls: list[str], + zarr_dir: str, + # Core parameters + reference_acquisition: int = 0, +) -> dict[str, list[str]]: + """ + Finds images for all acquisitions per well. + + Returns the parallelization_list to run `find_registration_consensus`. + + Args: + zarr_urls: List of paths or urls to the individual OME-Zarr image to + be processed. + (standard argument for Fractal tasks, managed by Fractal server). + zarr_dir: path of the directory where the new OME-Zarrs will be + created. Not used by this task. + (standard argument for Fractal tasks, managed by Fractal server). + reference_acquisition: Which acquisition to register against. Uses the + OME-NGFF HCS well metadata acquisition keys to find the reference + acquisition. + """ + logger.info( + f"Running `init_group_by_well_for_multiplexing` for {zarr_urls=}" + ) + image_groups = create_well_acquisition_dict(zarr_urls) + + # Create the parallelization list + parallelization_list = [] + for key, image_group in image_groups.items(): + # Assert that all image groups have the reference acquisition present + if reference_acquisition not in image_group.keys(): + raise ValueError( + f"Registration with {reference_acquisition=} can only work if " + "all wells have the reference acquisition present. It was not " + f"found for well {key}." + ) + + # Create a parallelization list entry for each image group + zarr_url_list = [] + for acquisition, zarr_url in image_group.items(): + if acquisition == reference_acquisition: + reference_zarr_url = zarr_url + + zarr_url_list.append(zarr_url) + + parallelization_list.append( + dict( + zarr_url=reference_zarr_url, + init_args=dict(zarr_url_list=zarr_url_list), + ) + ) + + return dict(parallelization_list=parallelization_list) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=init_group_by_well_for_multiplexing, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/io_models.py b/fractal_tasks_core/tasks/io_models.py new file mode 100644 index 000000000..c061afc89 --- /dev/null +++ b/fractal_tasks_core/tasks/io_models.py @@ -0,0 +1,185 @@ +from typing import Literal +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field +from pydantic import model_validator +from typing_extensions import Self + +from fractal_tasks_core.channels import ChannelInputModel +from fractal_tasks_core.channels import OmeroChannel + + +class InitArgsRegistration(BaseModel): + """ + Registration init args. + + Passed from `image_based_registration_hcs_init` to + `calculate_registration_image_based`. + + Attributes: + reference_zarr_url: zarr_url for the reference image + """ + + reference_zarr_url: str + + +class InitArgsRegistrationConsensus(BaseModel): + """ + Registration consensus init args. + + Provides the list of zarr_urls for all acquisitions for a given well + + Attributes: + zarr_url_list: List of zarr_urls for all the OME-Zarr images in the + well. + """ + + zarr_url_list: list[str] + + +class InitArgsCellVoyager(BaseModel): + """ + Arguments to be passed from cellvoyager converter init to compute + + Attributes: + image_dir: Directory where the raw images are found + plate_prefix: part of the image filename needed for finding the + right subset of image files + well_ID: part of the image filename needed for finding the + right subset of image files + image_extension: part of the image filename needed for finding the + right subset of image files + include_glob_patterns: Additional glob patterns to filter the available + images with. + exclude_glob_patterns: Glob patterns to exclude. + acquisition: Acquisition metadata needed for multiplexing + """ + + image_dir: str + plate_prefix: str + well_ID: str + image_extension: str + include_glob_patterns: Optional[list[str]] = None + exclude_glob_patterns: Optional[list[str]] = None + acquisition: Optional[int] = None + + +class InitArgsIllumination(BaseModel): + """ + Dummy model description. + + Attributes: + raw_path: dummy attribute description. + subsets: dummy attribute description. + """ + + raw_path: str + subsets: dict[Literal["C_index"], int] = Field(default_factory=dict) + + +class InitArgsMIP(BaseModel): + """ + Init Args for MIP task. + + Attributes: + origin_url: Path to the zarr_url with the 3D data + method: Projection method to be used. See `DaskProjectionMethod` + overwrite: If `True`, overwrite the task output. + new_plate_name: Name of the new OME-Zarr HCS plate + """ + + origin_url: str + method: str + overwrite: bool + new_plate_name: str + + +class MultiplexingAcquisition(BaseModel): + """ + Input class for Multiplexing Cellvoyager converter + + Attributes: + image_dir: Path to the folder that contains the Cellvoyager image + files for that acquisition and the MeasurementData & + MeasurementDetail metadata files. + allowed_channels: A list of `OmeroChannel` objects, where each channel + must include the `wavelength_id` attribute and where the + `wavelength_id` values must be unique across the list. + """ + + image_dir: str + allowed_channels: list[OmeroChannel] + + +class NapariWorkflowsOutput(BaseModel): + """ + A value of the `output_specs` argument in `napari_workflows_wrapper`. + + Attributes: + type: Output type (either `label` or `dataframe`). + label_name: Label name (for label outputs, it is used as the name of + the label; for dataframe outputs, it is used to fill the + `region["path"]` field). + table_name: Table name (for dataframe outputs only). + """ + + type: Literal["label", "dataframe"] + label_name: str + table_name: Optional[str] = None + + @model_validator(mode="after") + def table_name_only_for_dataframe_type(self: Self) -> Self: + """ + Check that table_name is set only for dataframe outputs. + """ + _type = self.type + _table_name = self.table_name + if (_type == "dataframe" and (not _table_name)) or ( + _type != "dataframe" and _table_name + ): + raise ValueError( + f"Output item has type={_type} but table_name={_table_name}." + ) + return self + + +class NapariWorkflowsInput(BaseModel): + """ + A value of the `input_specs` argument in `napari_workflows_wrapper`. + + Attributes: + type: Input type (either `image` or `label`). + label_name: Label name (for label inputs only). + channel: `ChannelInputModel` object (for image inputs only). + """ + + type: Literal["image", "label"] + label_name: Optional[str] = None + channel: Optional[ChannelInputModel] = None + + @model_validator(mode="after") + def label_name_is_present(self: Self) -> Self: + """ + Check that label inputs have `label_name` set. + """ + label_name = self.label_name + _type = self.type + if _type == "label" and label_name is None: + raise ValueError( + f"Input item has type={_type} but label_name={label_name}." + ) + return self + + @model_validator(mode="after") + def channel_is_present(self: Self) -> Self: + """ + Check that image inputs have `channel` set. + """ + _type = self.type + channel = self.channel + if _type == "image" and channel is None: + raise ValueError( + f"Input item has type={_type} but channel={channel}." + ) + return self diff --git a/fractal_tasks_core/tasks/napari_workflows_wrapper.py b/fractal_tasks_core/tasks/napari_workflows_wrapper.py new file mode 100644 index 000000000..b381a5960 --- /dev/null +++ b/fractal_tasks_core/tasks/napari_workflows_wrapper.py @@ -0,0 +1,638 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Marco Franzon +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Wrapper of napari-workflows. +""" +import logging +from typing import Any + +import anndata as ad +import dask.array as da +import napari_workflows +import numpy as np +import pandas as pd +import zarr +from napari_workflows._io_yaml_v1 import load_workflow +from pydantic import validate_call + +import fractal_tasks_core +from fractal_tasks_core.channels import get_channel_from_image_zarr +from fractal_tasks_core.labels import prepare_label_group +from fractal_tasks_core.ngff import load_NgffImageMeta +from fractal_tasks_core.pyramids import build_pyramid +from fractal_tasks_core.roi import check_valid_ROI_indices +from fractal_tasks_core.roi import ( + convert_ROI_table_to_indices, +) +from fractal_tasks_core.roi import load_region +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.tasks.io_models import ( + NapariWorkflowsInput, +) +from fractal_tasks_core.tasks.io_models import ( + NapariWorkflowsOutput, +) +from fractal_tasks_core.upscale_array import upscale_array +from fractal_tasks_core.utils import rescale_datasets + + +__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ + + +logger = logging.getLogger(__name__) + + +class OutOfTaskScopeError(NotImplementedError): + """ + Encapsulates features that are out-of-scope for the current wrapper task. + """ + + pass + + +@validate_call +def napari_workflows_wrapper( + *, + # Fractal parameters + zarr_url: str, + # Core parameters + workflow_file: str, + input_specs: dict[str, NapariWorkflowsInput], + output_specs: dict[str, NapariWorkflowsOutput], + input_ROI_table: str = "FOV_ROI_table", + level: int = 0, + # Advanced parameters + relabeling: bool = True, + expected_dimensions: int = 3, + overwrite: bool = True, +): + """ + Run a napari-workflow on the ROIs of a single OME-NGFF image. + + This task takes images and labels and runs a napari-workflow on them that + can produce a label and tables as output. + + Examples of allowed entries for `input_specs` and `output_specs`: + + ``` + input_specs = { + "in_1": {"type": "image", "channel": {"wavelength_id": "A01_C02"}}, + "in_2": {"type": "image", "channel": {"label": "DAPI"}}, + "in_3": {"type": "label", "label_name": "label_DAPI"}, + } + + output_specs = { + "out_1": {"type": "label", "label_name": "label_DAPI_new"}, + "out_2": {"type": "dataframe", "table_name": "measurements"}, + } + ``` + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + (standard argument for Fractal tasks, managed by Fractal server). + workflow_file: Absolute path to napari-workflows YAML file + input_specs: A dictionary of `NapariWorkflowsInput` values. + output_specs: A dictionary of `NapariWorkflowsOutput` values. + input_ROI_table: Name of the ROI table over which the task loops to + apply napari workflows. + Examples: + `FOV_ROI_table` + => loop over the field of views; + `organoid_ROI_table` + => loop over the organoid ROI table (generated by another task); + `well_ROI_table` + => process the whole well as one image. + level: Pyramid level of the image to be used as input for + napari-workflows. Choose `0` to process at full resolution. + Levels > 0 are currently only supported for workflows that only + have intensity images as input and only produce a label images as + output. + relabeling: If `True`, apply relabeling so that label values are + unique across all ROIs in the well. + expected_dimensions: Expected dimensions (either `2` or `3`). Useful + when loading 2D images that are stored in a 3D array with shape + `(1, size_x, size_y)` [which is the default way Fractal stores 2D + images], but you want to make sure the napari workflow gets a 2D + array to process. Also useful to set to `2` when loading a 2D + OME-Zarr that is saved as `(size_x, size_y)`. + overwrite: If `True`, overwrite the task output. + """ + wf: napari_workflows.Worfklow = load_workflow(workflow_file) + logger.info(f"Loaded workflow from {workflow_file}") + + # Validation of input/output specs + if not (set(wf.leafs()) <= set(output_specs.keys())): + msg = f"Some item of {wf.leafs()=} is not part of {output_specs=}." + logger.warning(msg) + if not (set(wf.roots()) <= set(input_specs.keys())): + msg = f"Some item of {wf.roots()=} is not part of {input_specs=}." + logger.error(msg) + raise ValueError(msg) + list_outputs = sorted(output_specs.keys()) + + # Characterization of workflow and scope restriction + input_types = [in_params.type for (name, in_params) in input_specs.items()] + output_types = [ + out_params.type for (name, out_params) in output_specs.items() + ] + are_inputs_all_images = set(input_types) == {"image"} + are_outputs_all_labels = set(output_types) == {"label"} + are_outputs_all_dataframes = set(output_types) == {"dataframe"} + is_labeling_workflow = are_inputs_all_images and are_outputs_all_labels + is_measurement_only_workflow = are_outputs_all_dataframes + # Level-related constraint + logger.info(f"This workflow acts at {level=}") + logger.info( + f"Is the current workflow a labeling one? {is_labeling_workflow}" + ) + if level > 0 and not is_labeling_workflow: + msg = ( + f"{level=}>0 is currently only accepted for labeling workflows, " + "i.e. those going from image(s) to label(s)" + ) + logger.error(msg) + raise OutOfTaskScopeError(msg) + # Relabeling-related (soft) constraint + if is_measurement_only_workflow and relabeling: + logger.warning( + "This is a measurement-output-only workflow, setting " + "relabeling=False." + ) + relabeling = False + if relabeling: + max_label_for_relabeling = 0 + + label_dtype = np.uint32 + + # Read ROI table + ROI_table = ad.read_zarr(f"{zarr_url}/tables/{input_ROI_table}") + + # Load image metadata + ngff_image_meta = load_NgffImageMeta(zarr_url) + num_levels = ngff_image_meta.num_levels + coarsening_xy = ngff_image_meta.coarsening_xy + + # Read pixel sizes from zattrs file + full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) + + # Create list of indices for 3D FOVs spanning the entire Z direction + list_indices = convert_ROI_table_to_indices( + ROI_table, + level=level, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) + check_valid_ROI_indices(list_indices, input_ROI_table) + num_ROIs = len(list_indices) + logger.info( + f"Completed reading ROI table {input_ROI_table}," + f" found {num_ROIs} ROIs." + ) + + # Input preparation: "image" type + image_inputs = [ + (name, in_params) + for (name, in_params) in input_specs.items() + if in_params.type == "image" + ] + input_image_arrays = {} + if image_inputs: + img_array = da.from_zarr(f"{zarr_url}/{level}") + # Loop over image inputs and assign corresponding channel of the image + for name, params in image_inputs: + channel = get_channel_from_image_zarr( + image_zarr_path=zarr_url, + wavelength_id=params.channel.wavelength_id, + label=params.channel.label, + ) + channel_index = channel.index + input_image_arrays[name] = img_array[channel_index] + + # Handle dimensions + shape = input_image_arrays[name].shape + if expected_dimensions == 3 and shape[0] == 1: + logger.warning( + f"Input {name} has shape {shape} " + f"but {expected_dimensions=}" + ) + if expected_dimensions == 2: + if len(shape) == 2: + # We already load the data as a 2D array + pass + elif shape[0] == 1: + input_image_arrays[name] = input_image_arrays[name][ + 0, :, : + ] + else: + msg = ( + f"Input {name} has shape {shape} " + f"but {expected_dimensions=}" + ) + logger.error(msg) + raise ValueError(msg) + logger.info(f"Prepared input with {name=} and {params=}") + logger.info(f"{input_image_arrays=}") + + # Input preparation: "label" type + label_inputs = [ + (name, in_params) + for (name, in_params) in input_specs.items() + if in_params.type == "label" + ] + if label_inputs: + # Set target_shape for upscaling labels + if not image_inputs: + logger.warning( + f"{len(label_inputs)=} but num_image_inputs=0. " + "Label array(s) will not be upscaled." + ) + upscale_labels = False + else: + target_shape = list(input_image_arrays.values())[0].shape + upscale_labels = True + # Loop over label inputs and load corresponding (upscaled) image + input_label_arrays = {} + for name, params in label_inputs: + label_name = params.label_name + label_array_raw = da.from_zarr( + f"{zarr_url}/labels/{label_name}/{level}" + ) + input_label_arrays[name] = label_array_raw + + # Handle dimensions + shape = input_label_arrays[name].shape + if expected_dimensions == 3 and shape[0] == 1: + logger.warning( + f"Input {name} has shape {shape} " + f"but {expected_dimensions=}" + ) + if expected_dimensions == 2: + if len(shape) == 2: + # We already load the data as a 2D array + pass + elif shape[0] == 1: + input_label_arrays[name] = input_label_arrays[name][ + 0, :, : + ] + else: + msg = ( + f"Input {name} has shape {shape} " + f"but {expected_dimensions=}" + ) + logger.error(msg) + raise ValueError(msg) + + if upscale_labels: + # Check that dimensionality matches the image + if len(input_label_arrays[name].shape) != len(target_shape): + raise ValueError( + f"Label {name} has shape " + f"{input_label_arrays[name].shape}. " + "But the corresponding image has shape " + f"{target_shape}. Those dimensionalities do not " + f"match. Is {expected_dimensions=} the correct " + "setting?" + ) + if expected_dimensions == 3: + upscaling_axes = [1, 2] + else: + upscaling_axes = [0, 1] + input_label_arrays[name] = upscale_array( + array=input_label_arrays[name], + target_shape=target_shape, + axis=upscaling_axes, + pad_with_zeros=True, + ) + + logger.info(f"Prepared input with {name=} and {params=}") + logger.info(f"{input_label_arrays=}") + + # Output preparation: "label" type + label_outputs = [ + (name, out_params) + for (name, out_params) in output_specs.items() + if out_params.type == "label" + ] + if label_outputs: + # Preliminary scope checks + if len(label_outputs) > 1: + raise OutOfTaskScopeError( + "Multiple label outputs would break label-inputs-only " + f"workflows (found {len(label_outputs)=})." + ) + if len(label_outputs) > 1 and relabeling: + raise OutOfTaskScopeError( + "Multiple label outputs would break relabeling in labeling+" + f"measurement workflows (found {len(label_outputs)=})." + ) + + # We only support two cases: + # 1. If there exist some input images, then use the first one to + # determine output-label array properties + # 2. If there are no input images, but there are input labels, then (A) + # re-load the pixel sizes and re-build ROI indices, and (B) use the + # first input label to determine output-label array properties + if image_inputs: + reference_array = list(input_image_arrays.values())[0] + elif label_inputs: + reference_array = list(input_label_arrays.values())[0] + # Re-load pixel size, matching to the correct level + input_label_name = label_inputs[0][1].label_name + ngff_label_image_meta = load_NgffImageMeta( + f"{zarr_url}/labels/{input_label_name}" + ) + full_res_pxl_sizes_zyx = ngff_label_image_meta.get_pixel_sizes_zyx( + level=0 + ) + # Create list of indices for 3D FOVs spanning the whole Z direction + list_indices = convert_ROI_table_to_indices( + ROI_table, + level=level, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) + check_valid_ROI_indices(list_indices, input_ROI_table) + num_ROIs = len(list_indices) + logger.info( + f"Re-create ROI indices from ROI table {input_ROI_table}, " + f"using {full_res_pxl_sizes_zyx=}. " + "This is necessary because label-input-only workflows may " + "have label inputs that are at a different resolution and " + "are not upscaled." + ) + else: + msg = ( + "Missing image_inputs and label_inputs, we cannot assign" + " label output properties" + ) + raise OutOfTaskScopeError(msg) + + # Extract label properties from reference_array, and make sure they are + # for three dimensions + label_shape = reference_array.shape + label_chunksize = reference_array.chunksize + if len(label_shape) == 2 and len(label_chunksize) == 2: + if expected_dimensions == 3: + raise ValueError( + f"Something wrong: {label_shape=} but " + f"{expected_dimensions=}" + ) + label_shape = (1, label_shape[0], label_shape[1]) + label_chunksize = (1, label_chunksize[0], label_chunksize[1]) + logger.info(f"{label_shape=}") + logger.info(f"{label_chunksize=}") + + # Loop over label outputs and (1) set zattrs, (2) create zarr group + output_label_zarr_groups: dict[str, Any] = {} + for name, out_params in label_outputs: + # (1a) Rescale OME-NGFF datasets (relevant for level>0) + if not ngff_image_meta.multiscale.axes[0].name == "c": + raise ValueError( + "Cannot set `remove_channel_axis=True` for multiscale " + f"metadata with axes={ngff_image_meta.multiscale.axes}. " + 'First axis should have name "c".' + ) + new_datasets = rescale_datasets( + datasets=[ + ds.model_dump() + for ds in ngff_image_meta.multiscale.datasets + ], + coarsening_xy=coarsening_xy, + reference_level=level, + remove_channel_axis=True, + ) + + # (1b) Prepare attrs for label group + label_name = out_params.label_name + label_attrs = { + "image-label": { + "version": __OME_NGFF_VERSION__, + "source": {"image": "../../"}, + }, + "multiscales": [ + { + "name": label_name, + "version": __OME_NGFF_VERSION__, + "axes": [ + ax.model_dump() + for ax in ngff_image_meta.multiscale.axes + if ax.type != "channel" + ], + "datasets": new_datasets, + } + ], + } + + # (2) Prepare label group + image_group = zarr.group(zarr_url) + label_group = prepare_label_group( + image_group, + label_name, + overwrite=overwrite, + label_attrs=label_attrs, + logger=logger, + ) + logger.info( + "Helper function `prepare_label_group` returned " + f"{label_group=}" + ) + + # (3) Create zarr group at level=0 + store = zarr.storage.FSStore(f"{zarr_url}/labels/{label_name}/0") + mask_zarr = zarr.create( + shape=label_shape, + chunks=label_chunksize, + dtype=label_dtype, + store=store, + overwrite=overwrite, + dimension_separator="/", + ) + output_label_zarr_groups[name] = mask_zarr + logger.info(f"Prepared output with {name=} and {out_params=}") + logger.info(f"{output_label_zarr_groups=}") + + # Output preparation: "dataframe" type + dataframe_outputs = [ + (name, out_params) + for (name, out_params) in output_specs.items() + if out_params.type == "dataframe" + ] + output_dataframe_lists: dict[str, list] = {} + for name, out_params in dataframe_outputs: + output_dataframe_lists[name] = [] + logger.info(f"Prepared output with {name=} and {out_params=}") + logger.info(f"{output_dataframe_lists=}") + + ##### + + for i_ROI, indices in enumerate(list_indices): + s_z, e_z, s_y, e_y, s_x, e_x = indices[:] + region = (slice(s_z, e_z), slice(s_y, e_y), slice(s_x, e_x)) + + logger.info(f"ROI {i_ROI+1}/{num_ROIs}: {region=}") + + # Always re-load napari worfklow + wf = load_workflow(workflow_file) + + # Set inputs + for input_name in input_specs.keys(): + input_type = input_specs[input_name].type + + if input_type == "image": + wf.set( + input_name, + load_region( + input_image_arrays[input_name], + region, + compute=True, + return_as_3D=False, + ), + ) + elif input_type == "label": + wf.set( + input_name, + load_region( + input_label_arrays[input_name], + region, + compute=True, + return_as_3D=False, + ), + ) + + # Get outputs + outputs = wf.get(list_outputs) + + # Iterate first over dataframe outputs (to use the correct + # max_label_for_relabeling, if needed) + for ind_output, output_name in enumerate(list_outputs): + if output_specs[output_name].type != "dataframe": + continue + df = outputs[ind_output] + if relabeling: + df["label"] += max_label_for_relabeling + logger.info( + f'ROI {i_ROI+1}/{num_ROIs}: Relabeling "{name}" dataframe' + "output, with {max_label_for_relabeling=}" + ) + + # Append the new-ROI dataframe to the all-ROIs list + output_dataframe_lists[output_name].append(df) + + # After all dataframe outputs, iterate over label outputs (which + # actually can be only 0 or 1) + for ind_output, output_name in enumerate(list_outputs): + if output_specs[output_name].type != "label": + continue + mask = outputs[ind_output] + + # Check dimensions + if len(mask.shape) != expected_dimensions: + msg = ( + f"Output {output_name} has shape {mask.shape} " + f"but {expected_dimensions=}" + ) + logger.error(msg) + raise ValueError(msg) + elif expected_dimensions == 2: + mask = np.expand_dims(mask, axis=0) + + # Sanity check: issue warning for non-consecutive labels + unique_labels = np.unique(mask) + num_unique_labels_in_this_ROI = len(unique_labels) + if np.min(unique_labels) == 0: + num_unique_labels_in_this_ROI -= 1 + num_labels_in_this_ROI = int(np.max(mask)) + if num_labels_in_this_ROI != num_unique_labels_in_this_ROI: + logger.warning( + f'ROI {i_ROI+1}/{num_ROIs}: "{name}" label output has' + f"non-consecutive labels: {num_labels_in_this_ROI=} but" + f"{num_unique_labels_in_this_ROI=}" + ) + + if relabeling: + mask[mask > 0] += max_label_for_relabeling + logger.info( + f'ROI {i_ROI+1}/{num_ROIs}: Relabeling "{name}" label ' + f"output, with {max_label_for_relabeling=}" + ) + max_label_for_relabeling += num_labels_in_this_ROI + logger.info( + f"ROI {i_ROI+1}/{num_ROIs}: label-number update with " + f"{num_labels_in_this_ROI=}; " + f"new {max_label_for_relabeling=}" + ) + + da.array(mask).to_zarr( + url=output_label_zarr_groups[output_name], + region=region, + compute=True, + overwrite=overwrite, + ) + logger.info(f"ROI {i_ROI+1}/{num_ROIs}: output handling complete") + + # Output handling: "dataframe" type (for each output, concatenate ROI + # dataframes, clean up, and store in a AnnData table on-disk) + for name, out_params in dataframe_outputs: + table_name = out_params.table_name + # Concatenate all FOV dataframes + list_dfs = output_dataframe_lists[name] + if len(list_dfs) == 0: + measurement_table = ad.AnnData() + else: + df_well = pd.concat(list_dfs, axis=0, ignore_index=True) + # Extract labels and drop them from df_well + labels = pd.DataFrame(df_well["label"].astype(str)) + df_well.drop(labels=["label"], axis=1, inplace=True) + # Convert all to float (warning: some would be int, in principle) + measurement_dtype = np.float32 + df_well = df_well.astype(measurement_dtype) + df_well.index = df_well.index.map(str) + # Convert to anndata + measurement_table = ad.AnnData(df_well, dtype=measurement_dtype) + measurement_table.obs = labels + + # Write to zarr group + image_group = zarr.group(zarr_url) + table_attrs = dict( + type="feature_table", + region=dict(path=f"../labels/{out_params.label_name}"), + instance_key="label", + ) + write_table( + image_group, + table_name, + measurement_table, + overwrite=overwrite, + table_attrs=table_attrs, + ) + + # Output handling: "label" type (for each output, build and write to disk + # pyramid of coarser levels) + for name, out_params in label_outputs: + label_name = out_params.label_name + build_pyramid( + zarrurl=f"{zarr_url}/labels/{label_name}", + overwrite=overwrite, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + chunksize=label_chunksize, + aggregation_function=np.max, + ) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=napari_workflows_wrapper, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/projection.py b/fractal_tasks_core/tasks/projection.py new file mode 100644 index 000000000..723cc1ba5 --- /dev/null +++ b/fractal_tasks_core/tasks/projection.py @@ -0,0 +1,145 @@ +# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +# University of Zurich +# +# Original authors: +# Tommaso Comparin +# Marco Franzon +# +# This file is part of Fractal and was originally developed by eXact lab S.r.l. +# under contract with Liberali Lab from the Friedrich Miescher +# Institute for Biomedical Research and Pelkmans Lab from the University of +# Zurich. +""" +Task for 3D->2D maximum-intensity projection. +""" +import logging +from typing import Any + +import dask.array as da +from ngio import NgffImage +from ngio.core import Image +from pydantic import validate_call + +from fractal_tasks_core.tasks.io_models import InitArgsMIP +from fractal_tasks_core.tasks.projection_utils import DaskProjectionMethod + +logger = logging.getLogger(__name__) + + +def _compute_new_shape(source_image: Image) -> tuple[int]: + """Compute the new shape of the image after the projection. + + The new shape is the same as the original one, + except for the z-axis, which is set to 1. + """ + on_disk_shape = source_image.on_disk_shape + logger.info(f"Source {on_disk_shape=}") + + on_disk_z_index = source_image.dataset.on_disk_axes_names.index("z") + + dest_on_disk_shape = list(on_disk_shape) + dest_on_disk_shape[on_disk_z_index] = 1 + logger.info(f"Destination {dest_on_disk_shape=}") + return tuple(dest_on_disk_shape) + + +@validate_call +def projection( + *, + # Fractal parameters + zarr_url: str, + init_args: InitArgsMIP, +) -> dict[str, Any]: + """ + Perform intensity projection along Z axis with a chosen method. + + Note: this task stores the output in a new zarr file. + + Args: + zarr_url: Path or url to the individual OME-Zarr image to be processed. + (standard argument for Fractal tasks, managed by Fractal server). + init_args: Intialization arguments provided by + `create_cellvoyager_ome_zarr_init`. + """ + method = DaskProjectionMethod(init_args.method) + logger.info(f"{init_args.origin_url=}") + logger.info(f"{zarr_url=}") + logger.info(f"{method=}") + + # Read image metadata + original_ngff_image = NgffImage(init_args.origin_url) + orginal_image = original_ngff_image.get_image() + + if orginal_image.is_2d or orginal_image.is_2d_time_series: + raise ValueError( + "The input image is 2D, " + "projection is only supported for 3D images." + ) + + # Compute the new shape and pixel size + dest_on_disk_shape = _compute_new_shape(orginal_image) + + dest_pixel_size = orginal_image.pixel_size + dest_pixel_size.z = 1.0 + logger.info(f"New shape: {dest_on_disk_shape=}") + + # Create the new empty image + new_ngff_image = original_ngff_image.derive_new_image( + store=zarr_url, + name="MIP", + on_disk_shape=dest_on_disk_shape, + pixel_sizes=dest_pixel_size, + overwrite=init_args.overwrite, + copy_labels=False, + copy_tables=True, + ) + logger.info(f"New Projection image created - {new_ngff_image=}") + new_image = new_ngff_image.get_image() + + # Process the image + z_axis_index = orginal_image.find_axis("z") + source_dask = orginal_image.get_array( + mode="dask", preserve_dimensions=True + ) + + dest_dask = method.apply(dask_array=source_dask, axis=z_axis_index) + dest_dask = da.expand_dims(dest_dask, axis=z_axis_index) + new_image.set_array(dest_dask) + new_image.consolidate() + # Ends + + # Copy over the tables + for roi_table_name in new_ngff_image.tables.list(table_type="roi_table"): + table = new_ngff_image.tables.get_table(roi_table_name) + + roi_list = [] + for roi in table.rois: + roi.z = 0.0 + roi.z_length = 1.0 + roi_list.append(roi) + + table.set_rois(roi_list, overwrite=True) + table.consolidate() + logger.info(f"Table {roi_table_name} Projection done") + + # Generate image_list_updates + image_list_update_dict = dict( + image_list_updates=[ + dict( + zarr_url=zarr_url, + origin=init_args.origin_url, + attributes=dict(plate=init_args.new_plate_name), + types=dict(is_3D=False), + ) + ] + ) + return image_list_update_dict + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=projection, + logger_name=logger.name, + ) diff --git a/fractal_tasks_core/tasks/projection_utils.py b/fractal_tasks_core/tasks/projection_utils.py new file mode 100644 index 000000000..4d71fa25a --- /dev/null +++ b/fractal_tasks_core/tasks/projection_utils.py @@ -0,0 +1,136 @@ +from enum import Enum +from typing import Any +from typing import Dict + +import dask.array as da +import numpy as np + + +def safe_sum( + dask_array: da.Array, axis: int = 0, **kwargs: Dict[str, Any] +) -> da.Array: + """ + Perform a safe sum on a Dask array to avoid overflow, by clipping the + result of da.sum & casting it to its original dtype. + + Dask.array already correctly handles promotion to uin32 or uint64 when + necessary internally, but we want to ensure we clip the result. + + Args: + dask_array (dask.array.Array): The input Dask array. + axis (int, optional): The axis along which to sum the array. + Defaults to 0. + **kwargs: Additional keyword arguments passed to da.sum. + + Returns: + dask.array.Array: The result of the sum, safely clipped and cast + back to the original dtype. + """ + # Handle empty array + if any(dim == 0 for dim in dask_array.shape): + return dask_array + + # Determine the original dtype + original_dtype = dask_array.dtype + max_value = np.iinfo(original_dtype).max + + # Perform the sum + result = da.sum(dask_array, axis=axis, **kwargs) + + # Clip the values to the maximum possible value for the original dtype + result = da.clip(result, 0, max_value) + + # Cast back to the original dtype + result = result.astype(original_dtype) + + return result + + +def mean_wrapper( + dask_array: da.Array, axis: int = 0, **kwargs: Dict[str, Any] +) -> da.Array: + """ + Perform a da.mean on the dask_array & cast it to its original dtype. + + Without casting, the result can change dtype to e.g. float64 + + Args: + dask_array (dask.array.Array): The input Dask array. + axis (int, optional): The axis along which to mean the array. + Defaults to 0. + **kwargs: Additional keyword arguments passed to da.mean. + + Returns: + dask.array.Array: The result of the mean, cast back to the original + dtype. + """ + # Handle empty array + if any(dim == 0 for dim in dask_array.shape): + return dask_array + + # Determine the original dtype + original_dtype = dask_array.dtype + + # Perform the sum + result = da.mean(dask_array, axis=axis, **kwargs) + + # Cast back to the original dtype + result = result.astype(original_dtype) + + return result + + +class DaskProjectionMethod(Enum): + """ + Registration method selection + + Choose which method to use for intensity projection along the Z axis. + + Attributes: + MIP: Maximum intensity projection + MINIP: Minimum intensityp projection + MEANIP: Mean intensity projection + SUMIP: Sum intensityp projection + """ + + MIP = "mip" + MINIP = "minip" + MEANIP = "meanip" + SUMIP = "sumip" + + def apply( + self, dask_array: da.Array, axis: int = 0, **kwargs: Dict[str, Any] + ) -> da.Array: + """ + Apply the selected projection method to the given Dask array. + + Args: + dask_array (dask.array.Array): The Dask array to project. + axis (int): The axis along which to apply the projection. + **kwargs: Additional keyword arguments to pass to the projection + method. + + Returns: + dask.array.Array: The resulting Dask array after applying the + projection. + + Example: + >>> array = da.random.random((1000, 1000), chunks=(100, 100)) + >>> method = DaskProjectionMethod.MAX + >>> result = method.apply(array, axis=0) + >>> computed_result = result.compute() + >>> print(computed_result) + """ + # Map the Enum values to the actual Dask array methods + method_map = { + DaskProjectionMethod.MIP: lambda arr, axis, **kw: arr.max( + axis=axis, **kw + ), + DaskProjectionMethod.MINIP: lambda arr, axis, **kw: arr.min( + axis=axis, **kw + ), + DaskProjectionMethod.MEANIP: mean_wrapper, + DaskProjectionMethod.SUMIP: safe_sum, + } + # Call the appropriate method, passing in the dask_array explicitly + return method_map[self](dask_array, axis=axis, **kwargs)