diff --git a/fractal_tasks_core/tasks/__init__.py b/fractal_tasks_core/tasks/__init__.py deleted file mode 100644 index d959b02bb..000000000 --- a/fractal_tasks_core/tasks/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Tasks subpackage (requires installation extra `fractal-tasks`). -""" diff --git a/fractal_tasks_core/tasks/_registration_utils.py b/fractal_tasks_core/tasks/_registration_utils.py deleted file mode 100755 index 8cac07736..000000000 --- a/fractal_tasks_core/tasks/_registration_utils.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2024 (C) BioVisionCenter -# -# Original authors: -# Joel Lüthi -# -# This file is part of Fractal -"""Utils functions for registration""" -import copy - -import anndata as ad -import dask.array as da -import numpy as np -import pandas as pd -from image_registration import chi2_shift - - -def calculate_physical_shifts( - shifts: np.array, - level: int, - coarsening_xy: int, - full_res_pxl_sizes_zyx: list[float], -) -> list[float]: - """ - Calculates shifts in physical units based on pixel shifts - - Args: - shifts: array of shifts, zyx or yx - level: resolution level - coarsening_xy: coarsening factor between levels - full_res_pxl_sizes_zyx: pixel sizes in physical units as zyx - - Returns: - shifts in physical units as zyx - """ - - curr_pixel_size = np.array(full_res_pxl_sizes_zyx) * coarsening_xy**level - if len(shifts) == 3: - shifts_physical = shifts * curr_pixel_size - elif len(shifts) == 2: - shifts_physical = [ - 0, - shifts[0] * curr_pixel_size[1], - shifts[1] * curr_pixel_size[2], - ] - else: - raise ValueError( - f"Wrong input for calculate_physical_shifts ({shifts=})" - ) - return shifts_physical - - -def get_ROI_table_with_translation( - ROI_table: ad.AnnData, - new_shifts: dict[str, list[float]], -) -> ad.AnnData: - """ - Adds translation columns to a ROI table - - Args: - ROI_table: Fractal ROI table - new_shifts: zyx list of shifts - - Returns: - Fractal ROI table with 3 additional columns for calculated translations - """ - - shift_table = pd.DataFrame(new_shifts).T - shift_table.columns = ["translation_z", "translation_y", "translation_x"] - shift_table = shift_table.rename_axis("FieldIndex") - new_roi_table = ROI_table.to_df().merge( - shift_table, left_index=True, right_index=True - ) - if len(new_roi_table) != len(ROI_table): - raise ValueError( - "New ROI table with registration info has a " - f"different length ({len(new_roi_table)=}) " - f"from the original ROI table ({len(ROI_table)=})" - ) - - adata = ad.AnnData(X=new_roi_table.astype(np.float32)) - adata.obs_names = new_roi_table.index - adata.var_names = list(map(str, new_roi_table.columns)) - return adata - - -# Helper functions -def add_zero_translation_columns(ad_table: ad.AnnData): - """ - Add three zero-filled columns (`translation_{x,y,z}`) to an AnnData table. - """ - columns = ["translation_z", "translation_y", "translation_x"] - if ad_table.var.index.isin(columns).any().any(): - raise ValueError( - "The roi table already contains translation columns. Did you " - "enter a wrong reference acquisition?" - ) - df = pd.DataFrame(np.zeros([len(ad_table), 3]), columns=columns) - df.index = ad_table.obs.index - ad_new = ad.concat([ad_table, ad.AnnData(df)], axis=1) - return ad_new - - -def calculate_min_max_across_dfs(tables_list): - # Initialize dataframes to store the max and min values - max_df = pd.DataFrame( - index=tables_list[0].index, columns=tables_list[0].columns - ) - min_df = pd.DataFrame( - index=tables_list[0].index, columns=tables_list[0].columns - ) - - # Loop through the tables and calculate max and min values - for table in tables_list: - max_df = pd.DataFrame( - np.maximum(max_df.values, table.values), - columns=max_df.columns, - index=max_df.index, - ) - min_df = pd.DataFrame( - np.minimum(min_df.values, table.values), - columns=min_df.columns, - index=min_df.index, - ) - - return max_df, min_df - - -def apply_registration_to_single_ROI_table( - roi_table: ad.AnnData, - max_df: pd.DataFrame, - min_df: pd.DataFrame, -) -> ad.AnnData: - """ - Applies the registration to a ROI table - - Calculates the new position as: p = position + max(shift, 0) - own_shift - Calculates the new len as: l = len - max(shift, 0) + min(shift, 0) - - Args: - roi_table: AnnData table which contains a Fractal ROI table. - Rows are ROIs - max_df: Max translation shift in z, y, x for each ROI. Rows are ROIs, - columns are translation_z, translation_y, translation_x - min_df: Min translation shift in z, y, x for each ROI. Rows are ROIs, - columns are translation_z, translation_y, translation_x - Returns: - ROI table where all ROIs are registered to the smallest common area - across all acquisitions. - """ - roi_table = copy.deepcopy(roi_table) - rois = roi_table.obs.index - if (rois != max_df.index).all() or (rois != min_df.index).all(): - raise ValueError( - "ROI table and max & min translation need to contain the same " - f"ROIS, but they were {rois=}, {max_df.index=}, {min_df.index=}" - ) - - for roi in rois: - roi_table[[roi], ["z_micrometer"]] = ( - roi_table[[roi], ["z_micrometer"]].X - + float(max_df.loc[roi, "translation_z"]) - - roi_table[[roi], ["translation_z"]].X - ) - roi_table[[roi], ["y_micrometer"]] = ( - roi_table[[roi], ["y_micrometer"]].X - + float(max_df.loc[roi, "translation_y"]) - - roi_table[[roi], ["translation_y"]].X - ) - roi_table[[roi], ["x_micrometer"]] = ( - roi_table[[roi], ["x_micrometer"]].X - + float(max_df.loc[roi, "translation_x"]) - - roi_table[[roi], ["translation_x"]].X - ) - # This calculation only works if all ROIs are the same size initially! - roi_table[[roi], ["len_z_micrometer"]] = ( - roi_table[[roi], ["len_z_micrometer"]].X - - float(max_df.loc[roi, "translation_z"]) - + float(min_df.loc[roi, "translation_z"]) - ) - roi_table[[roi], ["len_y_micrometer"]] = ( - roi_table[[roi], ["len_y_micrometer"]].X - - float(max_df.loc[roi, "translation_y"]) - + float(min_df.loc[roi, "translation_y"]) - ) - roi_table[[roi], ["len_x_micrometer"]] = ( - roi_table[[roi], ["len_x_micrometer"]].X - - float(max_df.loc[roi, "translation_x"]) - + float(min_df.loc[roi, "translation_x"]) - ) - return roi_table - - -def chi2_shift_out(img_ref, img_cycle_x) -> list[np.ndarray]: - """ - Helper function to get the output of chi2_shift into the same format as - phase_cross_correlation. Calculates the shift between two images using - the chi2_shift method. - - Args: - img_ref (np.ndarray): First image. - img_cycle_x (np.ndarray): Second image. - - Returns: - List containing numpy array of shift in y and x direction. - """ - x, y, a, b = chi2_shift(np.squeeze(img_ref), np.squeeze(img_cycle_x)) - - """ - Running into issues when using direct float output for fractal. - When rounding to integer and using integer dtype, it typically works - but for some reasons fails when run over a whole 384 well plate (but - the well where it fails works fine when run alone). For now, rounding - to integer, but still using float64 dtype (like the scikit-image - phase cross correlation function) seems to be the safest option. - """ - shifts = np.array([-np.round(y), -np.round(x)], dtype="float64") - # return as a list to adhere to the phase_cross_correlation output format - return [shifts] - - -def is_3D(dask_array: da.array) -> bool: - """ - Check if a dask array is 3D. - - Treats singelton Z dimensions as 2D images. - (1, 2000, 2000) => False - (10, 2000, 2000) => True - - Args: - dask_array: Input array to be checked - - Returns: - bool on whether the array is 3D - """ - if len(dask_array.shape) == 3 and dask_array.shape[0] > 1: - return True - else: - return False diff --git a/fractal_tasks_core/tasks/_utils.py b/fractal_tasks_core/tasks/_utils.py deleted file mode 100644 index a7c0a4b89..000000000 --- a/fractal_tasks_core/tasks/_utils.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Standard input/output interface for tasks. -""" -import json -import logging -from argparse import ArgumentParser -from json import JSONEncoder -from pathlib import Path -from typing import Callable -from typing import Optional - - -class TaskParameterEncoder(JSONEncoder): - """ - Custom JSONEncoder that transforms Path objects to strings. - """ - - def default(self, value): - """ - Subclass implementation of `default`, to serialize Path objects as - strings. - """ - if isinstance(value, Path): - return value.as_posix() - return JSONEncoder.default(self, value) - - -def run_fractal_task( - *, - task_function: Callable, - logger_name: Optional[str] = None, -): - """ - Implement standard task interface and call task_function. - - Args: - task_function: the callable function that runs the task. - logger_name: TBD - """ - - # Parse `-j` and `--metadata-out` arguments - parser = ArgumentParser() - parser.add_argument( - "--args-json", help="Read parameters from json file", required=True - ) - parser.add_argument( - "--out-json", - help="Output file to redirect serialised returned data", - required=True, - ) - parsed_args = parser.parse_args() - - # Set logger - logger = logging.getLogger(logger_name) - - # Preliminary check - if Path(parsed_args.out_json).exists(): - logger.error( - f"Output file {parsed_args.out_json} already exists. Terminating" - ) - exit(1) - - # Read parameters dictionary - with open(parsed_args.args_json, "r") as f: - pars = json.load(f) - - # Run task - logger.info(f"START {task_function.__name__} task") - metadata_update = task_function(**pars) - logger.info(f"END {task_function.__name__} task") - - # Write output metadata to file, with custom JSON encoder - with open(parsed_args.out_json, "w") as fout: - json.dump(metadata_update, fout, cls=TaskParameterEncoder, indent=2) diff --git a/fractal_tasks_core/tasks/_zarr_utils.py b/fractal_tasks_core/tasks/_zarr_utils.py deleted file mode 100644 index 0f030272b..000000000 --- a/fractal_tasks_core/tasks/_zarr_utils.py +++ /dev/null @@ -1,205 +0,0 @@ -import copy -import logging - -import anndata as ad -import zarr -from filelock import FileLock - -from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tables.v1 import get_tables_list_v1 -from fractal_tasks_core.utils import _split_well_path_image_path - -logger = logging.getLogger(__name__) - - -def _copy_hcs_ome_zarr_metadata( - zarr_url_origin: str, - zarr_url_new: str, -) -> None: - """ - Updates the necessary metadata for a new copy of an OME-Zarr image - - Based on an existing OME-Zarr image in the same well, the metadata is - copied and added to the new zarr well. Additionally, the well-level - metadata is updated to include this new image. - - Args: - zarr_url_origin: zarr_url of the origin image - zarr_url_new: zarr_url of the newly created image. The zarr-group - already needs to exist, but metadata is written by this function. - """ - # Copy over OME-Zarr metadata for illumination_corrected image - # See #681 for discussion for validation of this zattrs - old_image_group = zarr.open_group(zarr_url_origin, mode="r") - old_attrs = old_image_group.attrs.asdict() - zarr_url_new = zarr_url_new.rstrip("/") - new_image_group = zarr.group(zarr_url_new) - new_image_group.attrs.put(old_attrs) - - # Update well metadata about adding the new image: - new_image_path = zarr_url_new.split("/")[-1] - well_url, old_image_path = _split_well_path_image_path(zarr_url_origin) - _update_well_metadata(well_url, old_image_path, new_image_path) - - -def _update_well_metadata( - well_url: str, - old_image_path: str, - new_image_path: str, - timeout: int = 120, -) -> None: - """ - Update the well metadata by adding the new_image_path to the image list. - - The content of new_image_path will be based on old_image_path, the origin - for the new image that was created. - This function aims to avoid race conditions with other processes that try - to update the well metadata file by using FileLock & Timeouts - - Args: - well_url: Path to the HCS OME-Zarr well that needs to be updated - old_image_path: path relative to well_url where the original image is - found - new_image_path: path relative to well_url where the new image is placed - timeout: Timeout in seconds for trying to get the file lock - """ - lock = FileLock(f"{well_url}/.zattrs.lock") - with lock.acquire(timeout=timeout): - well_meta = load_NgffWellMeta(well_url) - existing_well_images = [image.path for image in well_meta.well.images] - if new_image_path in existing_well_images: - raise ValueError( - f"Could not add the {new_image_path=} image to the well " - "metadata because and image with that name " - f"already existed in the well metadata: {well_meta}" - ) - try: - well_meta_image_old = next( - image - for image in well_meta.well.images - if image.path == old_image_path - ) - except StopIteration: - raise ValueError( - f"Could not find an image with {old_image_path=} in the " - "current well metadata." - ) - well_meta_image = copy.deepcopy(well_meta_image_old) - well_meta_image.path = new_image_path - well_meta.well.images.append(well_meta_image) - well_meta.well.images = sorted( - well_meta.well.images, - key=lambda _image: _image.path, - ) - - well_group = zarr.group(well_url) - well_group.attrs.put(well_meta.model_dump(exclude_none=True)) - - # One could catch the timeout with a try except Timeout. But what to do - # with it? - - -def _split_base_suffix(input: str) -> tuple[str, str]: - parts = input.split("_") - base = parts[0] - if len(parts) > 1: - suffix = "_".join(parts[1:]) - else: - suffix = "" - return base, suffix - - -def _get_matching_ref_acquisition_path_heuristic( - path_list: list[str], path: str -) -> str: - """ - Pick the best match from path_list to a given path - - This is a workaround to find the reference registration acquisition when - there are multiple OME-Zarrs with the same acquisition identifier in the - well metadata and we need to find which one is the reference for a given - path. - - Args: - path_list: List of paths to OME-Zarr images in the well metadata. For - example: ['0', '0_illum_corr'] - path: A given path for which we want to find the reference image. For - example, '1_illum_corr' - - Returns: - The best matching reference path. If no direct match is found, it - returns the most similar one based on suffix hierarchy or the base - path if applicable. For example, '0_illum_corr' with the example - inputs above. - """ - - # Extract the base number and suffix from the input path - base, suffix = _split_base_suffix(path) - - # Sort path_list - sorted_path_list = sorted(path_list) - - # Never return the input `path` - if path in sorted_path_list: - sorted_path_list.remove(path) - - # First matching rule: a path with the same suffix - for p in sorted_path_list: - # Split the list path into base and suffix - p_base, p_suffix = _split_base_suffix(p) - # If suffices match, it's the match. - if p_suffix == suffix: - return p - - # If no match is found, return the first entry in the list - logger.warning( - "No heuristic reference acquisition match found, defaulting to first " - f"option {sorted_path_list[0]}." - ) - return sorted_path_list[0] - - -def _copy_tables_from_zarr_url( - origin_zarr_url: str, - target_zarr_url: str, - table_type: str = None, - overwrite: bool = True, -) -> None: - """ - Copies all ROI tables from one Zarr into a new Zarr - - Args: - origin_zarr_url: url of the OME-Zarr image that contains tables. - e.g. /path/to/my_plate.zarr/B/03/0 - target_zarr_url: url of the new OME-Zarr image where tables are copied - to. e.g. /path/to/my_plate.zarr/B/03/0_illum_corr - table_type: Filter for specific table types that should be copied. - overwrite: Whether existing tables of the same name in the - target_zarr_url should be overwritten. - """ - table_list = get_tables_list_v1( - zarr_url=origin_zarr_url, table_type=table_type - ) - - if table_list: - logger.info( - f"Copying the tables {table_list} from {origin_zarr_url} to " - f"{target_zarr_url}." - ) - new_image_group = zarr.group(target_zarr_url) - - for table in table_list: - logger.info(f"Copying table: {table}") - # Get the relevant metadata of the Zarr table & add it - table_url = f"{origin_zarr_url}/tables/{table}" - old_table_group = zarr.open_group(table_url, mode="r") - # Write the Zarr table - curr_table = ad.read_zarr(table_url) - write_table( - new_image_group, - table, - curr_table, - table_attrs=old_table_group.attrs.asdict(), - overwrite=overwrite, - ) diff --git a/fractal_tasks_core/tasks/apply_registration_to_image.py b/fractal_tasks_core/tasks/apply_registration_to_image.py deleted file mode 100644 index 156f5b3e6..000000000 --- a/fractal_tasks_core/tasks/apply_registration_to_image.py +++ /dev/null @@ -1,392 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Joel Lüthi -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Calculates translation for 2D image-based registration -""" -import logging -import os -import shutil -import time -from typing import Callable - -import anndata as ad -import dask.array as da -import numpy as np -import zarr -from pydantic import validate_call - -from fractal_tasks_core.ngff import load_NgffImageMeta -from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta -from fractal_tasks_core.pyramids import build_pyramid -from fractal_tasks_core.roi import ( - convert_indices_to_regions, -) -from fractal_tasks_core.roi import ( - convert_ROI_table_to_indices, -) -from fractal_tasks_core.roi import is_standard_roi_table -from fractal_tasks_core.roi import load_region -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks._zarr_utils import ( - _get_matching_ref_acquisition_path_heuristic, -) -from fractal_tasks_core.tasks._zarr_utils import _update_well_metadata -from fractal_tasks_core.utils import _get_table_path_dict -from fractal_tasks_core.utils import ( - _split_well_path_image_path, -) - -logger = logging.getLogger(__name__) - - -@validate_call -def apply_registration_to_image( - *, - # Fractal parameters - zarr_url: str, - # Core parameters - registered_roi_table: str, - reference_acquisition: int = 0, - overwrite_input: bool = True, -): - """ - Apply registration to images by using a registered ROI table - - This task consists of 4 parts: - - 1. Mask all regions in images that are not available in the - registered ROI table and store each acquisition aligned to the - reference_acquisition (by looping over ROIs). - 2. Do the same for all label images. - 3. Copy all tables from the non-aligned image to the aligned image - (currently only works well if the only tables are well & FOV ROI tables - (registered and original). Not implemented for measurement tables and - other ROI tables). - 4. Clean up: Delete the old, non-aligned image and rename the new, - aligned image to take over its place. - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - (standard argument for Fractal tasks, managed by Fractal server). - registered_roi_table: Name of the ROI table which has been registered - and will be applied to mask and shift the images. - Examples: `registered_FOV_ROI_table` => loop over the field of - views, `registered_well_ROI_table` => process the whole well as - one image. - reference_acquisition: Which acquisition to register against. Uses the - OME-NGFF HCS well metadata acquisition keys to find the reference - acquisition. - overwrite_input: Whether the old image data should be replaced with the - newly registered image data. Currently only implemented for - `overwrite_input=True`. - - """ - logger.info(zarr_url) - logger.info( - f"Running `apply_registration_to_image` on {zarr_url=}, " - f"{registered_roi_table=} and {reference_acquisition=}. " - f"Using {overwrite_input=}" - ) - - well_url, old_img_path = _split_well_path_image_path(zarr_url) - new_zarr_url = f"{well_url}/{zarr_url.split('/')[-1]}_registered" - # Get the zarr_url for the reference acquisition - acq_dict = load_NgffWellMeta(well_url).get_acquisition_paths() - if reference_acquisition not in acq_dict: - raise ValueError( - f"{reference_acquisition=} was not one of the available " - f"acquisitions in {acq_dict=} for well {well_url}" - ) - elif len(acq_dict[reference_acquisition]) > 1: - ref_path = _get_matching_ref_acquisition_path_heuristic( - acq_dict[reference_acquisition], old_img_path - ) - logger.warning( - "Running registration when there are multiple images of the same " - "acquisition in a well. Using a heuristic to match the reference " - f"acquisition. Using {ref_path} as the reference image." - ) - else: - ref_path = acq_dict[reference_acquisition][0] - reference_zarr_url = f"{well_url}/{ref_path}" - - ROI_table_ref = ad.read_zarr( - f"{reference_zarr_url}/tables/{registered_roi_table}" - ) - ROI_table_acq = ad.read_zarr(f"{zarr_url}/tables/{registered_roi_table}") - - ngff_image_meta = load_NgffImageMeta(zarr_url) - coarsening_xy = ngff_image_meta.coarsening_xy - num_levels = ngff_image_meta.num_levels - - #################### - # Process images - #################### - logger.info("Write the registered Zarr image to disk") - write_registered_zarr( - zarr_url=zarr_url, - new_zarr_url=new_zarr_url, - ROI_table=ROI_table_acq, - ROI_table_ref=ROI_table_ref, - num_levels=num_levels, - coarsening_xy=coarsening_xy, - aggregation_function=np.mean, - ) - - #################### - # Process labels - #################### - try: - labels_group = zarr.open_group(f"{zarr_url}/labels", "r") - label_list = labels_group.attrs["labels"] - except (zarr.errors.GroupNotFoundError, KeyError): - label_list = [] - - if label_list: - logger.info(f"Processing the label images: {label_list}") - labels_group = zarr.group(f"{new_zarr_url}/labels") - labels_group.attrs["labels"] = label_list - - for label in label_list: - write_registered_zarr( - zarr_url=f"{zarr_url}/labels/{label}", - new_zarr_url=f"{new_zarr_url}/labels/{label}", - ROI_table=ROI_table_acq, - ROI_table_ref=ROI_table_ref, - num_levels=num_levels, - coarsening_xy=coarsening_xy, - aggregation_function=np.max, - ) - - #################### - # Copy tables - # 1. Copy all standard ROI tables from the reference acquisition. - # 2. Copy all tables that aren't standard ROI tables from the given - # acquisition. - #################### - table_dict_reference = _get_table_path_dict(reference_zarr_url) - table_dict_component = _get_table_path_dict(zarr_url) - - table_dict = {} - # Define which table should get copied: - for table in table_dict_reference: - if is_standard_roi_table(table): - table_dict[table] = table_dict_reference[table] - for table in table_dict_component: - if not is_standard_roi_table(table): - if reference_zarr_url != zarr_url: - logger.warning( - f"{zarr_url} contained a table that is not a standard " - "ROI table. The `Apply Registration To Image task` is " - "best used before additional tables are generated. It " - f"will copy the {table} from this acquisition without " - "applying any transformations. This will work well if " - f"{table} contains measurements. But if {table} is a " - "custom ROI table coming from another task, the " - "transformation is not applied and it will not match " - "with the registered image anymore." - ) - table_dict[table] = table_dict_component[table] - - if table_dict: - logger.info(f"Processing the tables: {table_dict}") - new_image_group = zarr.group(new_zarr_url) - - for table in table_dict.keys(): - logger.info(f"Copying table: {table}") - # Get the relevant metadata of the Zarr table & add it - # See issue #516 for the need for this workaround - max_retries = 20 - sleep_time = 5 - current_round = 0 - while current_round < max_retries: - try: - old_table_group = zarr.open_group( - table_dict[table], mode="r" - ) - current_round = max_retries - except zarr.errors.GroupNotFoundError: - logger.debug( - f"Table {table} not found in attempt {current_round}. " - f"Waiting {sleep_time} seconds before trying again." - ) - current_round += 1 - time.sleep(sleep_time) - # Write the Zarr table - curr_table = ad.read_zarr(table_dict[table]) - write_table( - new_image_group, - table, - curr_table, - table_attrs=old_table_group.attrs.asdict(), - overwrite=True, - ) - - #################### - # Clean up Zarr file - #################### - if overwrite_input: - logger.info( - "Replace original zarr image with the newly created Zarr image" - ) - # Potential for race conditions: Every acquisition reads the - # reference acquisition, but the reference acquisition also gets - # modified - # See issue #516 for the details - os.rename(zarr_url, f"{zarr_url}_tmp") - os.rename(new_zarr_url, zarr_url) - shutil.rmtree(f"{zarr_url}_tmp") - image_list_updates = dict(image_list_updates=[dict(zarr_url=zarr_url)]) - else: - image_list_updates = dict( - image_list_updates=[dict(zarr_url=new_zarr_url, origin=zarr_url)] - ) - # Update the metadata of the the well - well_url, new_img_path = _split_well_path_image_path(new_zarr_url) - _update_well_metadata( - well_url=well_url, - old_image_path=old_img_path, - new_image_path=new_img_path, - ) - - return image_list_updates - - -def write_registered_zarr( - zarr_url: str, - new_zarr_url: str, - ROI_table: ad.AnnData, - ROI_table_ref: ad.AnnData, - num_levels: int, - coarsening_xy: int = 2, - aggregation_function: Callable = np.mean, -): - """ - Write registered zarr array based on ROI tables - - This function loads the image or label data from a zarr array based on the - ROI bounding-box coordinates and stores them into a new zarr array. - The new Zarr array has the same shape as the original array, but will have - 0s where the ROI tables don't specify loading of the image data. - The ROIs loaded from `list_indices` will be written into the - `list_indices_ref` position, thus performing translational registration if - the two lists of ROI indices vary. - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be used as - the basis for the new OME-Zarr image. - new_zarr_url: Path or url to the new OME-Zarr image to be written - ROI_table: Fractal ROI table for the component - ROI_table_ref: Fractal ROI table for the reference acquisition - num_levels: Number of pyramid layers to be created (argument of - `build_pyramid`). - coarsening_xy: Coarsening factor between pyramid levels - aggregation_function: Function to be used when downsampling (argument - of `build_pyramid`). - - """ - # Read pixel sizes from Zarr attributes - ngff_image_meta = load_NgffImageMeta(zarr_url) - pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) - - # Create list of indices for 3D ROIs - list_indices = convert_ROI_table_to_indices( - ROI_table, - level=0, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=pxl_sizes_zyx, - ) - list_indices_ref = convert_ROI_table_to_indices( - ROI_table_ref, - level=0, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=pxl_sizes_zyx, - ) - - old_image_group = zarr.open_group(zarr_url, mode="r") - old_ngff_image_meta = load_NgffImageMeta(zarr_url) - new_image_group = zarr.group(new_zarr_url) - new_image_group.attrs.put(old_image_group.attrs.asdict()) - - # Loop over all channels. For each channel, write full-res image data. - data_array = da.from_zarr(old_image_group["0"]) - # Create dask array with 0s of same shape - new_array = da.zeros_like(data_array) - - # TODO: Add sanity checks on the 2 ROI tables: - # 1. The number of ROIs need to match - # 2. The size of the ROIs need to match - # (otherwise, we can't assign them to the reference regions) - # ROI_table_ref vs ROI_table_acq - for i, roi_indices in enumerate(list_indices): - reference_region = convert_indices_to_regions(list_indices_ref[i]) - region = convert_indices_to_regions(roi_indices) - - axes_list = old_ngff_image_meta.axes_names - - if axes_list == ["c", "z", "y", "x"]: - num_channels = data_array.shape[0] - # Loop over channels - for ind_ch in range(num_channels): - idx = tuple( - [slice(ind_ch, ind_ch + 1)] + list(reference_region) - ) - new_array[idx] = load_region( - data_zyx=data_array[ind_ch], region=region, compute=False - ) - elif axes_list == ["z", "y", "x"]: - new_array[reference_region] = load_region( - data_zyx=data_array, region=region, compute=False - ) - elif axes_list == ["c", "y", "x"]: - # TODO: Implement cyx case (based on looping over xy case) - raise NotImplementedError( - "`write_registered_zarr` has not been implemented for " - f"a zarr with {axes_list=}" - ) - elif axes_list == ["y", "x"]: - # TODO: Implement yx case - raise NotImplementedError( - "`write_registered_zarr` has not been implemented for " - f"a zarr with {axes_list=}" - ) - else: - raise NotImplementedError( - "`write_registered_zarr` has not been implemented for " - f"a zarr with {axes_list=}" - ) - - new_array.to_zarr( - f"{new_zarr_url}/0", - overwrite=True, - dimension_separator="/", - write_empty_chunks=False, - ) - - # Starting from on-disk highest-resolution data, build and write to - # disk a pyramid of coarser levels - build_pyramid( - zarrurl=new_zarr_url, - overwrite=True, - num_levels=num_levels, - coarsening_xy=coarsening_xy, - chunksize=data_array.chunksize, - aggregation_function=aggregation_function, - ) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=apply_registration_to_image, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/calculate_registration_image_based.py b/fractal_tasks_core/tasks/calculate_registration_image_based.py deleted file mode 100755 index 6bb783662..000000000 --- a/fractal_tasks_core/tasks/calculate_registration_image_based.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Joel Lüthi -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Calculates translation for image-based registration -""" -import logging -from enum import Enum - -import anndata as ad -import dask.array as da -import numpy as np -import zarr -from pydantic import validate_call -from skimage.exposure import rescale_intensity -from skimage.registration import phase_cross_correlation - -from fractal_tasks_core.channels import get_channel_from_image_zarr -from fractal_tasks_core.channels import OmeroChannel -from fractal_tasks_core.ngff import load_NgffImageMeta -from fractal_tasks_core.roi import check_valid_ROI_indices -from fractal_tasks_core.roi import ( - convert_indices_to_regions, -) -from fractal_tasks_core.roi import ( - convert_ROI_table_to_indices, -) -from fractal_tasks_core.roi import load_region -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks._registration_utils import ( - calculate_physical_shifts, -) -from fractal_tasks_core.tasks._registration_utils import chi2_shift_out -from fractal_tasks_core.tasks._registration_utils import ( - get_ROI_table_with_translation, -) -from fractal_tasks_core.tasks._registration_utils import is_3D -from fractal_tasks_core.tasks.io_models import InitArgsRegistration - -logger = logging.getLogger(__name__) - - -class RegistrationMethod(Enum): - """ - RegistrationMethod Enum class - - Attributes: - PHASE_CROSS_CORRELATION: phase cross correlation based on scikit-image - (works with 2D & 3D images). - CHI2_SHIFT: chi2 shift based on image-registration library - (only works with 2D images). - """ - - PHASE_CROSS_CORRELATION = "phase_cross_correlation" - CHI2_SHIFT = "chi2_shift" - - def register(self, img_ref, img_acq_x): - if self == RegistrationMethod.PHASE_CROSS_CORRELATION: - return phase_cross_correlation(img_ref, img_acq_x) - elif self == RegistrationMethod.CHI2_SHIFT: - return chi2_shift_out(img_ref, img_acq_x) - - -@validate_call -def calculate_registration_image_based( - *, - # Fractal arguments - zarr_url: str, - init_args: InitArgsRegistration, - # Core parameters - wavelength_id: str, - method: RegistrationMethod = RegistrationMethod.PHASE_CROSS_CORRELATION, - lower_rescale_quantile: float = 0.0, - upper_rescale_quantile: float = 0.99, - roi_table: str = "FOV_ROI_table", - level: int = 2, -) -> None: - """ - Calculate registration based on images - - This task consists of 3 parts: - - 1. Loading the images of a given ROI (=> loop over ROIs) - 2. Calculating the transformation for that ROI - 3. Storing the calculated transformation in the ROI table - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - (standard argument for Fractal tasks, managed by Fractal server). - init_args: Intialization arguments provided by - `image_based_registration_hcs_init`. They contain the - reference_zarr_url that is used for registration. - (standard argument for Fractal tasks, managed by Fractal server). - wavelength_id: Wavelength that will be used for image-based - registration; e.g. `A01_C01` for Yokogawa, `C01` for MD. - method: Method to use for image registration. The available methods - are `phase_cross_correlation` (scikit-image package, works for 2D - & 3D) and "chi2_shift" (image_registration package, only works for - 2D images). - lower_rescale_quantile: Lower quantile for rescaling the image - intensities before applying registration. Can be helpful - to deal with image artifacts. Default is 0. - upper_rescale_quantile: Upper quantile for rescaling the image - intensities before applying registration. Can be helpful - to deal with image artifacts. Default is 0.99. - roi_table: Name of the ROI table over which the task loops to - calculate the registration. Examples: `FOV_ROI_table` => loop over - the field of views, `well_ROI_table` => process the whole well as - one image. - level: Pyramid level of the image to be used for registration. - Choose `0` to process at full resolution. - - """ - logger.info( - f"Running for {zarr_url=}.\n" - f"Calculating translation registration per {roi_table=} for " - f"{wavelength_id=}." - ) - - init_args.reference_zarr_url = init_args.reference_zarr_url - - # Read some parameters from Zarr metadata - ngff_image_meta = load_NgffImageMeta(str(init_args.reference_zarr_url)) - coarsening_xy = ngff_image_meta.coarsening_xy - - # Get channel_index via wavelength_id. - # Intially only allow registration of the same wavelength - channel_ref: OmeroChannel = get_channel_from_image_zarr( - image_zarr_path=init_args.reference_zarr_url, - wavelength_id=wavelength_id, - ) - channel_index_ref = channel_ref.index - - channel_align: OmeroChannel = get_channel_from_image_zarr( - image_zarr_path=zarr_url, - wavelength_id=wavelength_id, - ) - channel_index_align = channel_align.index - - # Lazily load zarr array - data_reference_zyx = da.from_zarr( - f"{init_args.reference_zarr_url}/{level}" - )[channel_index_ref] - data_alignment_zyx = da.from_zarr(f"{zarr_url}/{level}")[ - channel_index_align - ] - - # Check if data is 3D (as not all registration methods work in 3D) - # TODO: Abstract this check into a higher-level Zarr loading class - if is_3D(data_reference_zyx): - if method == RegistrationMethod(RegistrationMethod.CHI2_SHIFT): - raise ValueError( - f"The `{RegistrationMethod.CHI2_SHIFT}` registration method " - "has not been implemented for 3D images and the input image " - f"had a shape of {data_reference_zyx.shape}." - ) - - # Read ROIs - ROI_table_ref = ad.read_zarr( - f"{init_args.reference_zarr_url}/tables/{roi_table}" - ) - ROI_table_x = ad.read_zarr(f"{zarr_url}/tables/{roi_table}") - logger.info( - f"Found {len(ROI_table_x)} ROIs in {roi_table=} to be processed." - ) - - # Check that table type of ROI_table_ref is valid. Note that - # "ngff:region_table" and None are accepted for backwards compatibility - valid_table_types = [ - "roi_table", - "masking_roi_table", - "ngff:region_table", - None, - ] - ROI_table_ref_group = zarr.open_group( - f"{init_args.reference_zarr_url}/tables/{roi_table}", - mode="r", - ) - ref_table_attrs = ROI_table_ref_group.attrs.asdict() - ref_table_type = ref_table_attrs.get("type") - if ref_table_type not in valid_table_types: - raise ValueError( - ( - f"Table '{roi_table}' (with type '{ref_table_type}') is " - "not a valid ROI table." - ) - ) - - # For each acquisition, get the relevant info - # TODO: Add additional checks on ROIs? - if (ROI_table_ref.obs.index != ROI_table_x.obs.index).all(): - raise ValueError( - "Registration is only implemented for ROIs that match between the " - "acquisitions (e.g. well, FOV ROIs). Here, the ROIs in the " - f"reference acquisitions were {ROI_table_ref.obs.index}, but the " - f"ROIs in the alignment acquisition were {ROI_table_x.obs.index}" - ) - # TODO: Make this less restrictive? i.e. could we also run it if different - # acquisitions have different FOVs? But then how do we know which FOVs to - # match? - # If we relax this, downstream assumptions on matching based on order - # in the list will break. - - # Read pixel sizes from zarr attributes - ngff_image_meta_acq_x = load_NgffImageMeta(zarr_url) - pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) - pxl_sizes_zyx_acq_x = ngff_image_meta_acq_x.get_pixel_sizes_zyx(level=0) - - if pxl_sizes_zyx != pxl_sizes_zyx_acq_x: - raise ValueError( - "Pixel sizes need to be equal between acquisitions for " - "registration." - ) - - # Create list of indices for 3D ROIs spanning the entire Z direction - list_indices_ref = convert_ROI_table_to_indices( - ROI_table_ref, - level=level, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=pxl_sizes_zyx, - ) - check_valid_ROI_indices(list_indices_ref, roi_table) - - list_indices_acq_x = convert_ROI_table_to_indices( - ROI_table_x, - level=level, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=pxl_sizes_zyx, - ) - check_valid_ROI_indices(list_indices_acq_x, roi_table) - - num_ROIs = len(list_indices_ref) - compute = True - new_shifts = {} - for i_ROI in range(num_ROIs): - logger.info( - f"Now processing ROI {i_ROI+1}/{num_ROIs} " - f"for channel {channel_align}." - ) - img_ref = load_region( - data_zyx=data_reference_zyx, - region=convert_indices_to_regions(list_indices_ref[i_ROI]), - compute=compute, - ) - img_acq_x = load_region( - data_zyx=data_alignment_zyx, - region=convert_indices_to_regions(list_indices_acq_x[i_ROI]), - compute=compute, - ) - - # Rescale the images - img_ref = rescale_intensity( - img_ref, - in_range=( - np.quantile(img_ref, lower_rescale_quantile), - np.quantile(img_ref, upper_rescale_quantile), - ), - ) - img_acq_x = rescale_intensity( - img_acq_x, - in_range=( - np.quantile(img_acq_x, lower_rescale_quantile), - np.quantile(img_acq_x, upper_rescale_quantile), - ), - ) - - ############## - # Calculate the transformation - ############## - if img_ref.shape != img_acq_x.shape: - raise NotImplementedError( - "This registration is not implemented for ROIs with " - "different shapes between acquisitions." - ) - - shifts = method.register(np.squeeze(img_ref), np.squeeze(img_acq_x))[0] - - ############## - # Store the calculated transformation ### - ############## - # Adapt ROIs for the given ROI table: - ROI_name = ROI_table_ref.obs.index[i_ROI] - new_shifts[ROI_name] = calculate_physical_shifts( - shifts, - level=level, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=pxl_sizes_zyx, - ) - - # Write physical shifts to disk (as part of the ROI table) - logger.info(f"Updating the {roi_table=} with translation columns") - image_group = zarr.group(zarr_url) - new_ROI_table = get_ROI_table_with_translation(ROI_table_x, new_shifts) - write_table( - image_group, - roi_table, - new_ROI_table, - overwrite=True, - table_attrs=ref_table_attrs, - ) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=calculate_registration_image_based, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/cellpose_segmentation.py b/fractal_tasks_core/tasks/cellpose_segmentation.py deleted file mode 100644 index 9ee23b8e8..000000000 --- a/fractal_tasks_core/tasks/cellpose_segmentation.py +++ /dev/null @@ -1,627 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Marco Franzon -# Joel Lüthi -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Image segmentation via Cellpose library. -""" -import logging -import os -import time -from typing import Literal -from typing import Optional - -import anndata as ad -import cellpose -import dask.array as da -import numpy as np -import zarr -from cellpose import models -from pydantic import Field -from pydantic import validate_call - -import fractal_tasks_core -from fractal_tasks_core.labels import prepare_label_group -from fractal_tasks_core.masked_loading import masked_loading_wrapper -from fractal_tasks_core.ngff import load_NgffImageMeta -from fractal_tasks_core.pyramids import build_pyramid -from fractal_tasks_core.roi import ( - array_to_bounding_box_table, -) -from fractal_tasks_core.roi import check_valid_ROI_indices -from fractal_tasks_core.roi import ( - convert_ROI_table_to_indices, -) -from fractal_tasks_core.roi import create_roi_table_from_df_list -from fractal_tasks_core.roi import ( - find_overlaps_in_ROI_indices, -) -from fractal_tasks_core.roi import get_overlapping_pairs_3D -from fractal_tasks_core.roi import is_ROI_table_valid -from fractal_tasks_core.roi import load_region -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks.cellpose_utils import ( - _normalize_cellpose_channels, -) -from fractal_tasks_core.tasks.cellpose_utils import ( - CellposeChannel1InputModel, -) -from fractal_tasks_core.tasks.cellpose_utils import ( - CellposeChannel2InputModel, -) -from fractal_tasks_core.tasks.cellpose_utils import ( - CellposeCustomNormalizer, -) -from fractal_tasks_core.tasks.cellpose_utils import ( - CellposeModelParams, -) -from fractal_tasks_core.utils import rescale_datasets - -logger = logging.getLogger(__name__) - -__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ - - -def segment_ROI( - x: np.ndarray, - num_labels_tot: dict[str, int], - model: models.CellposeModel = None, - do_3D: bool = True, - channels: list[int] = [0, 0], - diameter: float = 30.0, - normalize: CellposeCustomNormalizer = CellposeCustomNormalizer(), - normalize2: Optional[CellposeCustomNormalizer] = None, - label_dtype: Optional[np.dtype] = None, - relabeling: bool = True, - advanced_cellpose_model_params: CellposeModelParams = CellposeModelParams(), # noqa: E501 -) -> np.ndarray: - """ - Internal function that runs Cellpose segmentation for a single ROI. - - Args: - x: 4D numpy array. - num_labels_tot: Number of labels already in total image. Used for - relabeling purposes. Using a dict to have a mutable object that - can be edited from within the function without having to be passed - back through the masked_loading_wrapper. - model: An instance of `models.CellposeModel`. - do_3D: If `True`, cellpose runs in 3D mode: runs on xy, xz & yz planes, - then averages the flows. - channels: Which channels to use. If only one channel is provided, `[0, - 0]` should be used. If two channels are provided (the first - dimension of `x` has length of 2), `[1, 2]` should be used - (`x[0, :, :,:]` contains the membrane channel and - `x[1, :, :, :]` contains the nuclear channel). - diameter: Expected object diameter in pixels for cellpose. - normalize: By default, data is normalized so 0.0=1st percentile and - 1.0=99th percentile of image intensities in each channel. - This automatic normalization can lead to issues when the image to - be segmented is very sparse. You can turn off the default - rescaling. With the "custom" option, you can either provide your - own rescaling percentiles or fixed rescaling upper and lower - bound integers. - normalize2: Normalization options for channel 2. If one channel is - normalized with default settings, both channels need to be - normalized with default settings. - label_dtype: Label images are cast into this `np.dtype`. - relabeling: Whether relabeling based on num_labels_tot is performed. - advanced_cellpose_model_params: Advanced Cellpose model parameters - that are passed to the Cellpose `model.eval` method. - """ - - # Write some debugging info - logger.info( - "[segment_ROI] START |" - f" x: {type(x)}, {x.shape} |" - f" {do_3D=} |" - f" {model.diam_mean=} |" - f" {diameter=} |" - f" {advanced_cellpose_model_params.flow_threshold=} |" - f" {normalize.type=}" - ) - x = _normalize_cellpose_channels(x, channels, normalize, normalize2) - - # Actual labeling - t0 = time.perf_counter() - mask, _, _ = model.eval( - x, - channels=channels, - do_3D=do_3D, - net_avg=advanced_cellpose_model_params.net_avg, - augment=advanced_cellpose_model_params.augment, - diameter=diameter, - anisotropy=advanced_cellpose_model_params.anisotropy, - cellprob_threshold=advanced_cellpose_model_params.cellprob_threshold, - flow_threshold=advanced_cellpose_model_params.flow_threshold, - normalize=normalize.cellpose_normalize, - min_size=advanced_cellpose_model_params.min_size, - batch_size=advanced_cellpose_model_params.batch_size, - invert=advanced_cellpose_model_params.invert, - tile=advanced_cellpose_model_params.tile, - tile_overlap=advanced_cellpose_model_params.tile_overlap, - resample=advanced_cellpose_model_params.resample, - interp=advanced_cellpose_model_params.interp, - stitch_threshold=advanced_cellpose_model_params.stitch_threshold, - ) - - if mask.ndim == 2: - # If we get a 2D image, we still return it as a 3D array - mask = np.expand_dims(mask, axis=0) - t1 = time.perf_counter() - - # Write some debugging info - logger.info( - "[segment_ROI] END |" - f" Elapsed: {t1-t0:.3f} s |" - f" {mask.shape=}," - f" {mask.dtype=} (then {label_dtype})," - f" {np.max(mask)=} |" - f" {model.diam_mean=} |" - f" {diameter=} |" - f" {advanced_cellpose_model_params.flow_threshold=}" - ) - - # Shift labels and update relabeling counters - if relabeling: - num_labels_roi = np.max(mask) - mask[mask > 0] += num_labels_tot["num_labels_tot"] - num_labels_tot["num_labels_tot"] += num_labels_roi - - # Write some logs - logger.info(f"ROI had {num_labels_roi=}, {num_labels_tot=}") - - # Check that total number of labels is under control - if num_labels_tot["num_labels_tot"] > np.iinfo(label_dtype).max: - raise ValueError( - "ERROR in re-labeling:" - f"Reached {num_labels_tot} labels, " - f"but dtype={label_dtype}" - ) - - return mask.astype(label_dtype) - - -@validate_call -def cellpose_segmentation( - *, - # Fractal parameters - zarr_url: str, - # Core parameters - level: int, - channel: CellposeChannel1InputModel, - channel2: CellposeChannel2InputModel = Field( - default_factory=CellposeChannel2InputModel - ), - input_ROI_table: str = "FOV_ROI_table", - output_ROI_table: Optional[str] = None, - output_label_name: Optional[str] = None, - # Cellpose-related arguments - diameter_level0: float = 30.0, - # https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/401 # noqa E501 - model_type: Literal[tuple(models.MODEL_NAMES)] = "cyto2", # type: ignore - pretrained_model: Optional[str] = None, - relabeling: bool = True, - use_masks: bool = True, - advanced_cellpose_model_params: CellposeModelParams = Field( - default_factory=CellposeModelParams - ), - overwrite: bool = True, -) -> None: - """ - Run cellpose segmentation on the ROIs of a single OME-Zarr image. - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - (standard argument for Fractal tasks, managed by Fractal server). - level: Pyramid level of the image to be segmented. Choose `0` to - process at full resolution. - channel: Primary channel for segmentation; requires either - `wavelength_id` (e.g. `A01_C01`) or `label` (e.g. `DAPI`), but not - both. Also contains normalization options. By default, data is - normalized so 0.0=1st percentile and 1.0=99th percentile of image - intensities in each channel. - This automatic normalization can lead to issues when the image to - be segmented is very sparse. You can turn off the default - rescaling. With the "custom" option, you can either provide your - own rescaling percentiles or fixed rescaling upper and lower - bound integers. - channel2: Second channel for segmentation (in the same format as - `channel`). If specified, cellpose runs in dual channel mode. - For dual channel segmentation of cells, the first channel should - contain the membrane marker, the second channel should contain the - nuclear marker. - input_ROI_table: Name of the ROI table over which the task loops to - apply Cellpose segmentation. Examples: `FOV_ROI_table` => loop over - the field of views, `organoid_ROI_table` => loop over the organoid - ROI table (generated by another task), `well_ROI_table` => process - the whole well as one image. - output_ROI_table: If provided, a ROI table with that name is created, - which will contain the bounding boxes of the newly segmented - labels. ROI tables should have `ROI` in their name. - output_label_name: Name of the output label image (e.g. `"organoids"`). - diameter_level0: Expected diameter of the objects that should be - segmented in pixels at level 0. Initial diameter is rescaled using - the `level` that was selected. The rescaled value is passed as - the diameter to the `CellposeModel.eval` method. - model_type: Parameter of `CellposeModel` class. Defines which model - should be used. Typical choices are `nuclei`, `cyto`, `cyto2`, etc. - pretrained_model: Parameter of `CellposeModel` class (takes - precedence over `model_type`). Allows you to specify the path of - a custom trained cellpose model. - relabeling: If `True`, apply relabeling so that label values are - unique for all objects in the well. - use_masks: If `True`, try to use masked loading and fall back to - `use_masks=False` if the ROI table is not suitable. Masked - loading is relevant when only a subset of the bounding box should - actually be processed (e.g. running within `organoid_ROI_table`). - advanced_cellpose_model_params: Advanced Cellpose model parameters - that are passed to the Cellpose `model.eval` method. - overwrite: If `True`, overwrite the task output. - """ - logger.info(f"Processing {zarr_url=}") - - # Preliminary checks on Cellpose model - if pretrained_model: - if not os.path.exists(pretrained_model): - raise ValueError(f"{pretrained_model=} does not exist.") - - # Read attributes from NGFF metadata - ngff_image_meta = load_NgffImageMeta(zarr_url) - num_levels = ngff_image_meta.num_levels - coarsening_xy = ngff_image_meta.coarsening_xy - full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) - actual_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=level) - logger.info(f"NGFF image has {num_levels=}") - logger.info(f"NGFF image has {coarsening_xy=}") - logger.info( - f"NGFF image has full-res pixel sizes {full_res_pxl_sizes_zyx}" - ) - logger.info( - f"NGFF image has level-{level} pixel sizes " - f"{actual_res_pxl_sizes_zyx}" - ) - - # Find channel index - omero_channel = channel.get_omero_channel(zarr_url) - if omero_channel: - ind_channel = omero_channel.index - else: - return - - # Find channel index for second channel, if one is provided - if channel2.is_set(): - omero_channel_2 = channel2.get_omero_channel(zarr_url) - if omero_channel_2: - ind_channel_c2 = omero_channel_2.index - else: - return - - # Set channel label - if output_label_name is None: - try: - channel_label = omero_channel.label - output_label_name = f"label_{channel_label}" - except (KeyError, IndexError): - output_label_name = f"label_{ind_channel}" - - # Load ZYX data - # Workaround for #788: Only load channel index when there is a channel - # dimension - if ngff_image_meta.axes_names[0] != "c": - data_zyx = da.from_zarr(f"{zarr_url}/{level}") - if channel2.is_set(): - raise ValueError( - "Dual channel input was specified for an OME-Zarr image " - "without a channel axis" - ) - else: - data_zyx = da.from_zarr(f"{zarr_url}/{level}")[ind_channel] - if channel2.is_set(): - data_zyx_c2 = da.from_zarr(f"{zarr_url}/{level}")[ind_channel_c2] - logger.info(f"Second channel: {data_zyx_c2.shape=}") - logger.info(f"{data_zyx.shape=}") - - # Read ROI table - ROI_table_path = f"{zarr_url}/tables/{input_ROI_table}" - ROI_table = ad.read_zarr(ROI_table_path) - - # Perform some checks on the ROI table - valid_ROI_table = is_ROI_table_valid( - table_path=ROI_table_path, use_masks=use_masks - ) - if use_masks and not valid_ROI_table: - logger.info( - f"ROI table at {ROI_table_path} cannot be used for masked " - "loading. Set use_masks=False." - ) - use_masks = False - logger.info(f"{use_masks=}") - - # Create list of indices for 3D ROIs spanning the entire Z direction - list_indices = convert_ROI_table_to_indices( - ROI_table, - level=level, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, - ) - check_valid_ROI_indices(list_indices, input_ROI_table) - - # If we are not planning to use masked loading, fail for overlapping ROIs - if not use_masks: - overlap = find_overlaps_in_ROI_indices(list_indices) - if overlap: - raise ValueError( - f"ROI indices created from {input_ROI_table} table have " - "overlaps, but we are not using masked loading." - ) - - # Select 2D/3D behavior and set some parameters - do_3D = data_zyx.shape[0] > 1 and len(data_zyx.shape) == 3 - if do_3D: - if advanced_cellpose_model_params.anisotropy is None: - # Compute anisotropy as pixel_size_z/pixel_size_x - advanced_cellpose_model_params.anisotropy = ( - actual_res_pxl_sizes_zyx[0] / actual_res_pxl_sizes_zyx[2] - ) - logger.info(f"Anisotropy: {advanced_cellpose_model_params.anisotropy}") - - # Rescale datasets (only relevant for level>0) - # Workaround for #788 - if ngff_image_meta.axes_names[0] != "c": - new_datasets = rescale_datasets( - datasets=[ds.model_dump() for ds in ngff_image_meta.datasets], - coarsening_xy=coarsening_xy, - reference_level=level, - remove_channel_axis=False, - ) - else: - new_datasets = rescale_datasets( - datasets=[ds.model_dump() for ds in ngff_image_meta.datasets], - coarsening_xy=coarsening_xy, - reference_level=level, - remove_channel_axis=True, - ) - - label_attrs = { - "image-label": { - "version": __OME_NGFF_VERSION__, - "source": {"image": "../../"}, - }, - "multiscales": [ - { - "name": output_label_name, - "version": __OME_NGFF_VERSION__, - "axes": [ - ax.dict() - for ax in ngff_image_meta.multiscale.axes - if ax.type != "channel" - ], - "datasets": new_datasets, - } - ], - } - - image_group = zarr.group(zarr_url) - label_group = prepare_label_group( - image_group, - output_label_name, - overwrite=overwrite, - label_attrs=label_attrs, - logger=logger, - ) - - logger.info( - f"Helper function `prepare_label_group` returned {label_group=}" - ) - logger.info(f"Output label path: {zarr_url}/labels/{output_label_name}/0") - store = zarr.storage.FSStore(f"{zarr_url}/labels/{output_label_name}/0") - label_dtype = np.uint32 - - # Ensure that all output shapes & chunks are 3D (for 2D data: (1, y, x)) - # https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/398 - shape = data_zyx.shape - if len(shape) == 2: - shape = (1, *shape) - chunks = data_zyx.chunksize - if len(chunks) == 2: - chunks = (1, *chunks) - mask_zarr = zarr.create( - shape=shape, - chunks=chunks, - dtype=label_dtype, - store=store, - overwrite=False, - dimension_separator="/", - ) - - logger.info( - f"mask will have shape {data_zyx.shape} " - f"and chunks {data_zyx.chunks}" - ) - - # Initialize cellpose - gpu = advanced_cellpose_model_params.use_gpu and cellpose.core.use_gpu() - if pretrained_model: - model = models.CellposeModel( - gpu=gpu, pretrained_model=pretrained_model - ) - else: - model = models.CellposeModel(gpu=gpu, model_type=model_type) - - # Initialize other things - logger.info(f"Start cellpose_segmentation task for {zarr_url}") - logger.info(f"relabeling: {relabeling}") - logger.info(f"do_3D: {do_3D}") - logger.info(f"use_gpu: {gpu}") - logger.info(f"level: {level}") - logger.info(f"model_type: {model_type}") - logger.info(f"pretrained_model: {pretrained_model}") - logger.info(f"anisotropy: {advanced_cellpose_model_params.anisotropy}") - logger.info("Total well shape/chunks:") - logger.info(f"{data_zyx.shape}") - logger.info(f"{data_zyx.chunks}") - if channel2.is_set(): - logger.info("Dual channel input for cellpose model") - logger.info(f"{data_zyx_c2.shape}") - logger.info(f"{data_zyx_c2.chunks}") - - # Counters for relabeling - num_labels_tot = {"num_labels_tot": 0} - - # Iterate over ROIs - num_ROIs = len(list_indices) - - if output_ROI_table: - bbox_dataframe_list = [] - - logger.info(f"Now starting loop over {num_ROIs} ROIs") - for i_ROI, indices in enumerate(list_indices): - # Define region - s_z, e_z, s_y, e_y, s_x, e_x = indices[:] - region = ( - slice(s_z, e_z), - slice(s_y, e_y), - slice(s_x, e_x), - ) - logger.info(f"Now processing ROI {i_ROI+1}/{num_ROIs}") - - # Prepare single-channel or dual-channel input for cellpose - if channel2.is_set(): - # Dual channel mode, first channel is the membrane channel - img_1 = load_region( - data_zyx, - region, - compute=True, - return_as_3D=True, - ) - img_np = np.zeros((2, *img_1.shape)) - img_np[0, :, :, :] = img_1 - img_np[1, :, :, :] = load_region( - data_zyx_c2, - region, - compute=True, - return_as_3D=True, - ) - channels = [1, 2] - else: - img_np = np.expand_dims( - load_region(data_zyx, region, compute=True, return_as_3D=True), - axis=0, - ) - channels = [0, 0] - - # Prepare keyword arguments for segment_ROI function - kwargs_segment_ROI = dict( - num_labels_tot=num_labels_tot, - model=model, - channels=channels, - do_3D=do_3D, - label_dtype=label_dtype, - diameter=diameter_level0 / coarsening_xy**level, - normalize=channel.normalize, - normalize2=channel2.normalize, - relabeling=relabeling, - advanced_cellpose_model_params=advanced_cellpose_model_params, - ) - - # Prepare keyword arguments for preprocessing function - preprocessing_kwargs = {} - if use_masks: - preprocessing_kwargs = dict( - region=region, - current_label_path=f"{zarr_url}/labels/{output_label_name}/0", - ROI_table_path=ROI_table_path, - ROI_positional_index=i_ROI, - ) - - # Call segment_ROI through the masked-loading wrapper, which includes - # pre/post-processing functions if needed - new_label_img = masked_loading_wrapper( - image_array=img_np, - function=segment_ROI, - kwargs=kwargs_segment_ROI, - use_masks=use_masks, - preprocessing_kwargs=preprocessing_kwargs, - ) - - if output_ROI_table: - bbox_df = array_to_bounding_box_table( - new_label_img, - actual_res_pxl_sizes_zyx, - origin_zyx=(s_z, s_y, s_x), - ) - - bbox_dataframe_list.append(bbox_df) - - overlap_list = get_overlapping_pairs_3D( - bbox_df, full_res_pxl_sizes_zyx - ) - if len(overlap_list) > 0: - logger.warning( - f"ROI {indices} has " - f"{len(overlap_list)} bounding-box pairs overlap" - ) - - # Compute and store 0-th level to disk - da.array(new_label_img).to_zarr( - url=mask_zarr, - region=region, - compute=True, - ) - - logger.info( - f"End cellpose_segmentation task for {zarr_url}, " - "now building pyramids." - ) - - # Starting from on-disk highest-resolution data, build and write to disk a - # pyramid of coarser levels - build_pyramid( - zarrurl=f"{zarr_url}/labels/{output_label_name}", - overwrite=overwrite, - num_levels=num_levels, - coarsening_xy=coarsening_xy, - chunksize=chunks, - aggregation_function=np.max, - ) - - logger.info("End building pyramids") - - if output_ROI_table: - bbox_table = create_roi_table_from_df_list(bbox_dataframe_list) - - # Write to zarr group - image_group = zarr.group(zarr_url) - logger.info( - "Now writing bounding-box ROI table to " - f"{zarr_url}/tables/{output_ROI_table}" - ) - table_attrs = { - "type": "masking_roi_table", - "region": {"path": f"../labels/{output_label_name}"}, - "instance_key": "label", - } - write_table( - image_group, - output_ROI_table, - bbox_table, - overwrite=overwrite, - table_attrs=table_attrs, - ) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=cellpose_segmentation, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/cellpose_utils.py b/fractal_tasks_core/tasks/cellpose_utils.py deleted file mode 100644 index 6ba22a29e..000000000 --- a/fractal_tasks_core/tasks/cellpose_utils.py +++ /dev/null @@ -1,468 +0,0 @@ -# Copyright 2023 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Joel Lüthi -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Helper functions for image normalization in -""" -import logging -from typing import Literal -from typing import Optional - -import numpy as np -from pydantic import BaseModel -from pydantic import Field -from pydantic import model_validator -from typing_extensions import Self - -from fractal_tasks_core.channels import ChannelInputModel -from fractal_tasks_core.channels import ChannelNotFoundError -from fractal_tasks_core.channels import get_channel_from_image_zarr -from fractal_tasks_core.channels import OmeroChannel - - -logger = logging.getLogger(__name__) - - -class CellposeCustomNormalizer(BaseModel): - """ - Validator to handle different normalization scenarios for Cellpose models - - If `type="default"`, then Cellpose default normalization is - used and no other parameters can be specified. - If `type="no_normalization"`, then no normalization is used and no - other parameters can be specified. - If `type="custom"`, then either percentiles or explicit integer - bounds can be applied. - - Attributes: - type: - One of `default` (Cellpose default normalization), `custom` - (using the other custom parameters) or `no_normalization`. - lower_percentile: Specify a custom lower-bound percentile for rescaling - as a float value between 0 and 100. Set to 1 to run the same as - default). You can only specify percentiles or bounds, not both. - upper_percentile: Specify a custom upper-bound percentile for rescaling - as a float value between 0 and 100. Set to 99 to run the same as - default, set to e.g. 99.99 if the default rescaling was too harsh. - You can only specify percentiles or bounds, not both. - lower_bound: Explicit lower bound value to rescale the image at. - Needs to be an integer, e.g. 100. - You can only specify percentiles or bounds, not both. - upper_bound: Explicit upper bound value to rescale the image at. - Needs to be an integer, e.g. 2000. - You can only specify percentiles or bounds, not both. - """ - - type: Literal["default", "custom", "no_normalization"] = "default" - lower_percentile: Optional[float] = Field(None, ge=0, le=100) - upper_percentile: Optional[float] = Field(None, ge=0, le=100) - lower_bound: Optional[int] = None - upper_bound: Optional[int] = None - - # In the future, add an option to allow using precomputed percentiles - # that are stored in OME-Zarr histograms and use this pydantic model that - # those histograms actually exist - - @model_validator(mode="after") - def validate_conditions(self: Self) -> Self: - # Extract values - type = self.type - lower_percentile = self.lower_percentile - upper_percentile = self.upper_percentile - lower_bound = self.lower_bound - upper_bound = self.upper_bound - - # Verify that custom parameters are only provided when type="custom" - if type != "custom": - if lower_percentile is not None: - raise ValueError( - f"Type='{type}' but {lower_percentile=}. " - "Hint: set type='custom'." - ) - if upper_percentile is not None: - raise ValueError( - f"Type='{type}' but {upper_percentile=}. " - "Hint: set type='custom'." - ) - if lower_bound is not None: - raise ValueError( - f"Type='{type}' but {lower_bound=}. " - "Hint: set type='custom'." - ) - if upper_bound is not None: - raise ValueError( - f"Type='{type}' but {upper_bound=}. " - "Hint: set type='custom'." - ) - - # The only valid options are: - # 1. Both percentiles are set and both bounds are unset - # 2. Both bounds are set and both percentiles are unset - are_percentiles_set = ( - lower_percentile is not None, - upper_percentile is not None, - ) - are_bounds_set = ( - lower_bound is not None, - upper_bound is not None, - ) - if len(set(are_percentiles_set)) != 1: - raise ValueError( - "Both lower_percentile and upper_percentile must be set " - "together." - ) - if len(set(are_bounds_set)) != 1: - raise ValueError( - "Both lower_bound and upper_bound must be set together" - ) - if lower_percentile is not None and lower_bound is not None: - raise ValueError( - "You cannot set both explicit bounds and percentile bounds " - "at the same time. Hint: use only one of the two options." - ) - - return self - - @property - def cellpose_normalize(self) -> bool: - """ - Determine whether cellpose should apply its internal normalization. - - If type is set to `custom` or `no_normalization`, don't apply cellpose - internal normalization - """ - return self.type == "default" - - -class CellposeModelParams(BaseModel): - """ - Advanced Cellpose Model Parameters - - Attributes: - cellprob_threshold: Parameter of `CellposeModel.eval` method. Valid - values between -6 to 6. From Cellpose documentation: "Decrease this - threshold if cellpose is not returning as many ROIs as you'd - expect. Similarly, increase this threshold if cellpose is returning - too ROIs particularly from dim areas." - flow_threshold: Parameter of `CellposeModel.eval` method. Valid - values between 0.0 and 1.0. From Cellpose documentation: "Increase - this threshold if cellpose is not returning as many ROIs as you'd - expect. Similarly, decrease this threshold if cellpose is returning - too many ill-shaped ROIs." - anisotropy: Ratio of the pixel sizes along Z and XY axis (ignored if - the image is not three-dimensional). If unset, it is inferred from - the OME-NGFF metadata. - min_size: Parameter of `CellposeModel` class. Minimum size of the - segmented objects (in pixels). Use `-1` to turn off the size - filter. - augment: Parameter of `CellposeModel` class. Whether to use cellpose - augmentation to tile images with overlap. - net_avg: Parameter of `CellposeModel` class. Whether to use cellpose - net averaging to run the 4 built-in networks (useful for `nuclei`, - `cyto` and `cyto2`, not sure it works for the others). - use_gpu: If `False`, always use the CPU; if `True`, use the GPU if - possible (as defined in `cellpose.core.use_gpu()`) and fall-back - to the CPU otherwise. - batch_size: number of 224x224 patches to run simultaneously on the GPU - (can make smaller or bigger depending on GPU memory usage) - invert: invert image pixel intensity before running network (if True, - image is also normalized) - tile: tiles image to ensure GPU/CPU memory usage limited (recommended) - tile_overlap: fraction of overlap of tiles when computing flows - resample: run dynamics at original image size (will be slower but - create more accurate boundaries) - interp: interpolate during 2D dynamics (not available in 3D) - (in previous versions it was False, now it defaults to True) - stitch_threshold: if stitch_threshold>0.0 and not do_3D and equal - image sizes, masks are stitched in 3D to return volume segmentation - """ - - cellprob_threshold: float = 0.0 - flow_threshold: float = 0.4 - anisotropy: Optional[float] = None - min_size: int = 15 - augment: bool = False - net_avg: bool = False - use_gpu: bool = True - batch_size: int = 8 - invert: bool = False - tile: bool = True - tile_overlap: float = 0.1 - resample: bool = True - interp: bool = True - stitch_threshold: float = 0.0 - - -class CellposeChannel1InputModel(ChannelInputModel): - """ - Channel input for cellpose with normalization options. - - Attributes: - wavelength_id: Unique ID for the channel wavelength, e.g. `A01_C01`. - Can only be specified if label is not set. - label: Name of the channel. Can only be specified if wavelength_id is - not set. - normalize: Validator to handle different normalization scenarios for - Cellpose models - """ - - normalize: CellposeCustomNormalizer = Field( - default_factory=CellposeCustomNormalizer - ) - - def get_omero_channel(self, zarr_url) -> OmeroChannel: - try: - return get_channel_from_image_zarr( - image_zarr_path=zarr_url, - wavelength_id=self.wavelength_id, - label=self.label, - ) - except ChannelNotFoundError as e: - logger.warning( - f"Channel with wavelength_id: {self.wavelength_id} " - f"and label: {self.label} not found, exit from the task.\n" - f"Original error: {str(e)}" - ) - return None - - -class CellposeChannel2InputModel(BaseModel): - """ - Channel input for secondary cellpose channel with normalization options. - - The secondary channel is Optional, thus both wavelength_id and label are - optional to be set. The `is_set` function shows whether either value was - set. - - Attributes: - wavelength_id: Unique ID for the channel wavelength, e.g. `A01_C01`. - Can only be specified if label is not set. - label: Name of the channel. Can only be specified if wavelength_id is - not set. - normalize: Validator to handle different normalization scenarios for - Cellpose models - """ - - wavelength_id: Optional[str] = None - label: Optional[str] = None - normalize: CellposeCustomNormalizer = Field( - default_factory=CellposeCustomNormalizer - ) - - @model_validator(mode="after") - def mutually_exclusive_channel_attributes(self: Self) -> Self: - """ - Check that only 1 of `label` or `wavelength_id` is set. - """ - wavelength_id = self.wavelength_id - label = self.label - if (wavelength_id is not None) and (label is not None): - raise ValueError( - "`wavelength_id` and `label` cannot be both set " - f"(given {wavelength_id=} and {label=})." - ) - return self - - def is_set(self): - if self.wavelength_id or self.label: - return True - return False - - def get_omero_channel(self, zarr_url) -> OmeroChannel: - try: - return get_channel_from_image_zarr( - image_zarr_path=zarr_url, - wavelength_id=self.wavelength_id, - label=self.label, - ) - except ChannelNotFoundError as e: - logger.warning( - f"Second channel with wavelength_id: {self.wavelength_id} " - f"and label: {self.label} not found, exit from the task.\n" - f"Original error: {str(e)}" - ) - return None - - -def _normalize_cellpose_channels( - x: np.ndarray, - channels: list[int], - normalize: CellposeCustomNormalizer, - normalize2: CellposeCustomNormalizer, -) -> np.ndarray: - """ - Normalize a cellpose input array by channel. - - Args: - x: 4D numpy array. - channels: Which channels to use. If only one channel is provided, `[0, - 0]` should be used. If two channels are provided (the first - dimension of `x` has length of 2), `[1, 2]` should be used - (`x[0, :, :,:]` contains the membrane channel and - `x[1, :, :, :]` contains the nuclear channel). - normalize: By default, data is normalized so 0.0=1st percentile and - 1.0=99th percentile of image intensities in each channel. - This automatic normalization can lead to issues when the image to - be segmented is very sparse. You can turn off the default - rescaling. With the "custom" option, you can either provide your - own rescaling percentiles or fixed rescaling upper and lower - bound integers. - normalize2: Normalization options for channel 2. If one channel is - normalized with default settings, both channels need to be - normalized with default settings. - - """ - # Optionally perform custom normalization - # normalize channels separately, if normalize2 is provided: - if channels == [1, 2]: - # If run in single channel mode, fails (specified as channel = [0, 0]) - if (normalize.type == "default") != (normalize2.type == "default"): - raise ValueError( - "ERROR in normalization:" - f" {normalize.type=} and {normalize2.type=}." - " Either both need to be 'default', or none of them." - ) - if normalize.type == "custom": - x[channels[0] - 1 : channels[0]] = normalized_img( - x[channels[0] - 1 : channels[0]], - lower_p=normalize.lower_percentile, - upper_p=normalize.upper_percentile, - lower_bound=normalize.lower_bound, - upper_bound=normalize.upper_bound, - ) - if normalize2.type == "custom": - x[channels[1] - 1 : channels[1]] = normalized_img( - x[channels[1] - 1 : channels[1]], - lower_p=normalize2.lower_percentile, - upper_p=normalize2.upper_percentile, - lower_bound=normalize2.lower_bound, - upper_bound=normalize2.upper_bound, - ) - - # otherwise, use first normalize to normalize all channels: - else: - if normalize.type == "custom": - x = normalized_img( - x, - lower_p=normalize.lower_percentile, - upper_p=normalize.upper_percentile, - lower_bound=normalize.lower_bound, - upper_bound=normalize.upper_bound, - ) - - return x - - -def normalized_img( - img: np.ndarray, - axis: int = -1, - invert: bool = False, - lower_p: float = 1.0, - upper_p: float = 99.0, - lower_bound: Optional[int] = None, - upper_bound: Optional[int] = None, -): - """normalize each channel of the image so that so that 0.0=lower percentile - or lower bound and 1.0=upper percentile or upper bound of image intensities. - - The normalization can result in values < 0 or > 1 (no clipping). - - Based on https://github.com/MouseLand/cellpose/blob/4f5661983c3787efa443bbbd3f60256f4fd8bf53/cellpose/transforms.py#L375 # noqa E501 - - optional inversion - - Parameters - ------------ - - img: ND-array (at least 3 dimensions) - - axis: channel axis to loop over for normalization - - invert: invert image (useful if cells are dark instead of bright) - - lower_p: Lower percentile for rescaling - - upper_p: Upper percentile for rescaling - - lower_bound: Lower fixed-value used for rescaling - - upper_bound: Upper fixed-value used for rescaling - - Returns - --------------- - - img: ND-array, float32 - normalized image of same size - - """ - if img.ndim < 3: - error_message = "Image needs to have at least 3 dimensions" - logger.critical(error_message) - raise ValueError(error_message) - - img = img.astype(np.float32) - img = np.moveaxis(img, axis, 0) - for k in range(img.shape[0]): - if lower_p is not None: - # ptp can still give nan's with weird images - i99 = np.percentile(img[k], upper_p) - i1 = np.percentile(img[k], lower_p) - if i99 - i1 > +1e-3: # np.ptp(img[k]) > 1e-3: - img[k] = normalize_percentile( - img[k], lower=lower_p, upper=upper_p - ) - if invert: - img[k] = -1 * img[k] + 1 - else: - img[k] = 0 - elif lower_bound is not None: - if upper_bound - lower_bound > +1e-3: - img[k] = normalize_bounds( - img[k], lower=lower_bound, upper=upper_bound - ) - if invert: - img[k] = -1 * img[k] + 1 - else: - img[k] = 0 - else: - raise ValueError("No normalization method specified") - img = np.moveaxis(img, 0, axis) - return img - - -def normalize_percentile(Y: np.ndarray, lower: float = 1, upper: float = 99): - """normalize image so 0.0 is lower percentile and 1.0 is upper percentile - Percentiles are passed as floats (must be between 0 and 100) - - Args: - Y: The image to be normalized - lower: Lower percentile - upper: Upper percentile - - """ - X = Y.copy() - x01 = np.percentile(X, lower) - x99 = np.percentile(X, upper) - X = (X - x01) / (x99 - x01) - return X - - -def normalize_bounds(Y: np.ndarray, lower: int = 0, upper: int = 65535): - """normalize image so 0.0 is lower value and 1.0 is upper value - - Args: - Y: The image to be normalized - lower: Lower normalization value - upper: Upper normalization value - - """ - X = Y.copy() - X = (X - lower) / (upper - lower) - return X diff --git a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_compute.py b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_compute.py deleted file mode 100644 index a316e07db..000000000 --- a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_compute.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Marco Franzon -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Task that writes image data to an existing OME-NGFF zarr array. -""" -import logging - -import dask.array as da -import zarr -from anndata import read_zarr -from dask.array.image import imread -from pydantic import validate_call - -from fractal_tasks_core.cellvoyager.filenames import ( - glob_with_multiple_patterns, -) -from fractal_tasks_core.cellvoyager.filenames import parse_filename -from fractal_tasks_core.channels import get_omero_channel_list -from fractal_tasks_core.channels import OmeroChannel -from fractal_tasks_core.ngff import load_NgffImageMeta -from fractal_tasks_core.pyramids import build_pyramid -from fractal_tasks_core.roi import check_valid_ROI_indices -from fractal_tasks_core.roi import ( - convert_ROI_table_to_indices, -) -from fractal_tasks_core.tasks.io_models import InitArgsCellVoyager - - -logger = logging.getLogger(__name__) - - -def sort_fun(filename: str) -> list[int]: - """ - Takes a string (filename of a Yokogawa image), extract site and - z-index metadata and returns them as a list of integers. - - Args: - filename: Name of the image file. - """ - - filename_metadata = parse_filename(filename) - site = int(filename_metadata["F"]) - z_index = int(filename_metadata["Z"]) - return [site, z_index] - - -@validate_call -def cellvoyager_to_ome_zarr_compute( - *, - # Fractal parameters - zarr_url: str, - init_args: InitArgsCellVoyager, -): - """ - Convert Yokogawa output (png, tif) to zarr file. - - This task is run after an init task (typically - `cellvoyager_to_ome_zarr_init` or - `cellvoyager_to_ome_zarr_init_multiplex`), and it populates the empty - OME-Zarr files that were prepared. - - Note that the current task always overwrites existing data. To avoid this - behavior, set the `overwrite` argument of the init task to `False`. - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - (standard argument for Fractal tasks, managed by Fractal server). - init_args: Intialization arguments provided by - `create_cellvoyager_ome_zarr_init`. - """ - zarr_url = zarr_url.rstrip("/") - # Read attributes from NGFF metadata - ngff_image_meta = load_NgffImageMeta(zarr_url) - num_levels = ngff_image_meta.num_levels - coarsening_xy = ngff_image_meta.coarsening_xy - full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) - logger.info(f"NGFF image has {num_levels=}") - logger.info(f"NGFF image has {coarsening_xy=}") - logger.info( - f"NGFF image has full-res pixel sizes {full_res_pxl_sizes_zyx}" - ) - - channels: list[OmeroChannel] = get_omero_channel_list( - image_zarr_path=zarr_url - ) - wavelength_ids = [c.wavelength_id for c in channels] - - # Read useful information from ROI table - adata = read_zarr(f"{zarr_url}/tables/FOV_ROI_table") - fov_indices = convert_ROI_table_to_indices( - adata, - full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, - ) - check_valid_ROI_indices(fov_indices, "FOV_ROI_table") - adata_well = read_zarr(f"{zarr_url}/tables/well_ROI_table") - well_indices = convert_ROI_table_to_indices( - adata_well, - full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, - ) - check_valid_ROI_indices(well_indices, "well_ROI_table") - if len(well_indices) > 1: - raise ValueError(f"Something wrong with {well_indices=}") - - max_z = well_indices[0][1] - max_y = well_indices[0][3] - max_x = well_indices[0][5] - - # Load a single image, to retrieve useful information - include_patterns = [ - f"{init_args.plate_prefix}_{init_args.well_ID}_*." - f"{init_args.image_extension}" - ] - if init_args.include_glob_patterns: - include_patterns.extend(init_args.include_glob_patterns) - - exclude_patterns = [] - if init_args.exclude_glob_patterns: - exclude_patterns.extend(init_args.exclude_glob_patterns) - - tmp_images = glob_with_multiple_patterns( - folder=init_args.image_dir, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - sample = imread(tmp_images.pop()) - - # Initialize zarr - chunksize = (1, 1, sample.shape[1], sample.shape[2]) - canvas_zarr = zarr.create( - shape=(len(wavelength_ids), max_z, max_y, max_x), - chunks=chunksize, - dtype=sample.dtype, - store=zarr.storage.FSStore(zarr_url + "/0"), - overwrite=True, - dimension_separator="/", - ) - - # Loop over channels - for i_c, wavelength_id in enumerate(wavelength_ids): - A, C = wavelength_id.split("_") - - include_patterns = [ - f"{init_args.plate_prefix}_{init_args.well_ID}_*{A}*{C}*." - f"{init_args.image_extension}" - ] - if init_args.include_glob_patterns: - include_patterns.extend(init_args.include_glob_patterns) - filenames_set = glob_with_multiple_patterns( - folder=init_args.image_dir, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - filenames = sorted(list(filenames_set), key=sort_fun) - if len(filenames) == 0: - raise ValueError( - "Error in yokogawa_to_ome_zarr: len(filenames)=0.\n" - f" image_dir: {init_args.image_dir}\n" - f" wavelength_id: {wavelength_id},\n" - f" patterns: {include_patterns}\n" - f" exclusion patterns: {exclude_patterns}\n" - ) - # Loop over 3D FOV ROIs - for indices in fov_indices: - s_z, e_z, s_y, e_y, s_x, e_x = indices[:] - region = ( - slice(i_c, i_c + 1), - slice(s_z, e_z), - slice(s_y, e_y), - slice(s_x, e_x), - ) - FOV_3D = da.concatenate( - [imread(img) for img in filenames[:e_z]], - ) - FOV_4D = da.expand_dims(FOV_3D, axis=0) - filenames = filenames[e_z:] - da.array(FOV_4D).to_zarr( - url=canvas_zarr, - region=region, - compute=True, - ) - - # Starting from on-disk highest-resolution data, build and write to disk a - # pyramid of coarser levels - build_pyramid( - zarrurl=zarr_url, - overwrite=True, - num_levels=num_levels, - coarsening_xy=coarsening_xy, - chunksize=chunksize, - ) - - # Generate image list updates - # TODO: Can we check for dimensionality more robustly? Just checks for the - # last FOV of the last wavelength now - if FOV_4D.shape[-3] > 1: - is_3D = True - else: - is_3D = False - # FIXME: Get plate name from zarr_url => works for duplicate plate names - # with suffixes - print(zarr_url) - plate_name = zarr_url.split("/")[-4] - attributes = { - "plate": plate_name, - "well": init_args.well_ID, - } - if init_args.acquisition is not None: - attributes["acquisition"] = init_args.acquisition - - image_list_updates = dict( - image_list_updates=[ - dict( - zarr_url=zarr_url, - attributes=attributes, - types={"is_3D": is_3D}, - ) - ] - ) - - return image_list_updates - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=cellvoyager_to_ome_zarr_compute, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init.py b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init.py deleted file mode 100644 index 113f0113f..000000000 --- a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init.py +++ /dev/null @@ -1,494 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Create structure for OME-NGFF zarr array. -""" -import os -from pathlib import Path -from typing import Any -from typing import Optional - -import pandas as pd -from pydantic import validate_call - -import fractal_tasks_core -from fractal_tasks_core.cellvoyager.filenames import ( - glob_with_multiple_patterns, -) -from fractal_tasks_core.cellvoyager.filenames import parse_filename -from fractal_tasks_core.cellvoyager.metadata import ( - parse_yokogawa_metadata, -) -from fractal_tasks_core.cellvoyager.wells import generate_row_col_split -from fractal_tasks_core.cellvoyager.wells import get_filename_well_id -from fractal_tasks_core.channels import check_unique_wavelength_ids -from fractal_tasks_core.channels import define_omero_channels -from fractal_tasks_core.channels import OmeroChannel -from fractal_tasks_core.ngff.specs import NgffImageMeta -from fractal_tasks_core.ngff.specs import Plate -from fractal_tasks_core.ngff.specs import Well -from fractal_tasks_core.roi import prepare_FOV_ROI_table -from fractal_tasks_core.roi import prepare_well_ROI_table -from fractal_tasks_core.roi import remove_FOV_overlaps -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks.io_models import InitArgsCellVoyager -from fractal_tasks_core.zarr_utils import open_zarr_group_with_overwrite - -__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ - -import logging - -logger = logging.getLogger(__name__) - - -@validate_call -def cellvoyager_to_ome_zarr_init( - *, - # Fractal parameters - zarr_urls: list[str], - zarr_dir: str, - # Core parameters - image_dirs: list[str], - allowed_channels: list[OmeroChannel], - # Advanced parameters - include_glob_patterns: Optional[list[str]] = None, - exclude_glob_patterns: Optional[list[str]] = None, - num_levels: int = 5, - coarsening_xy: int = 2, - image_extension: str = "tif", - metadata_table_file: Optional[str] = None, - overwrite: bool = False, -) -> dict[str, Any]: - """ - Create a OME-NGFF zarr folder, without reading/writing image data. - - Find plates (for each folder in input_paths): - - - glob image files, - - parse metadata from image filename to identify plates, - - identify populated channels. - - Create a zarr folder (for each plate): - - - parse mlf metadata, - - identify wells and field of view (FOV), - - create FOV ZARR, - - verify that channels are uniform (i.e., same channels). - - Args: - zarr_urls: List of paths or urls to the individual OME-Zarr image to - be processed. Not used by the converter task. - (standard argument for Fractal tasks, managed by Fractal server). - zarr_dir: path of the directory where the new OME-Zarrs will be - created. - (standard argument for Fractal tasks, managed by Fractal server). - image_dirs: list of paths to the folders that contains the Cellvoyager - image files. Each entry is a path to a folder that contains the - image files themselves for a multiwell plate and the - MeasurementData & MeasurementDetail metadata files. - allowed_channels: A list of `OmeroChannel` s, where each channel must - include the `wavelength_id` attribute and where the - `wavelength_id` values must be unique across the list. - include_glob_patterns: If specified, only parse images with filenames - that match with all these patterns. Patterns must be defined as in - https://docs.python.org/3/library/fnmatch.html, Example: - `image_glob_pattern=["*_B03_*"]` => only process well B03 - `image_glob_pattern=["*_C09_*", "*F016*", "*Z[0-5][0-9]C*"]` => - only process well C09, field of view 16 and Z planes 0-59. - Can interact with exclude_glob_patterns: All included images - all - excluded images gives the final list of images to process - exclude_glob_patterns: If specified, exclude any image where the - filename matches any of the exclusion patterns. Patterns are - specified the same as for include_glob_patterns. - num_levels: Number of resolution-pyramid levels. If set to `5`, there - will be the full-resolution level and 4 levels of - downsampled images. - coarsening_xy: Linear coarsening factor between subsequent levels. - If set to `2`, level 1 is 2x downsampled, level 2 is - 4x downsampled etc. - image_extension: Filename extension of images (e.g. `"tif"` or `"png"`) - metadata_table_file: If `None`, parse Yokogawa metadata from mrf/mlf - files in the input_path folder; else, the full path to a csv file - containing the parsed metadata table. - overwrite: If `True`, overwrite the task output. - - Returns: - A metadata dictionary containing important metadata about the OME-Zarr - plate, the images and some parameters required by downstream tasks - (like `num_levels`). - """ - - # Preliminary checks on metadata_table_file - if metadata_table_file: - if not metadata_table_file.endswith(".csv"): - raise ValueError(f"{metadata_table_file=} is not a csv file") - if not os.path.isfile(metadata_table_file): - raise FileNotFoundError(f"{metadata_table_file=} does not exist") - - # Identify all plates and all channels, across all input folders - plates = [] - actual_wavelength_ids = None - dict_plate_paths = {} - dict_plate_prefixes: dict[str, Any] = {} - - # Preliminary checks on allowed_channels argument - check_unique_wavelength_ids(allowed_channels) - - for image_dir in image_dirs: - # Glob image filenames - include_patterns = [f"*.{image_extension}"] - exclude_patterns = [] - if include_glob_patterns: - include_patterns.extend(include_glob_patterns) - if exclude_glob_patterns: - exclude_patterns.extend(exclude_glob_patterns) - input_filenames = glob_with_multiple_patterns( - folder=image_dir, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - tmp_wavelength_ids = [] - tmp_plates = [] - for fn in input_filenames: - try: - filename_metadata = parse_filename(Path(fn).name) - plate_prefix = filename_metadata["plate_prefix"] - plate = filename_metadata["plate"] - if plate not in dict_plate_prefixes.keys(): - dict_plate_prefixes[plate] = plate_prefix - tmp_plates.append(plate) - A = filename_metadata["A"] - C = filename_metadata["C"] - tmp_wavelength_ids.append(f"A{A}_C{C}") - except ValueError as e: - logger.warning( - f'Skipping "{Path(fn).name}". Original error: ' + str(e) - ) - tmp_plates = sorted(list(set(tmp_plates))) - tmp_wavelength_ids = sorted(list(set(tmp_wavelength_ids))) - - info = ( - "Listing plates/channels:\n" - f"Folder: {image_dir}\n" - f"Include Patterns: {include_patterns}\n" - f"Exclude Patterns: {exclude_patterns}\n" - f"Plates: {tmp_plates}\n" - f"Channels: {tmp_wavelength_ids}\n" - ) - - # Check that only one plate is found - if len(tmp_plates) > 1: - raise ValueError(f"{info}ERROR: {len(tmp_plates)} plates detected") - elif len(tmp_plates) == 0: - raise ValueError(f"{info}ERROR: No plates detected") - plate = tmp_plates[0] - - # If plate already exists in other folder, add suffix - if plate in plates: - ind = 1 - new_plate = f"{plate}_{ind}" - while new_plate in plates: - new_plate = f"{plate}_{ind}" - ind += 1 - logger.info( - f"WARNING: {plate} already exists, renaming it as {new_plate}" - ) - plates.append(new_plate) - dict_plate_prefixes[new_plate] = dict_plate_prefixes[plate] - plate = new_plate - else: - plates.append(plate) - - # Check that channels are the same as in previous plates - if actual_wavelength_ids is None: - actual_wavelength_ids = tmp_wavelength_ids[:] - else: - if actual_wavelength_ids != tmp_wavelength_ids: - raise ValueError( - f"ERROR\n{info}\nERROR:" - f" expected channels {actual_wavelength_ids}" - ) - - # Update dict_plate_paths - dict_plate_paths[plate] = image_dir - - # Check that all channels are in the allowed_channels - allowed_wavelength_ids = [ - channel.wavelength_id for channel in allowed_channels - ] - if not set(actual_wavelength_ids).issubset(set(allowed_wavelength_ids)): - msg = "ERROR in create_ome_zarr\n" - msg += f"actual_wavelength_ids: {actual_wavelength_ids}\n" - msg += f"allowed_wavelength_ids: {allowed_wavelength_ids}\n" - raise ValueError(msg) - - # Create actual_channels, i.e. a list of the channel dictionaries which are - # present - actual_channels = [ - channel - for channel in allowed_channels - if channel.wavelength_id in actual_wavelength_ids - ] - - ################################################################ - # Create well/image OME-Zarr folders on disk, and prepare output - # metadata - parallelization_list = [] - - for plate in plates: - # Define plate zarr - relative_zarrurl = f"{plate}.zarr" - in_path = dict_plate_paths[plate] - logger.info(f"Creating {relative_zarrurl}") - # Call zarr.open_group wrapper, which handles overwrite=True/False - group_plate = open_zarr_group_with_overwrite( - str(Path(zarr_dir) / relative_zarrurl), - overwrite=overwrite, - ) - - # Obtain FOV-metadata dataframe - if metadata_table_file is None: - mrf_path = f"{in_path}/MeasurementDetail.mrf" - mlf_path = f"{in_path}/MeasurementData.mlf" - - site_metadata, number_images_mlf = parse_yokogawa_metadata( - mrf_path, - mlf_path, - include_patterns=include_glob_patterns, - exclude_patterns=exclude_glob_patterns, - ) - site_metadata = remove_FOV_overlaps(site_metadata) - - # If a metadata table was passed, load it and use it directly - else: - logger.warning( - "Since a custom metadata table was provided, there will " - "be no additional check on the number of image files." - ) - site_metadata = pd.read_csv(metadata_table_file) - site_metadata.set_index(["well_id", "FieldIndex"], inplace=True) - - # Extract pixel sizes and bit_depth - pixel_size_z = site_metadata["pixel_size_z"].iloc[0] - pixel_size_y = site_metadata["pixel_size_y"].iloc[0] - pixel_size_x = site_metadata["pixel_size_x"].iloc[0] - bit_depth = site_metadata["bit_depth"].iloc[0] - - if min(pixel_size_z, pixel_size_y, pixel_size_x) < 1e-9: - raise ValueError(pixel_size_z, pixel_size_y, pixel_size_x) - - # Identify all wells - plate_prefix = dict_plate_prefixes[plate] - - include_patterns = [f"{plate_prefix}_*.{image_extension}"] - if include_glob_patterns: - include_patterns.extend(include_glob_patterns) - plate_images = glob_with_multiple_patterns( - folder=str(in_path), - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - wells = [ - parse_filename(os.path.basename(fn))["well"] for fn in plate_images - ] - wells = sorted(list(set(wells))) - - # Verify that all wells have all channels - for well in wells: - include_patterns = [f"{plate_prefix}_{well}_*.{image_extension}"] - if include_glob_patterns: - include_patterns.extend(include_glob_patterns) - well_images = glob_with_multiple_patterns( - folder=str(in_path), - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - # Check number of images matches with expected one - if metadata_table_file is None: - num_images_glob = len(well_images) - num_images_expected = number_images_mlf[well] - if num_images_glob != num_images_expected: - raise ValueError( - f"Wrong number of images for {well=}\n" - f"Expected {num_images_expected} (from mlf file)\n" - f"Found {num_images_glob} files\n" - "Other parameters:\n" - f" {image_extension=}\n" - f" {include_glob_patterns=}\n" - f" {exclude_glob_patterns=}\n" - ) - - well_wavelength_ids = [] - for fpath in well_images: - try: - filename_metadata = parse_filename(os.path.basename(fpath)) - well_wavelength_ids.append( - f"A{filename_metadata['A']}_C{filename_metadata['C']}" - ) - except IndexError: - logger.info(f"Skipping {fpath}") - well_wavelength_ids = sorted(list(set(well_wavelength_ids))) - if well_wavelength_ids != actual_wavelength_ids: - raise ValueError( - f"ERROR: well {well} in plate {plate} (prefix: " - f"{plate_prefix}) has missing channels.\n" - f"Expected: {actual_channels}\n" - f"Found: {well_wavelength_ids}.\n" - ) - - well_rows_columns = generate_row_col_split(wells) - - row_list = [ - well_row_column[0] for well_row_column in well_rows_columns - ] - col_list = [ - well_row_column[1] for well_row_column in well_rows_columns - ] - row_list = sorted(list(set(row_list))) - col_list = sorted(list(set(col_list))) - - plate_attrs = { - "acquisitions": [{"id": 0, "name": plate}], - "columns": [{"name": col} for col in col_list], - "rows": [{"name": row} for row in row_list], - "version": __OME_NGFF_VERSION__, - "wells": [ - { - "path": well_row_column[0] + "/" + well_row_column[1], - "rowIndex": row_list.index(well_row_column[0]), - "columnIndex": col_list.index(well_row_column[1]), - } - for well_row_column in well_rows_columns - ], - } - - # Validate plate attrs: - Plate(**plate_attrs) - - group_plate.attrs["plate"] = plate_attrs - - for row, column in well_rows_columns: - parallelization_list.append( - { - "zarr_url": f"{zarr_dir}/{plate}.zarr/{row}/{column}/0", - "init_args": InitArgsCellVoyager( - image_dir=in_path, - plate_prefix=plate_prefix, - well_ID=get_filename_well_id(row, column), - image_extension=image_extension, - include_glob_patterns=include_glob_patterns, - exclude_glob_patterns=exclude_glob_patterns, - ).model_dump(), - } - ) - group_well = group_plate.create_group(f"{row}/{column}/") - - well_attrs = { - "images": [{"path": "0"}], - "version": __OME_NGFF_VERSION__, - } - - # Validate well attrs: - Well(**well_attrs) - group_well.attrs["well"] = well_attrs - - group_image = group_well.create_group("0") # noqa: F841 - group_image.attrs["multiscales"] = [ - { - "version": __OME_NGFF_VERSION__, - "axes": [ - {"name": "c", "type": "channel"}, - { - "name": "z", - "type": "space", - "unit": "micrometer", - }, - { - "name": "y", - "type": "space", - "unit": "micrometer", - }, - { - "name": "x", - "type": "space", - "unit": "micrometer", - }, - ], - "datasets": [ - { - "path": f"{ind_level}", - "coordinateTransformations": [ - { - "type": "scale", - "scale": [ - 1, - pixel_size_z, - pixel_size_y - * coarsening_xy**ind_level, - pixel_size_x - * coarsening_xy**ind_level, - ], - } - ], - } - for ind_level in range(num_levels) - ], - } - ] - - group_image.attrs["omero"] = { - "id": 1, # TODO does this depend on the plate number? - "name": "TBD", - "version": __OME_NGFF_VERSION__, - "channels": define_omero_channels( - channels=actual_channels, bit_depth=bit_depth - ), - } - - # Validate Image attrs - NgffImageMeta(**group_image.attrs) - - # Prepare AnnData tables for FOV/well ROIs - well_id = get_filename_well_id(row, column) - FOV_ROIs_table = prepare_FOV_ROI_table(site_metadata.loc[well_id]) - well_ROIs_table = prepare_well_ROI_table( - site_metadata.loc[well_id] - ) - - # Write AnnData tables into the `tables` zarr group - write_table( - group_image, - "FOV_ROI_table", - FOV_ROIs_table, - overwrite=overwrite, - table_attrs={"type": "roi_table"}, - ) - write_table( - group_image, - "well_ROI_table", - well_ROIs_table, - overwrite=overwrite, - table_attrs={"type": "roi_table"}, - ) - - return dict(parallelization_list=parallelization_list) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=cellvoyager_to_ome_zarr_init, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init_multiplex.py b/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init_multiplex.py deleted file mode 100644 index d9598a049..000000000 --- a/fractal_tasks_core/tasks/cellvoyager_to_ome_zarr_init_multiplex.py +++ /dev/null @@ -1,541 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Create OME-NGFF zarr group, for multiplexing dataset. -""" -import os -from pathlib import Path -from typing import Any -from typing import Optional - -import pandas as pd -import zarr -from pydantic import validate_call -from zarr.errors import ContainsGroupError - -import fractal_tasks_core -from fractal_tasks_core.cellvoyager.filenames import ( - glob_with_multiple_patterns, -) -from fractal_tasks_core.cellvoyager.filenames import parse_filename -from fractal_tasks_core.cellvoyager.metadata import ( - parse_yokogawa_metadata, -) -from fractal_tasks_core.cellvoyager.wells import generate_row_col_split -from fractal_tasks_core.cellvoyager.wells import get_filename_well_id -from fractal_tasks_core.channels import check_unique_wavelength_ids -from fractal_tasks_core.channels import check_well_channel_labels -from fractal_tasks_core.channels import define_omero_channels -from fractal_tasks_core.ngff.specs import NgffImageMeta -from fractal_tasks_core.ngff.specs import Plate -from fractal_tasks_core.ngff.specs import Well -from fractal_tasks_core.roi import prepare_FOV_ROI_table -from fractal_tasks_core.roi import prepare_well_ROI_table -from fractal_tasks_core.roi import remove_FOV_overlaps -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks.io_models import InitArgsCellVoyager -from fractal_tasks_core.tasks.io_models import MultiplexingAcquisition -from fractal_tasks_core.zarr_utils import open_zarr_group_with_overwrite - -__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ - -import logging - -logger = logging.getLogger(__name__) - - -@validate_call -def cellvoyager_to_ome_zarr_init_multiplex( - *, - # Fractal parameters - zarr_urls: list[str], - zarr_dir: str, - # Core parameters - acquisitions: dict[str, MultiplexingAcquisition], - # Advanced parameters - include_glob_patterns: Optional[list[str]] = None, - exclude_glob_patterns: Optional[list[str]] = None, - num_levels: int = 5, - coarsening_xy: int = 2, - image_extension: str = "tif", - metadata_table_files: Optional[dict[str, str]] = None, - overwrite: bool = False, -) -> dict[str, Any]: - """ - Create OME-NGFF structure and metadata to host a multiplexing dataset. - - This task takes a set of image folders (i.e. different multiplexing - acquisitions) and build the internal structure and metadata of a OME-NGFF - zarr group, without actually loading/writing the image data. - - Each element in input_paths should be treated as a different acquisition. - - Args: - zarr_urls: List of paths or urls to the individual OME-Zarr image to - be processed. Not used by the converter task. - (standard argument for Fractal tasks, managed by Fractal server). - zarr_dir: path of the directory where the new OME-Zarrs will be - created. - (standard argument for Fractal tasks, managed by Fractal server). - acquisitions: dictionary of acquisitions. Each key is the acquisition - identifier (normally 0, 1, 2, 3 etc.). Each item defines the - acquisition by providing the image_dir and the allowed_channels. - include_glob_patterns: If specified, only parse images with filenames - that match with all these patterns. Patterns must be defined as in - https://docs.python.org/3/library/fnmatch.html, Example: - `image_glob_pattern=["*_B03_*"]` => only process well B03 - `image_glob_pattern=["*_C09_*", "*F016*", "*Z[0-5][0-9]C*"]` => - only process well C09, field of view 16 and Z planes 0-59. - Can interact with exclude_glob_patterns: All included images - all - excluded images gives the final list of images to process - exclude_glob_patterns: If specified, exclude any image where the - filename matches any of the exclusion patterns. Patterns are - specified the same as for include_glob_patterns. - num_levels: Number of resolution-pyramid levels. If set to `5`, there - will be the full-resolution level and 4 levels of downsampled - images. - coarsening_xy: Linear coarsening factor between subsequent levels. - If set to `2`, level 1 is 2x downsampled, level 2 is 4x downsampled - etc. - image_extension: Filename extension of images - (e.g. `"tif"` or `"png"`). - metadata_table_files: If `None`, parse Yokogawa metadata from mrf/mlf - files in the input_path folder; else, a dictionary of key-value - pairs like `(acquisition, path)` with `acquisition` a string like - the key of the `acquisitions` dict and `path` pointing to a csv - file containing the parsed metadata table. - overwrite: If `True`, overwrite the task output. - - Returns: - A metadata dictionary containing important metadata about the OME-Zarr - plate, the images and some parameters required by downstream tasks - (like `num_levels`). - """ - - if metadata_table_files: - # Checks on the dict: - # 1. Acquisitions in acquisitions dict and metadata_table_files match - # 2. Files end with ".csv" - # 3. Files exist. - if set(acquisitions.keys()) != set(metadata_table_files.keys()): - raise ValueError( - "Mismatch in acquisition keys between " - f"{acquisitions.keys()=} and " - f"{metadata_table_files.keys()=}" - ) - for f in metadata_table_files.values(): - if not f.endswith(".csv"): - raise ValueError( - f"{f} (in metadata_table_file) is not a csv file." - ) - if not os.path.isfile(f): - raise ValueError( - f"{f} (in metadata_table_file) does not exist." - ) - - # Preliminary checks on acquisitions - # Note that in metadata the keys of dictionary arguments should be - # strings (and not integers), so that they can be read from a JSON file - for key, values in acquisitions.items(): - if not isinstance(key, str): - raise ValueError(f"{acquisitions=} has non-string keys") - check_unique_wavelength_ids(values.allowed_channels) - try: - int(key) - except ValueError: - raise ValueError("Acquisition dictionary keys need to be integers") - - # Identify all plates and all channels, per input folders - dict_acquisitions: dict = {} - acquisitions_sorted = sorted(list(acquisitions.keys())) - for acquisition in acquisitions_sorted: - acq_input = acquisitions[acquisition] - dict_acquisitions[acquisition] = {} - - actual_wavelength_ids = [] - plates = [] - plate_prefixes = [] - - # Loop over all images - include_patterns = [f"*.{image_extension}"] - exclude_patterns = [] - if include_glob_patterns: - include_patterns.extend(include_glob_patterns) - if exclude_glob_patterns: - exclude_patterns.extend(exclude_glob_patterns) - input_filenames = glob_with_multiple_patterns( - folder=acq_input.image_dir, - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - for fn in input_filenames: - try: - filename_metadata = parse_filename(Path(fn).name) - plate = filename_metadata["plate"] - plates.append(plate) - plate_prefix = filename_metadata["plate_prefix"] - plate_prefixes.append(plate_prefix) - A = filename_metadata["A"] - C = filename_metadata["C"] - actual_wavelength_ids.append(f"A{A}_C{C}") - except ValueError as e: - logger.warning( - f'Skipping "{Path(fn).name}". Original error: ' + str(e) - ) - plates = sorted(list(set(plates))) - actual_wavelength_ids = sorted(list(set(actual_wavelength_ids))) - - info = ( - "Listing all plates/channels:\n" - f"Include patterns: {include_patterns}\n" - f"Exclude patterns: {exclude_patterns}\n" - f"Plates: {plates}\n" - f"Actual wavelength IDs: {actual_wavelength_ids}\n" - ) - - # Check that a folder includes a single plate - if len(plates) > 1: - raise ValueError(f"{info}ERROR: {len(plates)} plates detected") - elif len(plates) == 0: - raise ValueError(f"{info}ERROR: No plates detected") - original_plate = plates[0] - plate_prefix = plate_prefixes[0] - - # Replace plate with the one of the first acquisition - if acquisition != acquisitions_sorted[0]: - plate = dict_acquisitions[acquisitions_sorted[0]]["plate"] - logger.warning( - f"For {acquisition=}, we replace {original_plate=} with " - f"{plate=} (the one for acquisition {acquisitions_sorted[0]})" - ) - - # Check that all channels are in the allowed_channels - allowed_wavelength_ids = [ - c.wavelength_id for c in acq_input.allowed_channels - ] - if not set(actual_wavelength_ids).issubset( - set(allowed_wavelength_ids) - ): - msg = "ERROR in create_ome_zarr\n" - msg += f"actual_wavelength_ids: {actual_wavelength_ids}\n" - msg += f"allowed_wavelength_ids: {allowed_wavelength_ids}\n" - raise ValueError(msg) - - # Create actual_channels, i.e. a list of the channel dictionaries which - # are present - actual_channels = [ - channel - for channel in acq_input.allowed_channels - if channel.wavelength_id in actual_wavelength_ids - ] - - logger.info(f"plate: {plate}") - logger.info(f"actual_channels: {actual_channels}") - - dict_acquisitions[acquisition] = {} - dict_acquisitions[acquisition]["plate"] = plate - dict_acquisitions[acquisition]["original_plate"] = original_plate - dict_acquisitions[acquisition]["plate_prefix"] = plate_prefix - dict_acquisitions[acquisition]["image_folder"] = acq_input.image_dir - dict_acquisitions[acquisition]["original_paths"] = [ - acq_input.image_dir - ] - dict_acquisitions[acquisition]["actual_channels"] = actual_channels - dict_acquisitions[acquisition][ - "actual_wavelength_ids" - ] = actual_wavelength_ids - - parallelization_list = [] - current_plates = [item["plate"] for item in dict_acquisitions.values()] - if len(set(current_plates)) > 1: - raise ValueError(f"{current_plates=}") - plate = current_plates[0] - - zarrurl = dict_acquisitions[acquisitions_sorted[0]]["plate"] + ".zarr" - full_zarrurl = str(Path(zarr_dir) / zarrurl) - logger.info(f"Creating {full_zarrurl=}") - # Call zarr.open_group wrapper, which handles overwrite=True/False - group_plate = open_zarr_group_with_overwrite( - full_zarrurl, overwrite=overwrite - ) - group_plate.attrs["plate"] = { - "acquisitions": [ - { - "id": int(acquisition), - "name": dict_acquisitions[acquisition]["original_plate"], - } - for acquisition in acquisitions_sorted - ] - } - - zarrurls: dict[str, list[str]] = {"well": [], "image": []} - zarrurls["plate"] = [f"{plate}.zarr"] - - ################################################################ - logging.info(f"{acquisitions_sorted=}") - - for i, acquisition in enumerate(acquisitions_sorted): - # Define plate zarr - image_folder = dict_acquisitions[acquisition]["image_folder"] - logger.info(f"Looking at {image_folder=}") - - # Obtain FOV-metadata dataframe - if metadata_table_files is None: - mrf_path = f"{image_folder}/MeasurementDetail.mrf" - mlf_path = f"{image_folder}/MeasurementData.mlf" - site_metadata, _ = parse_yokogawa_metadata( - mrf_path, - mlf_path, - include_patterns=include_glob_patterns, - exclude_patterns=exclude_glob_patterns, - ) - site_metadata = remove_FOV_overlaps(site_metadata) - else: - site_metadata = pd.read_csv(metadata_table_files[acquisition]) - site_metadata.set_index(["well_id", "FieldIndex"], inplace=True) - - # Extract pixel sizes and bit_depth - pixel_size_z = site_metadata["pixel_size_z"].iloc[0] - pixel_size_y = site_metadata["pixel_size_y"].iloc[0] - pixel_size_x = site_metadata["pixel_size_x"].iloc[0] - bit_depth = site_metadata["bit_depth"].iloc[0] - - if min(pixel_size_z, pixel_size_y, pixel_size_x) < 1e-9: - raise ValueError(pixel_size_z, pixel_size_y, pixel_size_x) - - # Identify all wells - plate_prefix = dict_acquisitions[acquisition]["plate_prefix"] - include_patterns = [f"{plate_prefix}_*.{image_extension}"] - if include_glob_patterns: - include_patterns.extend(include_glob_patterns) - plate_images = glob_with_multiple_patterns( - folder=str(image_folder), - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - wells = [ - parse_filename(os.path.basename(fn))["well"] for fn in plate_images - ] - wells = sorted(list(set(wells))) - logger.info(f"{wells=}") - - # Verify that all wells have all channels - actual_channels = dict_acquisitions[acquisition]["actual_channels"] - for well in wells: - include_patterns = [f"{plate_prefix}_{well}_*.{image_extension}"] - if include_glob_patterns: - include_patterns.extend(include_glob_patterns) - well_images = glob_with_multiple_patterns( - folder=str(image_folder), - include_patterns=include_patterns, - exclude_patterns=exclude_patterns, - ) - - well_wavelength_ids = [] - for fpath in well_images: - try: - filename_metadata = parse_filename(os.path.basename(fpath)) - A = filename_metadata["A"] - C = filename_metadata["C"] - well_wavelength_ids.append(f"A{A}_C{C}") - except IndexError: - logger.info(f"Skipping {fpath}") - well_wavelength_ids = sorted(list(set(well_wavelength_ids))) - actual_wavelength_ids = dict_acquisitions[acquisition][ - "actual_wavelength_ids" - ] - if well_wavelength_ids != actual_wavelength_ids: - raise ValueError( - f"ERROR: well {well} in plate {plate} (prefix: " - f"{plate_prefix}) has missing channels.\n" - f"Expected: {actual_wavelength_ids}\n" - f"Found: {well_wavelength_ids}.\n" - ) - - well_rows_columns = generate_row_col_split(wells) - row_list = [ - well_row_column[0] for well_row_column in well_rows_columns - ] - col_list = [ - well_row_column[1] for well_row_column in well_rows_columns - ] - row_list = sorted(list(set(row_list))) - col_list = sorted(list(set(col_list))) - - plate_attrs = group_plate.attrs["plate"] - plate_attrs["columns"] = [{"name": col} for col in col_list] - plate_attrs["rows"] = [{"name": row} for row in row_list] - plate_attrs["wells"] = [ - { - "path": well_row_column[0] + "/" + well_row_column[1], - "rowIndex": row_list.index(well_row_column[0]), - "columnIndex": col_list.index(well_row_column[1]), - } - for well_row_column in well_rows_columns - ] - plate_attrs["version"] = __OME_NGFF_VERSION__ - # Validate plate attrs - Plate(**plate_attrs) - group_plate.attrs["plate"] = plate_attrs - - for row, column in well_rows_columns: - parallelization_list.append( - { - "zarr_url": ( - f"{zarr_dir}/{plate}.zarr/{row}/{column}/" f"{i}" - ), - "init_args": InitArgsCellVoyager( - image_dir=acquisitions[acquisition].image_dir, - plate_prefix=plate_prefix, - well_ID=get_filename_well_id(row, column), - image_extension=image_extension, - include_glob_patterns=include_glob_patterns, - exclude_glob_patterns=exclude_glob_patterns, - acquisition=acquisition, - ).model_dump(), - } - ) - try: - group_well = group_plate.create_group(f"{row}/{column}/") - logging.info(f"Created new group_well at {row}/{column}/") - well_attrs = { - "images": [ - { - "path": f"{i}", - "acquisition": int(acquisition), - } - ], - "version": __OME_NGFF_VERSION__, - } - # Validate well attrs: - Well(**well_attrs) - group_well.attrs["well"] = well_attrs - zarrurls["well"].append(f"{plate}.zarr/{row}/{column}") - except ContainsGroupError: - group_well = zarr.open_group( - f"{full_zarrurl}/{row}/{column}/", mode="r+" - ) - logging.info( - f"Loaded group_well from {full_zarrurl}/{row}/{column}" - ) - current_images = group_well.attrs["well"]["images"] + [ - {"path": f"{i}", "acquisition": int(acquisition)} - ] - well_attrs = dict( - images=current_images, - version=group_well.attrs["well"]["version"], - ) - # Validate well attrs: - Well(**well_attrs) - group_well.attrs["well"] = well_attrs - - group_image = group_well.create_group(f"{i}/") # noqa: F841 - logging.info(f"Created image group {row}/{column}/{i}") - image = f"{plate}.zarr/{row}/{column}/{i}" - zarrurls["image"].append(image) - - group_image.attrs["multiscales"] = [ - { - "version": __OME_NGFF_VERSION__, - "axes": [ - {"name": "c", "type": "channel"}, - { - "name": "z", - "type": "space", - "unit": "micrometer", - }, - { - "name": "y", - "type": "space", - "unit": "micrometer", - }, - { - "name": "x", - "type": "space", - "unit": "micrometer", - }, - ], - "datasets": [ - { - "path": f"{ind_level}", - "coordinateTransformations": [ - { - "type": "scale", - "scale": [ - 1, - pixel_size_z, - pixel_size_y - * coarsening_xy**ind_level, - pixel_size_x - * coarsening_xy**ind_level, - ], - } - ], - } - for ind_level in range(num_levels) - ], - } - ] - - group_image.attrs["omero"] = { - "id": 1, # FIXME does this depend on the plate number? - "name": "TBD", - "version": __OME_NGFF_VERSION__, - "channels": define_omero_channels( - channels=actual_channels, - bit_depth=bit_depth, - label_prefix=i, - ), - } - # Validate Image attrs - NgffImageMeta(**group_image.attrs) - - # Prepare AnnData tables for FOV/well ROIs - well_id = get_filename_well_id(row, column) - FOV_ROIs_table = prepare_FOV_ROI_table(site_metadata.loc[well_id]) - well_ROIs_table = prepare_well_ROI_table( - site_metadata.loc[well_id] - ) - - # Write AnnData tables into the `tables` zarr group - write_table( - group_image, - "FOV_ROI_table", - FOV_ROIs_table, - overwrite=overwrite, - table_attrs={"type": "roi_table"}, - ) - write_table( - group_image, - "well_ROI_table", - well_ROIs_table, - overwrite=overwrite, - table_attrs={"type": "roi_table"}, - ) - - # Check that the different images (e.g. different acquisitions) in the each - # well have unique labels - for well_path in zarrurls["well"]: - check_well_channel_labels( - well_zarr_path=str(Path(zarr_dir) / well_path) - ) - - return dict(parallelization_list=parallelization_list) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=cellvoyager_to_ome_zarr_init_multiplex, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/copy_ome_zarr_hcs_plate.py b/fractal_tasks_core/tasks/copy_ome_zarr_hcs_plate.py deleted file mode 100644 index 4ec5e8111..000000000 --- a/fractal_tasks_core/tasks/copy_ome_zarr_hcs_plate.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Marco Franzon -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Task that copies the structure of an OME-NGFF zarr array to a new one. -""" -import logging -from typing import Any - -from pydantic import validate_call - -import fractal_tasks_core -from fractal_tasks_core.ngff.specs import NgffPlateMeta -from fractal_tasks_core.ngff.specs import WellInPlate -from fractal_tasks_core.ngff.zarr_utils import load_NgffPlateMeta -from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta -from fractal_tasks_core.tasks.io_models import InitArgsMIP -from fractal_tasks_core.tasks.projection_utils import DaskProjectionMethod -from fractal_tasks_core.zarr_utils import open_zarr_group_with_overwrite - -logger = logging.getLogger(__name__) - - -__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ - - -def _get_plate_url_from_image_url(zarr_url: str) -> str: - """ - Given the absolute `zarr_url` for an OME-Zarr image within an HCS plate, - return the path to the plate zarr group. - """ - zarr_url = zarr_url.rstrip("/") - plate_path = "/".join(zarr_url.split("/")[:-3]) - return plate_path - - -def _get_well_sub_url(zarr_url: str) -> str: - """ - Given the absolute `zarr_url` for an OME-Zarr image within an HCS plate, - return the path to the image zarr group. - """ - zarr_url = zarr_url.rstrip("/") - well_url = "/".join(zarr_url.split("/")[-3:-1]) - return well_url - - -def _get_image_sub_url(zarr_url: str) -> str: - """ - Given the absolute `zarr_url` for an OME-Zarr image, return the image - zarr-group name. - """ - zarr_url = zarr_url.rstrip("/") - image_sub_url = zarr_url.split("/")[-1] - return image_sub_url - - -def _generate_wells_rows_columns( - well_list: list[str], -) -> tuple[list[WellInPlate], list[str], list[str]]: - """ - Generate the plate well metadata based on the list of wells. - """ - rows = [] - columns = [] - wells = [] - for well in well_list: - rows.append(well.split("/")[0]) - columns.append(well.split("/")[1]) - rows = sorted(list(set(rows))) - columns = sorted(list(set(columns))) - for well in well_list: - wells.append( - WellInPlate( - path=well, - rowIndex=rows.index(well.split("/")[0]), - columnIndex=columns.index(well.split("/")[1]), - ) - ) - - return wells, rows, columns - - -def _generate_plate_well_metadata( - zarr_urls: list[str], -) -> tuple[dict[str, dict], dict[str, dict[str, dict]], dict[str, dict]]: - """ - Generate metadata for OME-Zarr HCS plates & wells. - - Based on the list of zarr_urls, generate metadata for all plates and all - their wells. - - Args: - zarr_urls: List of paths or urls to the individual OME-Zarr image to - be processed. - - Returns: - plate_metadata_dicts: Dictionary of plate plate metadata. The structure - is: {"old_plate_name": NgffPlateMeta (as dict)}. - new_well_image_attrs: Dictionary of image lists for the new wells. - The structure is: {"old_plate_name": {"old_well_name": - [ImageInWell(as dict)]}} - well_image_attrs: Dictionary of Image attributes of the existing wells. - """ - # TODO: Simplify this block. Currently complicated, because we need to loop - # through all potential plates, all their wells & their images to build up - # the metadata for the plate & well. - plate_metadata_dicts = {} - plate_wells = {} - well_image_attrs = {} - new_well_image_attrs = {} - for zarr_url in zarr_urls: - # Extract plate/well/image parts of `zarr_url` - old_plate_url = _get_plate_url_from_image_url(zarr_url) - well_sub_url = _get_well_sub_url(zarr_url) - curr_img_sub_url = _get_image_sub_url(zarr_url) - - # The first time a plate is found, create its metadata - if old_plate_url not in plate_metadata_dicts: - logger.info(f"Reading plate metadata of {old_plate_url=}") - old_plate_meta = load_NgffPlateMeta(old_plate_url) - plate_metadata = dict( - plate=dict( - acquisitions=old_plate_meta.plate.acquisitions, - field_count=old_plate_meta.plate.field_count, - name=old_plate_meta.plate.name, - # The new field count could be different from the old - # field count - version=old_plate_meta.plate.version, - ) - ) - plate_metadata_dicts[old_plate_url] = plate_metadata - plate_wells[old_plate_url] = [] - well_image_attrs[old_plate_url] = {} - new_well_image_attrs[old_plate_url] = {} - - # The first time a plate/well pair is found, create the well metadata - if well_sub_url not in plate_wells[old_plate_url]: - plate_wells[old_plate_url].append(well_sub_url) - old_well_url = f"{old_plate_url}/{well_sub_url}" - logger.info(f"Reading well metadata of {old_well_url}") - well_attrs = load_NgffWellMeta(old_well_url) - well_image_attrs[old_plate_url][well_sub_url] = well_attrs.well - new_well_image_attrs[old_plate_url][well_sub_url] = [] - - # Find images of the current well with name matching the current image - # TODO: clarify whether this list must always have length 1 - curr_well_image_list = [ - img - for img in well_image_attrs[old_plate_url][well_sub_url].images - if img.path == curr_img_sub_url - ] - new_well_image_attrs[old_plate_url][ - well_sub_url - ] += curr_well_image_list - - # Fill in the plate metadata based on all available wells - for old_plate_url in plate_metadata_dicts: - well_list, row_list, column_list = _generate_wells_rows_columns( - plate_wells[old_plate_url] - ) - plate_metadata_dicts[old_plate_url]["plate"]["columns"] = [] - for column in column_list: - plate_metadata_dicts[old_plate_url]["plate"]["columns"].append( - {"name": column} - ) - - plate_metadata_dicts[old_plate_url]["plate"]["rows"] = [] - for row in row_list: - plate_metadata_dicts[old_plate_url]["plate"]["rows"].append( - {"name": row} - ) - plate_metadata_dicts[old_plate_url]["plate"]["wells"] = well_list - - # Validate with NgffPlateMeta model - plate_metadata_dicts[old_plate_url] = NgffPlateMeta( - **plate_metadata_dicts[old_plate_url] - ).model_dump(exclude_none=True) - - return plate_metadata_dicts, new_well_image_attrs, well_image_attrs - - -@validate_call -def copy_ome_zarr_hcs_plate( - *, - # Fractal parameters - zarr_urls: list[str], - zarr_dir: str, - method: DaskProjectionMethod = DaskProjectionMethod.MIP, - # Advanced parameters - overwrite: bool = False, -) -> dict[str, Any]: - """ - Duplicate the OME-Zarr HCS structure for a set of zarr_urls. - - This task only processes the zarr images in the zarr_urls, not all the - images in the plate. It copies all the plate & well structure, but none - of the image metadata or the actual image data: - - - For each plate, create a new OME-Zarr HCS plate with the attributes for - all the images in zarr_urls - - For each well (in each plate), create a new zarr subgroup with the - same attributes as the original one. - - Note: this task makes use of methods from the `Attributes` class, see - https://zarr.readthedocs.io/en/stable/api/attrs.html. - - Args: - zarr_urls: List of paths or urls to the individual OME-Zarr image to - be processed. - (standard argument for Fractal tasks, managed by Fractal server). - zarr_dir: path of the directory where the new OME-Zarrs will be - created. - (standard argument for Fractal tasks, managed by Fractal server). - method: Choose which method to use for intensity projection along the - Z axis. mip is the default and performs a maximum intensity - projection. minip performs a minimum intensity projection, meanip - a mean intensity projection and sumip a sum intensity projection. - overwrite: If `True`, overwrite the task output. - - Returns: - A parallelization list to be used in a compute task to fill the wells - with OME-Zarr images. - """ - - parallelization_list = [] - - # Generate parallelization list - for zarr_url in zarr_urls: - old_plate_url = _get_plate_url_from_image_url(zarr_url) - well_sub_url = _get_well_sub_url(zarr_url) - old_plate_name = old_plate_url.split(".zarr")[-2].split("/")[-1] - new_plate_name = f"{old_plate_name}_{method.value}" - zarrurl_plate_new = f"{zarr_dir}/{new_plate_name}.zarr" - curr_img_sub_url = _get_image_sub_url(zarr_url) - new_zarr_url = f"{zarrurl_plate_new}/{well_sub_url}/{curr_img_sub_url}" - parallelization_item = dict( - zarr_url=new_zarr_url, - init_args=dict( - origin_url=zarr_url, - method=method.value, - overwrite=overwrite, - new_plate_name=f"{new_plate_name}.zarr", - ), - ) - InitArgsMIP(**parallelization_item["init_args"]) - parallelization_list.append(parallelization_item) - - # Generate the plate metadata & parallelization list - ( - plate_attrs_dicts, - new_well_image_attrs, - well_image_attrs, - ) = _generate_plate_well_metadata(zarr_urls=zarr_urls) - - # Create the new OME-Zarr HCS plate - for old_plate_url, plate_attrs in plate_attrs_dicts.items(): - old_plate_name = old_plate_url.split(".zarr")[-2].split("/")[-1] - new_plate_name = f"{old_plate_name}_{method.value}" - zarrurl_new = f"{zarr_dir}/{new_plate_name}.zarr" - logger.info(f"{old_plate_url=}") - logger.info(f"{zarrurl_new=}") - new_plate_group = open_zarr_group_with_overwrite( - zarrurl_new, overwrite=overwrite - ) - new_plate_group.attrs.put(plate_attrs) - - # Write well groups: - for well_sub_url in new_well_image_attrs[old_plate_url]: - new_well_group = new_plate_group.create_group(f"{well_sub_url}") - well_attrs = dict( - well=dict( - images=[ - img.model_dump(exclude_none=True) - for img in new_well_image_attrs[old_plate_url][ - well_sub_url - ] - ], - version=well_image_attrs[old_plate_url][ - well_sub_url - ].version, - ) - ) - new_well_group.attrs.put(well_attrs) - - return dict(parallelization_list=parallelization_list) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=copy_ome_zarr_hcs_plate, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/find_registration_consensus.py b/fractal_tasks_core/tasks/find_registration_consensus.py deleted file mode 100644 index f0aadb010..000000000 --- a/fractal_tasks_core/tasks/find_registration_consensus.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Joel Lüthi -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Applies the multiplexing translation to all ROI tables -""" -import logging -from typing import Optional - -import anndata as ad -import zarr -from pydantic import validate_call - -from fractal_tasks_core.roi import ( - are_ROI_table_columns_valid, -) -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks._registration_utils import ( - add_zero_translation_columns, -) -from fractal_tasks_core.tasks._registration_utils import ( - apply_registration_to_single_ROI_table, -) -from fractal_tasks_core.tasks._registration_utils import ( - calculate_min_max_across_dfs, -) -from fractal_tasks_core.tasks.io_models import InitArgsRegistrationConsensus - - -logger = logging.getLogger(__name__) - - -@validate_call -def find_registration_consensus( - *, - # Fractal parameters - zarr_url: str, - init_args: InitArgsRegistrationConsensus, - # Core parameters - roi_table: str = "FOV_ROI_table", - # Advanced parameters - new_roi_table: Optional[str] = None, -): - """ - Applies pre-calculated registration to ROI tables. - - Apply pre-calculated registration such that resulting ROIs contain - the consensus align region between all acquisitions. - - Parallelization level: well - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - Refers to the zarr_url of the reference acquisition. - (standard argument for Fractal tasks, managed by Fractal server). - init_args: Intialization arguments provided by - `init_group_by_well_for_multiplexing`. It contains the - zarr_url_list listing all the zarr_urls in the same well as the - zarr_url of the reference acquisition that are being processed. - (standard argument for Fractal tasks, managed by Fractal server). - roi_table: Name of the ROI table over which the task loops to - calculate the registration. Examples: `FOV_ROI_table` => loop over - the field of views, `well_ROI_table` => process the whole well as - one image. - new_roi_table: Optional name for the new, registered ROI table. If no - name is given, it will default to "registered_" + `roi_table` - - """ - if not new_roi_table: - new_roi_table = "registered_" + roi_table - logger.info( - f"Running for {zarr_url=} & the other acquisitions in that well. \n" - f"Applying translation registration to {roi_table=} and storing it as " - f"{new_roi_table=}." - ) - - # Collect all the ROI tables - roi_tables = {} - roi_tables_attrs = {} - for acq_zarr_url in init_args.zarr_url_list: - curr_ROI_table = ad.read_zarr(f"{acq_zarr_url}/tables/{roi_table}") - curr_ROI_table_group = zarr.open_group( - f"{acq_zarr_url}/tables/{roi_table}", mode="r" - ) - curr_ROI_table_attrs = curr_ROI_table_group.attrs.asdict() - - # For reference_acquisition, handle the fact that it doesn't - # have the shifts - if acq_zarr_url == zarr_url: - curr_ROI_table = add_zero_translation_columns(curr_ROI_table) - # Check for valid ROI tables - are_ROI_table_columns_valid(table=curr_ROI_table) - translation_columns = [ - "translation_z", - "translation_y", - "translation_x", - ] - if curr_ROI_table.var.index.isin(translation_columns).sum() != 3: - raise ValueError( - f"{roi_table=} in {acq_zarr_url} does not contain the " - f"translation columns {translation_columns} necessary to use " - "this task." - ) - roi_tables[acq_zarr_url] = curr_ROI_table - roi_tables_attrs[acq_zarr_url] = curr_ROI_table_attrs - - # Check that all acquisitions have the same ROIs - rois = roi_tables[list(roi_tables.keys())[0]].obs.index - for acq_zarr_url, acq_roi_table in roi_tables.items(): - if not (acq_roi_table.obs.index == rois).all(): - raise ValueError( - f"Acquisition {acq_zarr_url} does not contain the same ROIs " - f"as the reference acquisition {zarr_url}:\n" - f"{acq_zarr_url}: {acq_roi_table.obs.index}\n" - f"{zarr_url}: {rois}" - ) - - roi_table_dfs = [ - roi_table.to_df().loc[:, translation_columns] - for roi_table in roi_tables.values() - ] - logger.info("Calculating min & max translation across acquisitions.") - max_df, min_df = calculate_min_max_across_dfs(roi_table_dfs) - shifted_rois = {} - - # Loop over acquisitions - for acq_zarr_url in init_args.zarr_url_list: - shifted_rois[acq_zarr_url] = apply_registration_to_single_ROI_table( - roi_tables[acq_zarr_url], max_df, min_df - ) - - # TODO: Drop translation columns from this table? - - logger.info( - f"Write the registered ROI table {new_roi_table} for " - "{acq_zarr_url=}" - ) - # Save the shifted ROI table as a new table - image_group = zarr.group(acq_zarr_url) - write_table( - image_group, - new_roi_table, - shifted_rois[acq_zarr_url], - table_attrs=roi_tables_attrs[acq_zarr_url], - overwrite=True, - ) - - # TODO: Optionally apply registration to other tables as well? - # e.g. to well_ROI_table based on FOV_ROI_table - # => out of scope for the initial task, apply registration separately - # to each table - # Easiest implementation: Apply average shift calculcated here to other - # ROIs. From many to 1 (e.g. FOV => well) => average shift, but crop len - # From well to many (e.g. well to FOVs) => average shift, crop len by that - # amount - # Many to many (FOVs to organoids) => tricky because of matching - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=find_registration_consensus, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/illumination_correction.py b/fractal_tasks_core/tasks/illumination_correction.py deleted file mode 100644 index edb045b7b..000000000 --- a/fractal_tasks_core/tasks/illumination_correction.py +++ /dev/null @@ -1,292 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Marco Franzon -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Apply illumination correction to all fields of view. -""" -import logging -import time -import warnings -from pathlib import Path -from typing import Any - -import anndata as ad -import dask.array as da -import numpy as np -import zarr -from pydantic import validate_call -from skimage.io import imread - -from fractal_tasks_core.channels import get_omero_channel_list -from fractal_tasks_core.channels import OmeroChannel -from fractal_tasks_core.ngff import load_NgffImageMeta -from fractal_tasks_core.pyramids import build_pyramid -from fractal_tasks_core.roi import check_valid_ROI_indices -from fractal_tasks_core.roi import ( - convert_ROI_table_to_indices, -) -from fractal_tasks_core.tasks._zarr_utils import _copy_hcs_ome_zarr_metadata -from fractal_tasks_core.tasks._zarr_utils import _copy_tables_from_zarr_url - -logger = logging.getLogger(__name__) - - -def correct( - img_stack: np.ndarray, - corr_img: np.ndarray, - background: int = 110, -): - """ - Corrects a stack of images, using a given illumination profile (e.g. bright - in the center of the image, dim outside). - - Args: - img_stack: 4D numpy array (czyx), with dummy size along c. - corr_img: 2D numpy array (yx) - background: Background value that is subtracted from the image before - the illumination correction is applied. - """ - - logger.info(f"Start correct, {img_stack.shape}") - - # Check shapes - if corr_img.shape != img_stack.shape[2:] or img_stack.shape[0] != 1: - raise ValueError( - "Error in illumination_correction:\n" - f"{img_stack.shape=}\n{corr_img.shape=}" - ) - - # Store info about dtype - dtype = img_stack.dtype - dtype_max = np.iinfo(dtype).max - - # Background subtraction - img_stack[img_stack <= background] = 0 - img_stack[img_stack > background] -= background - - # Apply the normalized correction matrix (requires a float array) - # img_stack = img_stack.astype(np.float64) - new_img_stack = img_stack / (corr_img / np.max(corr_img))[None, None, :, :] - - # Handle edge case: corrected image may have values beyond the limit of - # the encoding, e.g. beyond 65535 for 16bit images. This clips values - # that surpass this limit and triggers a warning - if np.sum(new_img_stack > dtype_max) > 0: - warnings.warn( - "Illumination correction created values beyond the max range of " - f"the current image type. These have been clipped to {dtype_max=}." - ) - new_img_stack[new_img_stack > dtype_max] = dtype_max - - logger.info("End correct") - - # Cast back to original dtype and return - return new_img_stack.astype(dtype) - - -@validate_call -def illumination_correction( - *, - # Fractal parameters - zarr_url: str, - # Core parameters - illumination_profiles_folder: str, - illumination_profiles: dict[str, str], - background: int = 0, - input_ROI_table: str = "FOV_ROI_table", - overwrite_input: bool = True, - # Advanced parameters - suffix: str = "_illum_corr", -) -> dict[str, Any]: - """ - Applies illumination correction to the images in the OME-Zarr. - - Assumes that the illumination correction profiles were generated before - separately and that the same background subtraction was used during - calculation of the illumination correction (otherwise, it will not work - well & the correction may only be partial). - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - (standard argument for Fractal tasks, managed by Fractal server). - illumination_profiles_folder: Path of folder of illumination profiles. - illumination_profiles: Dictionary where keys match the `wavelength_id` - attributes of existing channels (e.g. `A01_C01` ) and values are - the filenames of the corresponding illumination profiles. - background: Background value that is subtracted from the image before - the illumination correction is applied. Set it to `0` if you don't - want any background subtraction. - input_ROI_table: Name of the ROI table that contains the information - about the location of the individual field of views (FOVs) to - which the illumination correction shall be applied. Defaults to - "FOV_ROI_table", the default name Fractal converters give the ROI - tables that list all FOVs separately. If you generated your - OME-Zarr with a different converter and used Import OME-Zarr to - generate the ROI tables, `image_ROI_table` is the right choice if - you only have 1 FOV per Zarr image and `grid_ROI_table` if you - have multiple FOVs per Zarr image and set the right grid options - during import. - overwrite_input: If `True`, the results of this task will overwrite - the input image data. If false, a new image is generated and the - illumination corrected data is saved there. - suffix: What suffix to append to the illumination corrected images. - Only relevant if `overwrite_input=False`. - """ - - # Defione old/new zarrurls - if overwrite_input: - zarr_url_new = zarr_url.rstrip("/") - else: - zarr_url_new = zarr_url.rstrip("/") + suffix - - t_start = time.perf_counter() - logger.info("Start illumination_correction") - logger.info(f" {overwrite_input=}") - logger.info(f" {zarr_url=}") - logger.info(f" {zarr_url_new=}") - - # Read attributes from NGFF metadata - ngff_image_meta = load_NgffImageMeta(zarr_url) - num_levels = ngff_image_meta.num_levels - coarsening_xy = ngff_image_meta.coarsening_xy - full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) - logger.info(f"NGFF image has {num_levels=}") - logger.info(f"NGFF image has {coarsening_xy=}") - logger.info( - f"NGFF image has full-res pixel sizes {full_res_pxl_sizes_zyx}" - ) - - # Read channels from .zattrs - channels: list[OmeroChannel] = get_omero_channel_list( - image_zarr_path=zarr_url - ) - num_channels = len(channels) - - # Read FOV ROIs - FOV_ROI_table = ad.read_zarr(f"{zarr_url}/tables/{input_ROI_table}") - - # Create list of indices for 3D FOVs spanning the entire Z direction - list_indices = convert_ROI_table_to_indices( - FOV_ROI_table, - level=0, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, - ) - check_valid_ROI_indices(list_indices, input_ROI_table) - - # Extract image size from FOV-ROI indices. Note: this works at level=0, - # where FOVs should all be of the exact same size (in pixels) - ref_img_size = None - for indices in list_indices: - img_size = (indices[3] - indices[2], indices[5] - indices[4]) - if ref_img_size is None: - ref_img_size = img_size - else: - if img_size != ref_img_size: - raise ValueError( - "ERROR: inconsistent image sizes in list_indices" - ) - img_size_y, img_size_x = img_size[:] - - # Assemble dictionary of matrices and check their shapes - corrections = {} - for channel in channels: - wavelength_id = channel.wavelength_id - corrections[wavelength_id] = imread( - ( - Path(illumination_profiles_folder) - / illumination_profiles[wavelength_id] - ).as_posix() - ) - if corrections[wavelength_id].shape != (img_size_y, img_size_x): - raise ValueError( - "Error in illumination_correction, " - "correction matrix has wrong shape." - ) - - # Lazily load highest-res level from original zarr array - data_czyx = da.from_zarr(f"{zarr_url}/0") - - # Create zarr for output - if overwrite_input: - new_zarr = zarr.open(f"{zarr_url_new}/0") - else: - new_zarr = zarr.create( - shape=data_czyx.shape, - chunks=data_czyx.chunksize, - dtype=data_czyx.dtype, - store=zarr.storage.FSStore(f"{zarr_url_new}/0"), - overwrite=False, - dimension_separator="/", - ) - _copy_hcs_ome_zarr_metadata(zarr_url, zarr_url_new) - # Copy ROI tables from the old zarr_url to keep ROI tables and other - # tables available in the new Zarr - _copy_tables_from_zarr_url(zarr_url, zarr_url_new) - - # Iterate over FOV ROIs - num_ROIs = len(list_indices) - for i_c, channel in enumerate(channels): - for i_ROI, indices in enumerate(list_indices): - # Define region - s_z, e_z, s_y, e_y, s_x, e_x = indices[:] - region = ( - slice(i_c, i_c + 1), - slice(s_z, e_z), - slice(s_y, e_y), - slice(s_x, e_x), - ) - logger.info( - f"Now processing ROI {i_ROI+1}/{num_ROIs} " - f"for channel {i_c+1}/{num_channels}" - ) - # Execute illumination correction - corrected_fov = correct( - data_czyx[region].compute(), - corrections[channel.wavelength_id], - background=background, - ) - # Write to disk - da.array(corrected_fov).to_zarr( - url=new_zarr, - region=region, - compute=True, - ) - - # Starting from on-disk highest-resolution data, build and write to disk a - # pyramid of coarser levels - build_pyramid( - zarrurl=zarr_url_new, - overwrite=True, - num_levels=num_levels, - coarsening_xy=coarsening_xy, - chunksize=data_czyx.chunksize, - ) - - t_end = time.perf_counter() - logger.info(f"End illumination_correction, elapsed: {t_end-t_start}") - - if overwrite_input: - image_list_updates = dict(image_list_updates=[dict(zarr_url=zarr_url)]) - else: - image_list_updates = dict( - image_list_updates=[dict(zarr_url=zarr_url_new, origin=zarr_url)] - ) - return image_list_updates - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=illumination_correction, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/image_based_registration_hcs_init.py b/fractal_tasks_core/tasks/image_based_registration_hcs_init.py deleted file mode 100644 index db76eb595..000000000 --- a/fractal_tasks_core/tasks/image_based_registration_hcs_init.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Joel Lüthi -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Initializes the parallelization list for registration in HCS plates. -""" -import logging -from typing import Any - -from pydantic import validate_call - -from fractal_tasks_core.utils import ( - create_well_acquisition_dict, -) - -logger = logging.getLogger(__name__) - - -@validate_call -def image_based_registration_hcs_init( - *, - # Fractal parameters - zarr_urls: list[str], - zarr_dir: str, - # Core parameters - reference_acquisition: int = 0, -) -> dict[str, list[dict[str, Any]]]: - """ - Initialized calculate registration task - - This task prepares a parallelization list of all zarr_urls that need to be - used to calculate the registration between acquisitions (all zarr_urls - except the reference acquisition vs. the reference acquisition). - This task only works for HCS OME-Zarrs for 2 reasons: Only HCS OME-Zarrs - currently have defined acquisition metadata to determine reference - acquisitions. And we have only implemented the grouping of images for - HCS OME-Zarrs by well (with the assumption that every well just has 1 - image per acqusition). - - Args: - zarr_urls: List of paths or urls to the individual OME-Zarr image to - be processed. - (standard argument for Fractal tasks, managed by Fractal server). - zarr_dir: path of the directory where the new OME-Zarrs will be - created. Not used by this task. - (standard argument for Fractal tasks, managed by Fractal server). - reference_acquisition: Which acquisition to register against. Needs to - match the acquisition metadata in the OME-Zarr image. - - Returns: - task_output: Dictionary for Fractal server that contains a - parallelization list. - """ - logger.info( - f"Running `image_based_registration_hcs_init` for {zarr_urls=}" - ) - image_groups = create_well_acquisition_dict(zarr_urls) - - # Create the parallelization list - parallelization_list = [] - for key, image_group in image_groups.items(): - # Assert that all image groups have the reference acquisition present - if reference_acquisition not in image_group.keys(): - raise ValueError( - f"Registration with {reference_acquisition=} can only work if " - "all wells have the reference acquisition present. It was not " - f"found for well {key}." - ) - # Add all zarr_urls except the reference acquisition to the - # parallelization list - for acquisition, zarr_url in image_group.items(): - if acquisition != reference_acquisition: - reference_zarr_url = image_group[reference_acquisition] - parallelization_list.append( - dict( - zarr_url=zarr_url, - init_args=dict(reference_zarr_url=reference_zarr_url), - ) - ) - - return dict(parallelization_list=parallelization_list) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=image_based_registration_hcs_init, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/import_ome_zarr.py b/fractal_tasks_core/tasks/import_ome_zarr.py deleted file mode 100644 index f35bdd48f..000000000 --- a/fractal_tasks_core/tasks/import_ome_zarr.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Task to import an existing OME-Zarr. -""" -import logging -from typing import Any -from typing import Optional - -import dask.array as da -import zarr -from pydantic import validate_call - -from fractal_tasks_core.channels import update_omero_channels -from fractal_tasks_core.ngff import detect_ome_ngff_type -from fractal_tasks_core.ngff import NgffImageMeta -from fractal_tasks_core.roi import get_image_grid_ROIs -from fractal_tasks_core.roi import get_single_image_ROI -from fractal_tasks_core.tables import write_table - -logger = logging.getLogger(__name__) - - -def _process_single_image( - image_path: str, - add_image_ROI_table: bool, - add_grid_ROI_table: bool, - update_omero_metadata: bool, - *, - grid_YX_shape: Optional[tuple[int, int]] = None, - overwrite: bool = False, -) -> dict[str, str]: - """ - Validate OME-NGFF metadata and optionally generate ROI tables. - - This task: - - 1. Validates OME-NGFF image metadata, via `NgffImageMeta`; - 2. Optionally generates and writes two ROI tables; - 3. Optionally update OME-NGFF omero metadata. - 4. Returns dataset types - - Args: - image_path: Absolute path to the image Zarr group. - add_image_ROI_table: Whether to add a `image_ROI_table` table - (argument propagated from `import_ome_zarr`). - add_grid_ROI_table: Whether to add a `grid_ROI_table` table (argument - propagated from `import_ome_zarr`). - update_omero_metadata: Whether to update Omero-channels metadata - (argument propagated from `import_ome_zarr`). - grid_YX_shape: YX shape of the ROI grid (it must be not `None`, if - `add_grid_ROI_table=True`. - """ - - # Note from zarr docs: `r+` means read/write (must exist) - image_group = zarr.open_group(image_path, mode="r+") - image_meta = NgffImageMeta(**image_group.attrs.asdict()) - - # Preliminary checks - if add_grid_ROI_table and (grid_YX_shape is None): - raise ValueError( - f"_process_single_image called with {add_grid_ROI_table=}, " - f"but {grid_YX_shape=}." - ) - - pixels_ZYX = image_meta.get_pixel_sizes_zyx(level=0) - - # Read zarr array - dataset_subpath = image_meta.datasets[0].path - array = da.from_zarr(f"{image_path}/{dataset_subpath}") - - # Prepare image_ROI_table and write it into the zarr group - if add_image_ROI_table: - image_ROI_table = get_single_image_ROI(array.shape, pixels_ZYX) - write_table( - image_group, - "image_ROI_table", - image_ROI_table, - overwrite=overwrite, - table_attrs={"type": "roi_table"}, - ) - - # Prepare grid_ROI_table and write it into the zarr group - if add_grid_ROI_table: - grid_ROI_table = get_image_grid_ROIs( - array.shape, - pixels_ZYX, - grid_YX_shape, - ) - write_table( - image_group, - "grid_ROI_table", - grid_ROI_table, - overwrite=overwrite, - table_attrs={"type": "roi_table"}, - ) - - # Update Omero-channels metadata - if update_omero_metadata: - # Extract number of channels from zarr array - try: - channel_axis_index = image_meta.axes_names.index("c") - except ValueError: - logger.error(f"Existing axes: {image_meta.axes_names}") - msg = ( - "OME-Zarrs with no channel axis are not currently " - "supported in fractal-tasks-core. Upcoming flexibility " - "improvements are tracked in https://github.com/" - "fractal-analytics-platform/fractal-tasks-core/issues/150." - ) - logger.error(msg) - raise NotImplementedError(msg) - logger.info(f"Existing axes: {image_meta.axes_names}") - logger.info(f"Channel-axis index: {channel_axis_index}") - num_channels_zarr = array.shape[channel_axis_index] - logger.info( - f"{num_channels_zarr} channel(s) found in Zarr array " - f"at {image_path}/{dataset_subpath}" - ) - # Update or create omero channels metadata - old_omero = image_group.attrs.get("omero", {}) - old_channels = old_omero.get("channels", []) - if len(old_channels) > 0: - logger.info( - f"{len(old_channels)} channel(s) found in NGFF omero metadata" - ) - if len(old_channels) != num_channels_zarr: - error_msg = ( - "Channels-number mismatch: Number of channels in the " - f"zarr array ({num_channels_zarr}) differs from number " - "of channels listed in NGFF omero metadata " - f"({len(old_channels)})." - ) - logging.error(error_msg) - raise ValueError(error_msg) - else: - old_channels = [{} for ind in range(num_channels_zarr)] - new_channels = update_omero_channels(old_channels) - new_omero = old_omero.copy() - new_omero["channels"] = new_channels - image_group.attrs.update(omero=new_omero) - - # Determine image types: - # Later: also provide a has_T flag. - # TODO: Potentially also load acquisition metadata if available in a Zarr - is_3D = False - if "z" in image_meta.axes_names: - if array.shape[-3] > 1: - is_3D = True - types = dict(is_3D=is_3D) - return types - - -@validate_call -def import_ome_zarr( - *, - # Fractal parameters - zarr_urls: list[str], - zarr_dir: str, - # Core parameters - zarr_name: str, - update_omero_metadata: bool = True, - add_image_ROI_table: bool = True, - add_grid_ROI_table: bool = True, - # Advanced parameters - grid_y_shape: int = 2, - grid_x_shape: int = 2, - overwrite: bool = False, -) -> dict[str, Any]: - """ - Import a single OME-Zarr into Fractal. - - The single OME-Zarr can be a full OME-Zarr HCS plate or an individual - OME-Zarr image. The image needs to be in the zarr_dir as specified by the - dataset. The current version of this task: - - 1. Creates the appropriate components-related metadata, needed for - processing an existing OME-Zarr through Fractal. - 2. Optionally adds new ROI tables to the existing OME-Zarr. - - Args: - zarr_urls: List of paths or urls to the individual OME-Zarr image to - be processed. Not used. - (standard argument for Fractal tasks, managed by Fractal server). - zarr_dir: path of the directory where the new OME-Zarrs will be - created. - (standard argument for Fractal tasks, managed by Fractal server). - zarr_name: The OME-Zarr name, without its parent folder. The parent - folder is provided by zarr_dir; e.g. `zarr_name="array.zarr"`, - if the OME-Zarr path is in `/zarr_dir/array.zarr`. - add_image_ROI_table: Whether to add a `image_ROI_table` table to each - image, with a single ROI covering the whole image. - add_grid_ROI_table: Whether to add a `grid_ROI_table` table to each - image, with the image split into a rectangular grid of ROIs. - grid_y_shape: Y shape of the ROI grid in `grid_ROI_table`. - grid_x_shape: X shape of the ROI grid in `grid_ROI_table`. - update_omero_metadata: Whether to update Omero-channels metadata, to - make them Fractal-compatible. - overwrite: Whether new ROI tables (added when `add_image_ROI_table` - and/or `add_grid_ROI_table` are `True`) can overwite existing ones. - """ - - # Is this based on the Zarr_dir or the zarr_urls? - if len(zarr_urls) > 0: - logger.warning( - "Running import while there are already items from the image list " - "provided to the task. The following inputs were provided: " - f"{zarr_urls=}" - "This task will not process the existing images, but look for " - f"zarr files named {zarr_name=} in the {zarr_dir=} instead." - ) - - zarr_path = f"{zarr_dir.rstrip('/')}/{zarr_name}" - logger.info(f"Zarr path: {zarr_path}") - - root_group = zarr.open_group(zarr_path, mode="r") - ngff_type = detect_ome_ngff_type(root_group) - grid_YX_shape = (grid_y_shape, grid_x_shape) - - image_list_updates = [] - if ngff_type == "plate": - for well in root_group.attrs["plate"]["wells"]: - well_path = well["path"] - - well_group = zarr.open_group(zarr_path, path=well_path, mode="r") - for image in well_group.attrs["well"]["images"]: - image_path = image["path"] - zarr_url = f"{zarr_path}/{well_path}/{image_path}" - types = _process_single_image( - zarr_url, - add_image_ROI_table, - add_grid_ROI_table, - update_omero_metadata, - grid_YX_shape=grid_YX_shape, - overwrite=overwrite, - ) - image_list_updates.append( - dict( - zarr_url=zarr_url, - attributes=dict( - plate=zarr_name, - well=well_path.replace("/", ""), - ), - types=types, - ) - ) - elif ngff_type == "well": - logger.warning( - "Only OME-Zarr for plates are fully supported in Fractal; " - f"e.g. the current one ({ngff_type=}) cannot be " - "processed via the `maximum_intensity_projection` task." - ) - for image in root_group.attrs["well"]["images"]: - image_path = image["path"] - zarr_url = f"{zarr_path}/{image_path}" - well_name = "".join(zarr_path.split("/")[-2:]) - types = _process_single_image( - zarr_url, - add_image_ROI_table, - add_grid_ROI_table, - update_omero_metadata, - grid_YX_shape=grid_YX_shape, - overwrite=overwrite, - ) - image_list_updates.append( - dict( - zarr_url=zarr_url, - attributes=dict( - well=well_name, - ), - types=types, - ) - ) - elif ngff_type == "image": - logger.warning( - "Only OME-Zarr for plates are fully supported in Fractal; " - f"e.g. the current one ({ngff_type=}) cannot be " - "processed via the `maximum_intensity_projection` task." - ) - zarr_url = zarr_path - types = _process_single_image( - zarr_url, - add_image_ROI_table, - add_grid_ROI_table, - update_omero_metadata, - grid_YX_shape=grid_YX_shape, - overwrite=overwrite, - ) - image_list_updates.append( - dict( - zarr_url=zarr_url, - types=types, - ) - ) - - image_list_changes = dict(image_list_updates=image_list_updates) - return image_list_changes - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=import_ome_zarr, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/init_group_by_well_for_multiplexing.py b/fractal_tasks_core/tasks/init_group_by_well_for_multiplexing.py deleted file mode 100644 index 1d0072671..000000000 --- a/fractal_tasks_core/tasks/init_group_by_well_for_multiplexing.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Joel Lüthi -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Applies the multiplexing translation to all ROI tables -""" -import logging - -from pydantic import validate_call - -from fractal_tasks_core.utils import ( - create_well_acquisition_dict, -) - -logger = logging.getLogger(__name__) - - -@validate_call -def init_group_by_well_for_multiplexing( - *, - # Fractal parameters - zarr_urls: list[str], - zarr_dir: str, - # Core parameters - reference_acquisition: int = 0, -) -> dict[str, list[str]]: - """ - Finds images for all acquisitions per well. - - Returns the parallelization_list to run `find_registration_consensus`. - - Args: - zarr_urls: List of paths or urls to the individual OME-Zarr image to - be processed. - (standard argument for Fractal tasks, managed by Fractal server). - zarr_dir: path of the directory where the new OME-Zarrs will be - created. Not used by this task. - (standard argument for Fractal tasks, managed by Fractal server). - reference_acquisition: Which acquisition to register against. Uses the - OME-NGFF HCS well metadata acquisition keys to find the reference - acquisition. - """ - logger.info( - f"Running `init_group_by_well_for_multiplexing` for {zarr_urls=}" - ) - image_groups = create_well_acquisition_dict(zarr_urls) - - # Create the parallelization list - parallelization_list = [] - for key, image_group in image_groups.items(): - # Assert that all image groups have the reference acquisition present - if reference_acquisition not in image_group.keys(): - raise ValueError( - f"Registration with {reference_acquisition=} can only work if " - "all wells have the reference acquisition present. It was not " - f"found for well {key}." - ) - - # Create a parallelization list entry for each image group - zarr_url_list = [] - for acquisition, zarr_url in image_group.items(): - if acquisition == reference_acquisition: - reference_zarr_url = zarr_url - - zarr_url_list.append(zarr_url) - - parallelization_list.append( - dict( - zarr_url=reference_zarr_url, - init_args=dict(zarr_url_list=zarr_url_list), - ) - ) - - return dict(parallelization_list=parallelization_list) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=init_group_by_well_for_multiplexing, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/io_models.py b/fractal_tasks_core/tasks/io_models.py deleted file mode 100644 index c061afc89..000000000 --- a/fractal_tasks_core/tasks/io_models.py +++ /dev/null @@ -1,185 +0,0 @@ -from typing import Literal -from typing import Optional - -from pydantic import BaseModel -from pydantic import Field -from pydantic import model_validator -from typing_extensions import Self - -from fractal_tasks_core.channels import ChannelInputModel -from fractal_tasks_core.channels import OmeroChannel - - -class InitArgsRegistration(BaseModel): - """ - Registration init args. - - Passed from `image_based_registration_hcs_init` to - `calculate_registration_image_based`. - - Attributes: - reference_zarr_url: zarr_url for the reference image - """ - - reference_zarr_url: str - - -class InitArgsRegistrationConsensus(BaseModel): - """ - Registration consensus init args. - - Provides the list of zarr_urls for all acquisitions for a given well - - Attributes: - zarr_url_list: List of zarr_urls for all the OME-Zarr images in the - well. - """ - - zarr_url_list: list[str] - - -class InitArgsCellVoyager(BaseModel): - """ - Arguments to be passed from cellvoyager converter init to compute - - Attributes: - image_dir: Directory where the raw images are found - plate_prefix: part of the image filename needed for finding the - right subset of image files - well_ID: part of the image filename needed for finding the - right subset of image files - image_extension: part of the image filename needed for finding the - right subset of image files - include_glob_patterns: Additional glob patterns to filter the available - images with. - exclude_glob_patterns: Glob patterns to exclude. - acquisition: Acquisition metadata needed for multiplexing - """ - - image_dir: str - plate_prefix: str - well_ID: str - image_extension: str - include_glob_patterns: Optional[list[str]] = None - exclude_glob_patterns: Optional[list[str]] = None - acquisition: Optional[int] = None - - -class InitArgsIllumination(BaseModel): - """ - Dummy model description. - - Attributes: - raw_path: dummy attribute description. - subsets: dummy attribute description. - """ - - raw_path: str - subsets: dict[Literal["C_index"], int] = Field(default_factory=dict) - - -class InitArgsMIP(BaseModel): - """ - Init Args for MIP task. - - Attributes: - origin_url: Path to the zarr_url with the 3D data - method: Projection method to be used. See `DaskProjectionMethod` - overwrite: If `True`, overwrite the task output. - new_plate_name: Name of the new OME-Zarr HCS plate - """ - - origin_url: str - method: str - overwrite: bool - new_plate_name: str - - -class MultiplexingAcquisition(BaseModel): - """ - Input class for Multiplexing Cellvoyager converter - - Attributes: - image_dir: Path to the folder that contains the Cellvoyager image - files for that acquisition and the MeasurementData & - MeasurementDetail metadata files. - allowed_channels: A list of `OmeroChannel` objects, where each channel - must include the `wavelength_id` attribute and where the - `wavelength_id` values must be unique across the list. - """ - - image_dir: str - allowed_channels: list[OmeroChannel] - - -class NapariWorkflowsOutput(BaseModel): - """ - A value of the `output_specs` argument in `napari_workflows_wrapper`. - - Attributes: - type: Output type (either `label` or `dataframe`). - label_name: Label name (for label outputs, it is used as the name of - the label; for dataframe outputs, it is used to fill the - `region["path"]` field). - table_name: Table name (for dataframe outputs only). - """ - - type: Literal["label", "dataframe"] - label_name: str - table_name: Optional[str] = None - - @model_validator(mode="after") - def table_name_only_for_dataframe_type(self: Self) -> Self: - """ - Check that table_name is set only for dataframe outputs. - """ - _type = self.type - _table_name = self.table_name - if (_type == "dataframe" and (not _table_name)) or ( - _type != "dataframe" and _table_name - ): - raise ValueError( - f"Output item has type={_type} but table_name={_table_name}." - ) - return self - - -class NapariWorkflowsInput(BaseModel): - """ - A value of the `input_specs` argument in `napari_workflows_wrapper`. - - Attributes: - type: Input type (either `image` or `label`). - label_name: Label name (for label inputs only). - channel: `ChannelInputModel` object (for image inputs only). - """ - - type: Literal["image", "label"] - label_name: Optional[str] = None - channel: Optional[ChannelInputModel] = None - - @model_validator(mode="after") - def label_name_is_present(self: Self) -> Self: - """ - Check that label inputs have `label_name` set. - """ - label_name = self.label_name - _type = self.type - if _type == "label" and label_name is None: - raise ValueError( - f"Input item has type={_type} but label_name={label_name}." - ) - return self - - @model_validator(mode="after") - def channel_is_present(self: Self) -> Self: - """ - Check that image inputs have `channel` set. - """ - _type = self.type - channel = self.channel - if _type == "image" and channel is None: - raise ValueError( - f"Input item has type={_type} but channel={channel}." - ) - return self diff --git a/fractal_tasks_core/tasks/napari_workflows_wrapper.py b/fractal_tasks_core/tasks/napari_workflows_wrapper.py deleted file mode 100644 index b381a5960..000000000 --- a/fractal_tasks_core/tasks/napari_workflows_wrapper.py +++ /dev/null @@ -1,638 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Marco Franzon -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Wrapper of napari-workflows. -""" -import logging -from typing import Any - -import anndata as ad -import dask.array as da -import napari_workflows -import numpy as np -import pandas as pd -import zarr -from napari_workflows._io_yaml_v1 import load_workflow -from pydantic import validate_call - -import fractal_tasks_core -from fractal_tasks_core.channels import get_channel_from_image_zarr -from fractal_tasks_core.labels import prepare_label_group -from fractal_tasks_core.ngff import load_NgffImageMeta -from fractal_tasks_core.pyramids import build_pyramid -from fractal_tasks_core.roi import check_valid_ROI_indices -from fractal_tasks_core.roi import ( - convert_ROI_table_to_indices, -) -from fractal_tasks_core.roi import load_region -from fractal_tasks_core.tables import write_table -from fractal_tasks_core.tasks.io_models import ( - NapariWorkflowsInput, -) -from fractal_tasks_core.tasks.io_models import ( - NapariWorkflowsOutput, -) -from fractal_tasks_core.upscale_array import upscale_array -from fractal_tasks_core.utils import rescale_datasets - - -__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ - - -logger = logging.getLogger(__name__) - - -class OutOfTaskScopeError(NotImplementedError): - """ - Encapsulates features that are out-of-scope for the current wrapper task. - """ - - pass - - -@validate_call -def napari_workflows_wrapper( - *, - # Fractal parameters - zarr_url: str, - # Core parameters - workflow_file: str, - input_specs: dict[str, NapariWorkflowsInput], - output_specs: dict[str, NapariWorkflowsOutput], - input_ROI_table: str = "FOV_ROI_table", - level: int = 0, - # Advanced parameters - relabeling: bool = True, - expected_dimensions: int = 3, - overwrite: bool = True, -): - """ - Run a napari-workflow on the ROIs of a single OME-NGFF image. - - This task takes images and labels and runs a napari-workflow on them that - can produce a label and tables as output. - - Examples of allowed entries for `input_specs` and `output_specs`: - - ``` - input_specs = { - "in_1": {"type": "image", "channel": {"wavelength_id": "A01_C02"}}, - "in_2": {"type": "image", "channel": {"label": "DAPI"}}, - "in_3": {"type": "label", "label_name": "label_DAPI"}, - } - - output_specs = { - "out_1": {"type": "label", "label_name": "label_DAPI_new"}, - "out_2": {"type": "dataframe", "table_name": "measurements"}, - } - ``` - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - (standard argument for Fractal tasks, managed by Fractal server). - workflow_file: Absolute path to napari-workflows YAML file - input_specs: A dictionary of `NapariWorkflowsInput` values. - output_specs: A dictionary of `NapariWorkflowsOutput` values. - input_ROI_table: Name of the ROI table over which the task loops to - apply napari workflows. - Examples: - `FOV_ROI_table` - => loop over the field of views; - `organoid_ROI_table` - => loop over the organoid ROI table (generated by another task); - `well_ROI_table` - => process the whole well as one image. - level: Pyramid level of the image to be used as input for - napari-workflows. Choose `0` to process at full resolution. - Levels > 0 are currently only supported for workflows that only - have intensity images as input and only produce a label images as - output. - relabeling: If `True`, apply relabeling so that label values are - unique across all ROIs in the well. - expected_dimensions: Expected dimensions (either `2` or `3`). Useful - when loading 2D images that are stored in a 3D array with shape - `(1, size_x, size_y)` [which is the default way Fractal stores 2D - images], but you want to make sure the napari workflow gets a 2D - array to process. Also useful to set to `2` when loading a 2D - OME-Zarr that is saved as `(size_x, size_y)`. - overwrite: If `True`, overwrite the task output. - """ - wf: napari_workflows.Worfklow = load_workflow(workflow_file) - logger.info(f"Loaded workflow from {workflow_file}") - - # Validation of input/output specs - if not (set(wf.leafs()) <= set(output_specs.keys())): - msg = f"Some item of {wf.leafs()=} is not part of {output_specs=}." - logger.warning(msg) - if not (set(wf.roots()) <= set(input_specs.keys())): - msg = f"Some item of {wf.roots()=} is not part of {input_specs=}." - logger.error(msg) - raise ValueError(msg) - list_outputs = sorted(output_specs.keys()) - - # Characterization of workflow and scope restriction - input_types = [in_params.type for (name, in_params) in input_specs.items()] - output_types = [ - out_params.type for (name, out_params) in output_specs.items() - ] - are_inputs_all_images = set(input_types) == {"image"} - are_outputs_all_labels = set(output_types) == {"label"} - are_outputs_all_dataframes = set(output_types) == {"dataframe"} - is_labeling_workflow = are_inputs_all_images and are_outputs_all_labels - is_measurement_only_workflow = are_outputs_all_dataframes - # Level-related constraint - logger.info(f"This workflow acts at {level=}") - logger.info( - f"Is the current workflow a labeling one? {is_labeling_workflow}" - ) - if level > 0 and not is_labeling_workflow: - msg = ( - f"{level=}>0 is currently only accepted for labeling workflows, " - "i.e. those going from image(s) to label(s)" - ) - logger.error(msg) - raise OutOfTaskScopeError(msg) - # Relabeling-related (soft) constraint - if is_measurement_only_workflow and relabeling: - logger.warning( - "This is a measurement-output-only workflow, setting " - "relabeling=False." - ) - relabeling = False - if relabeling: - max_label_for_relabeling = 0 - - label_dtype = np.uint32 - - # Read ROI table - ROI_table = ad.read_zarr(f"{zarr_url}/tables/{input_ROI_table}") - - # Load image metadata - ngff_image_meta = load_NgffImageMeta(zarr_url) - num_levels = ngff_image_meta.num_levels - coarsening_xy = ngff_image_meta.coarsening_xy - - # Read pixel sizes from zattrs file - full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) - - # Create list of indices for 3D FOVs spanning the entire Z direction - list_indices = convert_ROI_table_to_indices( - ROI_table, - level=level, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, - ) - check_valid_ROI_indices(list_indices, input_ROI_table) - num_ROIs = len(list_indices) - logger.info( - f"Completed reading ROI table {input_ROI_table}," - f" found {num_ROIs} ROIs." - ) - - # Input preparation: "image" type - image_inputs = [ - (name, in_params) - for (name, in_params) in input_specs.items() - if in_params.type == "image" - ] - input_image_arrays = {} - if image_inputs: - img_array = da.from_zarr(f"{zarr_url}/{level}") - # Loop over image inputs and assign corresponding channel of the image - for name, params in image_inputs: - channel = get_channel_from_image_zarr( - image_zarr_path=zarr_url, - wavelength_id=params.channel.wavelength_id, - label=params.channel.label, - ) - channel_index = channel.index - input_image_arrays[name] = img_array[channel_index] - - # Handle dimensions - shape = input_image_arrays[name].shape - if expected_dimensions == 3 and shape[0] == 1: - logger.warning( - f"Input {name} has shape {shape} " - f"but {expected_dimensions=}" - ) - if expected_dimensions == 2: - if len(shape) == 2: - # We already load the data as a 2D array - pass - elif shape[0] == 1: - input_image_arrays[name] = input_image_arrays[name][ - 0, :, : - ] - else: - msg = ( - f"Input {name} has shape {shape} " - f"but {expected_dimensions=}" - ) - logger.error(msg) - raise ValueError(msg) - logger.info(f"Prepared input with {name=} and {params=}") - logger.info(f"{input_image_arrays=}") - - # Input preparation: "label" type - label_inputs = [ - (name, in_params) - for (name, in_params) in input_specs.items() - if in_params.type == "label" - ] - if label_inputs: - # Set target_shape for upscaling labels - if not image_inputs: - logger.warning( - f"{len(label_inputs)=} but num_image_inputs=0. " - "Label array(s) will not be upscaled." - ) - upscale_labels = False - else: - target_shape = list(input_image_arrays.values())[0].shape - upscale_labels = True - # Loop over label inputs and load corresponding (upscaled) image - input_label_arrays = {} - for name, params in label_inputs: - label_name = params.label_name - label_array_raw = da.from_zarr( - f"{zarr_url}/labels/{label_name}/{level}" - ) - input_label_arrays[name] = label_array_raw - - # Handle dimensions - shape = input_label_arrays[name].shape - if expected_dimensions == 3 and shape[0] == 1: - logger.warning( - f"Input {name} has shape {shape} " - f"but {expected_dimensions=}" - ) - if expected_dimensions == 2: - if len(shape) == 2: - # We already load the data as a 2D array - pass - elif shape[0] == 1: - input_label_arrays[name] = input_label_arrays[name][ - 0, :, : - ] - else: - msg = ( - f"Input {name} has shape {shape} " - f"but {expected_dimensions=}" - ) - logger.error(msg) - raise ValueError(msg) - - if upscale_labels: - # Check that dimensionality matches the image - if len(input_label_arrays[name].shape) != len(target_shape): - raise ValueError( - f"Label {name} has shape " - f"{input_label_arrays[name].shape}. " - "But the corresponding image has shape " - f"{target_shape}. Those dimensionalities do not " - f"match. Is {expected_dimensions=} the correct " - "setting?" - ) - if expected_dimensions == 3: - upscaling_axes = [1, 2] - else: - upscaling_axes = [0, 1] - input_label_arrays[name] = upscale_array( - array=input_label_arrays[name], - target_shape=target_shape, - axis=upscaling_axes, - pad_with_zeros=True, - ) - - logger.info(f"Prepared input with {name=} and {params=}") - logger.info(f"{input_label_arrays=}") - - # Output preparation: "label" type - label_outputs = [ - (name, out_params) - for (name, out_params) in output_specs.items() - if out_params.type == "label" - ] - if label_outputs: - # Preliminary scope checks - if len(label_outputs) > 1: - raise OutOfTaskScopeError( - "Multiple label outputs would break label-inputs-only " - f"workflows (found {len(label_outputs)=})." - ) - if len(label_outputs) > 1 and relabeling: - raise OutOfTaskScopeError( - "Multiple label outputs would break relabeling in labeling+" - f"measurement workflows (found {len(label_outputs)=})." - ) - - # We only support two cases: - # 1. If there exist some input images, then use the first one to - # determine output-label array properties - # 2. If there are no input images, but there are input labels, then (A) - # re-load the pixel sizes and re-build ROI indices, and (B) use the - # first input label to determine output-label array properties - if image_inputs: - reference_array = list(input_image_arrays.values())[0] - elif label_inputs: - reference_array = list(input_label_arrays.values())[0] - # Re-load pixel size, matching to the correct level - input_label_name = label_inputs[0][1].label_name - ngff_label_image_meta = load_NgffImageMeta( - f"{zarr_url}/labels/{input_label_name}" - ) - full_res_pxl_sizes_zyx = ngff_label_image_meta.get_pixel_sizes_zyx( - level=0 - ) - # Create list of indices for 3D FOVs spanning the whole Z direction - list_indices = convert_ROI_table_to_indices( - ROI_table, - level=level, - coarsening_xy=coarsening_xy, - full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, - ) - check_valid_ROI_indices(list_indices, input_ROI_table) - num_ROIs = len(list_indices) - logger.info( - f"Re-create ROI indices from ROI table {input_ROI_table}, " - f"using {full_res_pxl_sizes_zyx=}. " - "This is necessary because label-input-only workflows may " - "have label inputs that are at a different resolution and " - "are not upscaled." - ) - else: - msg = ( - "Missing image_inputs and label_inputs, we cannot assign" - " label output properties" - ) - raise OutOfTaskScopeError(msg) - - # Extract label properties from reference_array, and make sure they are - # for three dimensions - label_shape = reference_array.shape - label_chunksize = reference_array.chunksize - if len(label_shape) == 2 and len(label_chunksize) == 2: - if expected_dimensions == 3: - raise ValueError( - f"Something wrong: {label_shape=} but " - f"{expected_dimensions=}" - ) - label_shape = (1, label_shape[0], label_shape[1]) - label_chunksize = (1, label_chunksize[0], label_chunksize[1]) - logger.info(f"{label_shape=}") - logger.info(f"{label_chunksize=}") - - # Loop over label outputs and (1) set zattrs, (2) create zarr group - output_label_zarr_groups: dict[str, Any] = {} - for name, out_params in label_outputs: - # (1a) Rescale OME-NGFF datasets (relevant for level>0) - if not ngff_image_meta.multiscale.axes[0].name == "c": - raise ValueError( - "Cannot set `remove_channel_axis=True` for multiscale " - f"metadata with axes={ngff_image_meta.multiscale.axes}. " - 'First axis should have name "c".' - ) - new_datasets = rescale_datasets( - datasets=[ - ds.model_dump() - for ds in ngff_image_meta.multiscale.datasets - ], - coarsening_xy=coarsening_xy, - reference_level=level, - remove_channel_axis=True, - ) - - # (1b) Prepare attrs for label group - label_name = out_params.label_name - label_attrs = { - "image-label": { - "version": __OME_NGFF_VERSION__, - "source": {"image": "../../"}, - }, - "multiscales": [ - { - "name": label_name, - "version": __OME_NGFF_VERSION__, - "axes": [ - ax.model_dump() - for ax in ngff_image_meta.multiscale.axes - if ax.type != "channel" - ], - "datasets": new_datasets, - } - ], - } - - # (2) Prepare label group - image_group = zarr.group(zarr_url) - label_group = prepare_label_group( - image_group, - label_name, - overwrite=overwrite, - label_attrs=label_attrs, - logger=logger, - ) - logger.info( - "Helper function `prepare_label_group` returned " - f"{label_group=}" - ) - - # (3) Create zarr group at level=0 - store = zarr.storage.FSStore(f"{zarr_url}/labels/{label_name}/0") - mask_zarr = zarr.create( - shape=label_shape, - chunks=label_chunksize, - dtype=label_dtype, - store=store, - overwrite=overwrite, - dimension_separator="/", - ) - output_label_zarr_groups[name] = mask_zarr - logger.info(f"Prepared output with {name=} and {out_params=}") - logger.info(f"{output_label_zarr_groups=}") - - # Output preparation: "dataframe" type - dataframe_outputs = [ - (name, out_params) - for (name, out_params) in output_specs.items() - if out_params.type == "dataframe" - ] - output_dataframe_lists: dict[str, list] = {} - for name, out_params in dataframe_outputs: - output_dataframe_lists[name] = [] - logger.info(f"Prepared output with {name=} and {out_params=}") - logger.info(f"{output_dataframe_lists=}") - - ##### - - for i_ROI, indices in enumerate(list_indices): - s_z, e_z, s_y, e_y, s_x, e_x = indices[:] - region = (slice(s_z, e_z), slice(s_y, e_y), slice(s_x, e_x)) - - logger.info(f"ROI {i_ROI+1}/{num_ROIs}: {region=}") - - # Always re-load napari worfklow - wf = load_workflow(workflow_file) - - # Set inputs - for input_name in input_specs.keys(): - input_type = input_specs[input_name].type - - if input_type == "image": - wf.set( - input_name, - load_region( - input_image_arrays[input_name], - region, - compute=True, - return_as_3D=False, - ), - ) - elif input_type == "label": - wf.set( - input_name, - load_region( - input_label_arrays[input_name], - region, - compute=True, - return_as_3D=False, - ), - ) - - # Get outputs - outputs = wf.get(list_outputs) - - # Iterate first over dataframe outputs (to use the correct - # max_label_for_relabeling, if needed) - for ind_output, output_name in enumerate(list_outputs): - if output_specs[output_name].type != "dataframe": - continue - df = outputs[ind_output] - if relabeling: - df["label"] += max_label_for_relabeling - logger.info( - f'ROI {i_ROI+1}/{num_ROIs}: Relabeling "{name}" dataframe' - "output, with {max_label_for_relabeling=}" - ) - - # Append the new-ROI dataframe to the all-ROIs list - output_dataframe_lists[output_name].append(df) - - # After all dataframe outputs, iterate over label outputs (which - # actually can be only 0 or 1) - for ind_output, output_name in enumerate(list_outputs): - if output_specs[output_name].type != "label": - continue - mask = outputs[ind_output] - - # Check dimensions - if len(mask.shape) != expected_dimensions: - msg = ( - f"Output {output_name} has shape {mask.shape} " - f"but {expected_dimensions=}" - ) - logger.error(msg) - raise ValueError(msg) - elif expected_dimensions == 2: - mask = np.expand_dims(mask, axis=0) - - # Sanity check: issue warning for non-consecutive labels - unique_labels = np.unique(mask) - num_unique_labels_in_this_ROI = len(unique_labels) - if np.min(unique_labels) == 0: - num_unique_labels_in_this_ROI -= 1 - num_labels_in_this_ROI = int(np.max(mask)) - if num_labels_in_this_ROI != num_unique_labels_in_this_ROI: - logger.warning( - f'ROI {i_ROI+1}/{num_ROIs}: "{name}" label output has' - f"non-consecutive labels: {num_labels_in_this_ROI=} but" - f"{num_unique_labels_in_this_ROI=}" - ) - - if relabeling: - mask[mask > 0] += max_label_for_relabeling - logger.info( - f'ROI {i_ROI+1}/{num_ROIs}: Relabeling "{name}" label ' - f"output, with {max_label_for_relabeling=}" - ) - max_label_for_relabeling += num_labels_in_this_ROI - logger.info( - f"ROI {i_ROI+1}/{num_ROIs}: label-number update with " - f"{num_labels_in_this_ROI=}; " - f"new {max_label_for_relabeling=}" - ) - - da.array(mask).to_zarr( - url=output_label_zarr_groups[output_name], - region=region, - compute=True, - overwrite=overwrite, - ) - logger.info(f"ROI {i_ROI+1}/{num_ROIs}: output handling complete") - - # Output handling: "dataframe" type (for each output, concatenate ROI - # dataframes, clean up, and store in a AnnData table on-disk) - for name, out_params in dataframe_outputs: - table_name = out_params.table_name - # Concatenate all FOV dataframes - list_dfs = output_dataframe_lists[name] - if len(list_dfs) == 0: - measurement_table = ad.AnnData() - else: - df_well = pd.concat(list_dfs, axis=0, ignore_index=True) - # Extract labels and drop them from df_well - labels = pd.DataFrame(df_well["label"].astype(str)) - df_well.drop(labels=["label"], axis=1, inplace=True) - # Convert all to float (warning: some would be int, in principle) - measurement_dtype = np.float32 - df_well = df_well.astype(measurement_dtype) - df_well.index = df_well.index.map(str) - # Convert to anndata - measurement_table = ad.AnnData(df_well, dtype=measurement_dtype) - measurement_table.obs = labels - - # Write to zarr group - image_group = zarr.group(zarr_url) - table_attrs = dict( - type="feature_table", - region=dict(path=f"../labels/{out_params.label_name}"), - instance_key="label", - ) - write_table( - image_group, - table_name, - measurement_table, - overwrite=overwrite, - table_attrs=table_attrs, - ) - - # Output handling: "label" type (for each output, build and write to disk - # pyramid of coarser levels) - for name, out_params in label_outputs: - label_name = out_params.label_name - build_pyramid( - zarrurl=f"{zarr_url}/labels/{label_name}", - overwrite=overwrite, - num_levels=num_levels, - coarsening_xy=coarsening_xy, - chunksize=label_chunksize, - aggregation_function=np.max, - ) - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=napari_workflows_wrapper, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/projection.py b/fractal_tasks_core/tasks/projection.py deleted file mode 100644 index 723cc1ba5..000000000 --- a/fractal_tasks_core/tasks/projection.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and -# University of Zurich -# -# Original authors: -# Tommaso Comparin -# Marco Franzon -# -# This file is part of Fractal and was originally developed by eXact lab S.r.l. -# under contract with Liberali Lab from the Friedrich Miescher -# Institute for Biomedical Research and Pelkmans Lab from the University of -# Zurich. -""" -Task for 3D->2D maximum-intensity projection. -""" -import logging -from typing import Any - -import dask.array as da -from ngio import NgffImage -from ngio.core import Image -from pydantic import validate_call - -from fractal_tasks_core.tasks.io_models import InitArgsMIP -from fractal_tasks_core.tasks.projection_utils import DaskProjectionMethod - -logger = logging.getLogger(__name__) - - -def _compute_new_shape(source_image: Image) -> tuple[int]: - """Compute the new shape of the image after the projection. - - The new shape is the same as the original one, - except for the z-axis, which is set to 1. - """ - on_disk_shape = source_image.on_disk_shape - logger.info(f"Source {on_disk_shape=}") - - on_disk_z_index = source_image.dataset.on_disk_axes_names.index("z") - - dest_on_disk_shape = list(on_disk_shape) - dest_on_disk_shape[on_disk_z_index] = 1 - logger.info(f"Destination {dest_on_disk_shape=}") - return tuple(dest_on_disk_shape) - - -@validate_call -def projection( - *, - # Fractal parameters - zarr_url: str, - init_args: InitArgsMIP, -) -> dict[str, Any]: - """ - Perform intensity projection along Z axis with a chosen method. - - Note: this task stores the output in a new zarr file. - - Args: - zarr_url: Path or url to the individual OME-Zarr image to be processed. - (standard argument for Fractal tasks, managed by Fractal server). - init_args: Intialization arguments provided by - `create_cellvoyager_ome_zarr_init`. - """ - method = DaskProjectionMethod(init_args.method) - logger.info(f"{init_args.origin_url=}") - logger.info(f"{zarr_url=}") - logger.info(f"{method=}") - - # Read image metadata - original_ngff_image = NgffImage(init_args.origin_url) - orginal_image = original_ngff_image.get_image() - - if orginal_image.is_2d or orginal_image.is_2d_time_series: - raise ValueError( - "The input image is 2D, " - "projection is only supported for 3D images." - ) - - # Compute the new shape and pixel size - dest_on_disk_shape = _compute_new_shape(orginal_image) - - dest_pixel_size = orginal_image.pixel_size - dest_pixel_size.z = 1.0 - logger.info(f"New shape: {dest_on_disk_shape=}") - - # Create the new empty image - new_ngff_image = original_ngff_image.derive_new_image( - store=zarr_url, - name="MIP", - on_disk_shape=dest_on_disk_shape, - pixel_sizes=dest_pixel_size, - overwrite=init_args.overwrite, - copy_labels=False, - copy_tables=True, - ) - logger.info(f"New Projection image created - {new_ngff_image=}") - new_image = new_ngff_image.get_image() - - # Process the image - z_axis_index = orginal_image.find_axis("z") - source_dask = orginal_image.get_array( - mode="dask", preserve_dimensions=True - ) - - dest_dask = method.apply(dask_array=source_dask, axis=z_axis_index) - dest_dask = da.expand_dims(dest_dask, axis=z_axis_index) - new_image.set_array(dest_dask) - new_image.consolidate() - # Ends - - # Copy over the tables - for roi_table_name in new_ngff_image.tables.list(table_type="roi_table"): - table = new_ngff_image.tables.get_table(roi_table_name) - - roi_list = [] - for roi in table.rois: - roi.z = 0.0 - roi.z_length = 1.0 - roi_list.append(roi) - - table.set_rois(roi_list, overwrite=True) - table.consolidate() - logger.info(f"Table {roi_table_name} Projection done") - - # Generate image_list_updates - image_list_update_dict = dict( - image_list_updates=[ - dict( - zarr_url=zarr_url, - origin=init_args.origin_url, - attributes=dict(plate=init_args.new_plate_name), - types=dict(is_3D=False), - ) - ] - ) - return image_list_update_dict - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task( - task_function=projection, - logger_name=logger.name, - ) diff --git a/fractal_tasks_core/tasks/projection_utils.py b/fractal_tasks_core/tasks/projection_utils.py deleted file mode 100644 index 4d71fa25a..000000000 --- a/fractal_tasks_core/tasks/projection_utils.py +++ /dev/null @@ -1,136 +0,0 @@ -from enum import Enum -from typing import Any -from typing import Dict - -import dask.array as da -import numpy as np - - -def safe_sum( - dask_array: da.Array, axis: int = 0, **kwargs: Dict[str, Any] -) -> da.Array: - """ - Perform a safe sum on a Dask array to avoid overflow, by clipping the - result of da.sum & casting it to its original dtype. - - Dask.array already correctly handles promotion to uin32 or uint64 when - necessary internally, but we want to ensure we clip the result. - - Args: - dask_array (dask.array.Array): The input Dask array. - axis (int, optional): The axis along which to sum the array. - Defaults to 0. - **kwargs: Additional keyword arguments passed to da.sum. - - Returns: - dask.array.Array: The result of the sum, safely clipped and cast - back to the original dtype. - """ - # Handle empty array - if any(dim == 0 for dim in dask_array.shape): - return dask_array - - # Determine the original dtype - original_dtype = dask_array.dtype - max_value = np.iinfo(original_dtype).max - - # Perform the sum - result = da.sum(dask_array, axis=axis, **kwargs) - - # Clip the values to the maximum possible value for the original dtype - result = da.clip(result, 0, max_value) - - # Cast back to the original dtype - result = result.astype(original_dtype) - - return result - - -def mean_wrapper( - dask_array: da.Array, axis: int = 0, **kwargs: Dict[str, Any] -) -> da.Array: - """ - Perform a da.mean on the dask_array & cast it to its original dtype. - - Without casting, the result can change dtype to e.g. float64 - - Args: - dask_array (dask.array.Array): The input Dask array. - axis (int, optional): The axis along which to mean the array. - Defaults to 0. - **kwargs: Additional keyword arguments passed to da.mean. - - Returns: - dask.array.Array: The result of the mean, cast back to the original - dtype. - """ - # Handle empty array - if any(dim == 0 for dim in dask_array.shape): - return dask_array - - # Determine the original dtype - original_dtype = dask_array.dtype - - # Perform the sum - result = da.mean(dask_array, axis=axis, **kwargs) - - # Cast back to the original dtype - result = result.astype(original_dtype) - - return result - - -class DaskProjectionMethod(Enum): - """ - Registration method selection - - Choose which method to use for intensity projection along the Z axis. - - Attributes: - MIP: Maximum intensity projection - MINIP: Minimum intensityp projection - MEANIP: Mean intensity projection - SUMIP: Sum intensityp projection - """ - - MIP = "mip" - MINIP = "minip" - MEANIP = "meanip" - SUMIP = "sumip" - - def apply( - self, dask_array: da.Array, axis: int = 0, **kwargs: Dict[str, Any] - ) -> da.Array: - """ - Apply the selected projection method to the given Dask array. - - Args: - dask_array (dask.array.Array): The Dask array to project. - axis (int): The axis along which to apply the projection. - **kwargs: Additional keyword arguments to pass to the projection - method. - - Returns: - dask.array.Array: The resulting Dask array after applying the - projection. - - Example: - >>> array = da.random.random((1000, 1000), chunks=(100, 100)) - >>> method = DaskProjectionMethod.MAX - >>> result = method.apply(array, axis=0) - >>> computed_result = result.compute() - >>> print(computed_result) - """ - # Map the Enum values to the actual Dask array methods - method_map = { - DaskProjectionMethod.MIP: lambda arr, axis, **kw: arr.max( - axis=axis, **kw - ), - DaskProjectionMethod.MINIP: lambda arr, axis, **kw: arr.min( - axis=axis, **kw - ), - DaskProjectionMethod.MEANIP: mean_wrapper, - DaskProjectionMethod.SUMIP: safe_sum, - } - # Call the appropriate method, passing in the dask_array explicitly - return method_map[self](dask_array, axis=axis, **kwargs)