From 06fc6416ce18eda265c90cf58a52802dd7005b54 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Fri, 6 Sep 2024 14:04:02 +0200 Subject: [PATCH] first implementation draft --- src/ilastik_tasks/dev/task_list.py | 4 +- ...astik_pixel_classification_segmentation.py | 448 ++++++++++++++++++ src/ilastik_tasks/thresholding_label_task.py | 192 -------- tests/test_thresholding_label_task.py | 8 +- 4 files changed, 454 insertions(+), 198 deletions(-) create mode 100644 src/ilastik_tasks/ilastik_pixel_classification_segmentation.py delete mode 100644 src/ilastik_tasks/thresholding_label_task.py diff --git a/src/ilastik_tasks/dev/task_list.py b/src/ilastik_tasks/dev/task_list.py index e9037e3..56685e0 100644 --- a/src/ilastik_tasks/dev/task_list.py +++ b/src/ilastik_tasks/dev/task_list.py @@ -5,7 +5,7 @@ TASK_LIST = [ ParallelTask( name="Thresholding Label Task", - executable="thresholding_label_task.py", - meta={"cpus_per_task": 1, "mem": 4000}, + executable="ilastik_pixel_classification_segmentation.py", + meta={"cpus_per_task": 8, "mem": 8000}, ), ] diff --git a/src/ilastik_tasks/ilastik_pixel_classification_segmentation.py b/src/ilastik_tasks/ilastik_pixel_classification_segmentation.py new file mode 100644 index 0000000..3d57068 --- /dev/null +++ b/src/ilastik_tasks/ilastik_pixel_classification_segmentation.py @@ -0,0 +1,448 @@ +"""Ilastik-based segmentation task for Fractal. + +Code adapted from: https://github.com/fractal-analytics-platform/fractal-tasks-core/blob/main/fractal_tasks_core/tasks/cellpose_segmentation.py + +Copyright 2022 (C) Friedrich Miescher Institute for Biomedical Research and +University of Zurich + +Original authors: + Tommaso Comparin + Marco Franzon + Joel Lüthi + +This file is part of Fractal and was originally developed by eXact lab S.r.l. + under contract with Liberali Lab from the Friedrich Miescher +Institute for Biomedical Research and Pelkmans Lab from the University of Zurich. + +Ilastik adaptation by: + Lorenzo Cerrone + Alexa McIntyre +""" + +import logging +from typing import Any, Optional + +import anndata as ad +import dask.array as da +import fractal_tasks_core +import numpy as np +import skimage.measure +import vigra +import zarr +from fractal_tasks_core.channels import ChannelInputModel, get_channel_from_image_zarr +from fractal_tasks_core.labels import prepare_label_group +from fractal_tasks_core.masked_loading import masked_loading_wrapper +from fractal_tasks_core.ngff import load_NgffImageMeta +from fractal_tasks_core.pyramids import build_pyramid +from fractal_tasks_core.roi import ( + array_to_bounding_box_table, + check_valid_ROI_indices, + convert_ROI_table_to_indices, + create_roi_table_from_df_list, + find_overlaps_in_ROI_indices, + get_overlapping_pairs_3D, + is_ROI_table_valid, + load_region, +) +from fractal_tasks_core.tables import write_table +from fractal_tasks_core.utils import rescale_datasets +from ilastik import app +from ilastik.applets.dataSelection.opDataSelection import ( + PreloadedArrayDatasetInfo, +) +from pydantic import validate_call + +logger = logging.getLogger(__name__) + +__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ + + +def seutp_ilastik(model_path: str): + """Setup Ilastik headless shell.""" + args = app.parse_args([]) + args.headless = True + args.project = model_path + args.readonly = True + shell = app.main(args) + return shell + + +def segment_ROI( + input_data: np.ndarray, + shell: Any, + threshold: float = 0.4, + min_size: int = 15, +) -> np.ndarray: + """Run the Ilastik model on a single ROI. + + Args: + input_data: Input data. Shape (z, y, x). + shell: Ilastik headless shell. + threshold: Threshold for the Ilastik model. + min_size: Minimum size for the Ilastik model. + + Returns: + np.ndarray: Segmented image. Shape (z, y, x). + """ + # run ilastik headless + logger.info(f"{input_data.shape=}") + + # reformat as tzyxc data expected by ilastik + + input_data = input_data[np.newaxis, :, :, :, np.newaxis] + logger.info(f"{input_data.shape=}") + + data = [ + { + "Raw Data": PreloadedArrayDatasetInfo( + preloaded_array=input_data, axistags=vigra.defaultAxistags("tzyxc") + ) + } + ] + ilastik_output = shell.workflow.batchProcessingApplet.run_export( + data, export_to_array=True + )[0] + logger.info(f"{ilastik_output.shape=}") + + # reformat to 2D + ilastik_output = np.reshape( + ilastik_output, (input_data.shape[1], input_data.shape[2], input_data.shape[3]) + ) + logger.info(f"{ilastik_output.shape=}") + + # take mask of regions above threshold + ilastik_output[ilastik_output < threshold] = 0 + ilastik_output[ilastik_output >= threshold] = 1 + + # label image + ilastik_labels = skimage.measure.label(ilastik_output) + + # remove objects below min_size - also removes anything with major or minor axis + # length of 0 for compatibility with current measurements task (01.24) + if min_size > 0: + label_props = skimage.measure.regionprops(ilastik_labels) + labels2remove = [ + label_props[i].label + for i in range(ilastik_labels.max()) + if (label_props[i].area < min_size) + or (label_props[i].axis_major_length < 1) + or (label_props[i].major_axis_length < 1) + ] + print(f"number of labels before filtering for size = {ilastik_labels.max()}") + ilastik_labels[np.isin(ilastik_labels, labels2remove)] = 0 + ilastik_labels = skimage.measure.label(ilastik_labels) + print(f"number of labels after filtering for size = {ilastik_labels.max()}") + label_props = skimage.measure.regionprops(ilastik_labels) + + return ilastik_labels + + +@validate_call +def ilastik_pixel_classification_segmentation( + *, + # Fractal parameters + zarr_url: str, + # Core parameters + level: int, + channel: ChannelInputModel, + ilastik_model: str, + input_ROI_table: str = "FOV_ROI_table", + output_ROI_table: Optional[str] = None, + output_label_name: Optional[str] = None, + # Cellpose-related arguments + threshold: float = 0.4, + min_size: int = 15, + use_masks: bool = True, + overwrite: bool = True, +) -> None: + """Run Ilastik Pixel Classification on a Zarr image. + + Args: + zarr_url: URL of the Zarr image. + level: Level of the Zarr image to process. + channel: Channel input model. + ilastik_model: Path to the Ilastik model. + input_ROI_table: Name of the input ROI table. + output_ROI_table: Name of the output ROI table. + output_label_name: Name of the output label. + threshold: Threshold for the Ilastik model. + min_size: Minimum size for the Ilastik model. + use_masks: Whether to use masks. + overwrite: Whether to overwrite existing data. + + """ + logger.info(f"Processing {zarr_url=}") + + # Read attributes from NGFF metadata + ngff_image_meta = load_NgffImageMeta(zarr_url) + num_levels = ngff_image_meta.num_levels + coarsening_xy = ngff_image_meta.coarsening_xy + full_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=0) + actual_res_pxl_sizes_zyx = ngff_image_meta.get_pixel_sizes_zyx(level=level) + logger.info(f"NGFF image has {num_levels=}") + logger.info(f"NGFF image has {coarsening_xy=}") + logger.info(f"NGFF image has full-res pixel sizes {full_res_pxl_sizes_zyx}") + logger.info( + f"NGFF image has level-{level} pixel sizes " f"{actual_res_pxl_sizes_zyx}" + ) + + # Setup Ilastik headless shell + shell = seutp_ilastik(ilastik_model) + + # Find channel index + omero_channel = get_channel_from_image_zarr( + image_zarr_path=zarr_url, + label=channel.label, + wavelength_id=channel.wavelength_id, + ) + if omero_channel: + ind_channel = omero_channel.index + else: + return + + # Set channel label + if output_label_name is None: + try: + channel_label = omero_channel.label + output_label_name = f"label_{channel_label}" + except (KeyError, IndexError): + output_label_name = f"label_{ind_channel}" + + # Load ZYX data + # Workaround for #788: Only load channel index when there is a channel + # dimension + if ngff_image_meta.axes_names[0] != "c": + data_zyx = da.from_zarr(f"{zarr_url}/{level}") + + else: + data_zyx = da.from_zarr(f"{zarr_url}/{level}")[ind_channel] + + logger.info(f"{data_zyx.shape=}") + + # Read ROI table + ROI_table_path = f"{zarr_url}/tables/{input_ROI_table}" + ROI_table = ad.read_zarr(ROI_table_path) + + # Perform some checks on the ROI table + valid_ROI_table = is_ROI_table_valid(table_path=ROI_table_path, use_masks=use_masks) + if use_masks and not valid_ROI_table: + logger.info( + f"ROI table at {ROI_table_path} cannot be used for masked " + "loading. Set use_masks=False." + ) + use_masks = False + logger.info(f"{use_masks=}") + + # Create list of indices for 3D ROIs spanning the entire Z direction + list_indices = convert_ROI_table_to_indices( + ROI_table, + level=level, + coarsening_xy=coarsening_xy, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) + check_valid_ROI_indices(list_indices, input_ROI_table) + + # If we are not planning to use masked loading, fail for overlapping ROIs + if not use_masks: + overlap = find_overlaps_in_ROI_indices(list_indices) + if overlap: + raise ValueError( + f"ROI indices created from {input_ROI_table} table have " + "overlaps, but we are not using masked loading." + ) + + # Rescale datasets (only relevant for level>0) + # Workaround for #788 + if ngff_image_meta.axes_names[0] != "c": + new_datasets = rescale_datasets( + datasets=[ds.model_dump() for ds in ngff_image_meta.datasets], + coarsening_xy=coarsening_xy, + reference_level=level, + remove_channel_axis=False, + ) + else: + new_datasets = rescale_datasets( + datasets=[ds.model_dump() for ds in ngff_image_meta.datasets], + coarsening_xy=coarsening_xy, + reference_level=level, + remove_channel_axis=True, + ) + + label_attrs = { + "image-label": { + "version": __OME_NGFF_VERSION__, + "source": {"image": "../../"}, + }, + "multiscales": [ + { + "name": output_label_name, + "version": __OME_NGFF_VERSION__, + "axes": [ + ax.dict() + for ax in ngff_image_meta.multiscale.axes + if ax.type != "channel" + ], + "datasets": new_datasets, + } + ], + } + + image_group = zarr.group(zarr_url) + label_group = prepare_label_group( + image_group, + output_label_name, + overwrite=overwrite, + label_attrs=label_attrs, + logger=logger, + ) + + logger.info(f"Helper function `prepare_label_group` returned {label_group=}") + logger.info(f"Output label path: {zarr_url}/labels/{output_label_name}/0") + store = zarr.storage.FSStore(f"{zarr_url}/labels/{output_label_name}/0") + label_dtype = np.uint32 + + # Ensure that all output shapes & chunks are 3D (for 2D data: (1, y, x)) + # https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/398 + shape = data_zyx.shape + if len(shape) == 2: + shape = (1, *shape) + chunks = data_zyx.chunksize + if len(chunks) == 2: + chunks = (1, *chunks) + mask_zarr = zarr.create( + shape=shape, + chunks=chunks, + dtype=label_dtype, + store=store, + overwrite=False, + dimension_separator="/", + ) + + logger.info( + f"mask will have shape {data_zyx.shape} " f"and chunks {data_zyx.chunks}" + ) + + # Initialize other things + logger.info(f"{data_zyx.shape}") + logger.info(f"{data_zyx.chunks}") + + # Counters for relabeling + num_labels_tot = {"num_labels_tot": 0} + + # Iterate over ROIs + num_ROIs = len(list_indices) + + if output_ROI_table: + bbox_dataframe_list = [] + + logger.info(f"Now starting loop over {num_ROIs} ROIs") + for i_ROI, indices in enumerate(list_indices): + # Define region + s_z, e_z, s_y, e_y, s_x, e_x = indices[:] + region = ( + slice(s_z, e_z), + slice(s_y, e_y), + slice(s_x, e_x), + ) + logger.info(f"Now processing ROI {i_ROI+1}/{num_ROIs}") + + # Prepare single-channel or dual-channel input for cellpose + img_np = load_region(data_zyx, region, compute=True, return_as_3D=True) + + # Prepare keyword arguments for segment_ROI function + kwargs_segment_ROI = { + "shell": shell, + "threshold": threshold, + "min_size": min_size, + } + + # Prepare keyword arguments for preprocessing function + preprocessing_kwargs = {} + if use_masks: + preprocessing_kwargs = { + "region": region, + "current_label_path": f"{zarr_url}/labels/{output_label_name}/0", + "ROI_table_path": ROI_table_path, + "ROI_positional_index": i_ROI, + } + + # Call segment_ROI through the masked-loading wrapper, which includes + # pre/post-processing functions if needed + new_label_img = masked_loading_wrapper( + image_array=img_np, + function=segment_ROI, + kwargs=kwargs_segment_ROI, + use_masks=use_masks, + preprocessing_kwargs=preprocessing_kwargs, + ) + + if output_ROI_table: + bbox_df = array_to_bounding_box_table( + new_label_img, + actual_res_pxl_sizes_zyx, + origin_zyx=(s_z, s_y, s_x), + ) + + bbox_dataframe_list.append(bbox_df) + + overlap_list = get_overlapping_pairs_3D(bbox_df, full_res_pxl_sizes_zyx) + if len(overlap_list) > 0: + logger.warning( + f"ROI {indices} has " + f"{len(overlap_list)} bounding-box pairs overlap" + ) + + # Compute and store 0-th level to disk + da.array(new_label_img).to_zarr( + url=mask_zarr, + region=region, + compute=True, + ) + + logger.info( + f"End cellpose_segmentation task for {zarr_url}, " "now building pyramids." + ) + + # Starting from on-disk highest-resolution data, build and write to disk a + # pyramid of coarser levels + build_pyramid( + zarrurl=f"{zarr_url}/labels/{output_label_name}", + overwrite=overwrite, + num_levels=num_levels, + coarsening_xy=coarsening_xy, + chunksize=chunks, + aggregation_function=np.max, + ) + + logger.info("End building pyramids") + + if output_ROI_table: + bbox_table = create_roi_table_from_df_list(bbox_dataframe_list) + + # Write to zarr group + image_group = zarr.group(zarr_url) + logger.info( + "Now writing bounding-box ROI table to " + f"{zarr_url}/tables/{output_ROI_table}" + ) + table_attrs = { + "type": "masking_roi_table", + "region": {"path": f"../labels/{output_label_name}"}, + "instance_key": "label", + } + write_table( + image_group, + output_ROI_table, + bbox_table, + overwrite=overwrite, + table_attrs=table_attrs, + ) + + +if __name__ == "__main__": + from fractal_tasks_core.tasks._utils import run_fractal_task + + run_fractal_task( + task_function=ilastik_pixel_classification_segmentation, + logger_name=logger.name, + ) diff --git a/src/ilastik_tasks/thresholding_label_task.py b/src/ilastik_tasks/thresholding_label_task.py deleted file mode 100644 index 9e9e017..0000000 --- a/src/ilastik_tasks/thresholding_label_task.py +++ /dev/null @@ -1,192 +0,0 @@ -"""This is the Python module for my_task.""" - -import logging -from typing import Any, Optional - -import dask.array as da -import fractal_tasks_core -import numpy as np -import zarr -from fractal_tasks_core.channels import ( - ChannelInputModel, - OmeroChannel, - get_channel_from_image_zarr, -) -from fractal_tasks_core.labels import prepare_label_group -from fractal_tasks_core.ngff import load_NgffImageMeta -from fractal_tasks_core.ngff.specs import NgffImageMeta -from fractal_tasks_core.pyramids import build_pyramid -from fractal_tasks_core.utils import rescale_datasets -from pydantic import validate_call -from skimage.measure import label -from skimage.morphology import ball, dilation, opening, remove_small_objects - -__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__ - - -@validate_call -def thresholding_label_task( - *, - zarr_url: str, - threshold: int, - channel: ChannelInputModel, - label_name: Optional[str] = None, - min_size: int = 50, - overwrite: bool = True, -) -> None: - """Threshold an image and find connected components. - - Args: - zarr_url: Absolute path to the OME-Zarr image. - threshold: Threshold value to be applied. - channel: Channel to be thresholded. - label_name: Name of the resulting label image - min_size: Minimum size of objects. Smaller objects are filtered out. - overwrite: Whether to overwrite an existing label image - """ - # Use the first of input_paths - logging.info(f"{zarr_url=}") - - # Parse and log several NGFF-image metadata attributes - ngff_image_meta = load_NgffImageMeta(zarr_url) - logging.info(f" Axes: {ngff_image_meta.axes_names}") - logging.info(f" Number of pyramid levels: {ngff_image_meta.num_levels}") - logging.info( - "Full-resolution ZYX pixel sizes (micrometer): " - f"{ngff_image_meta.get_pixel_sizes_zyx(level=0)}" - ) - - # Find the channel metadata - channel_model: OmeroChannel = get_channel_from_image_zarr( - image_zarr_path=zarr_url, - wavelength_id=channel.wavelength_id, - label=channel.label, - ) - - # Set label name - if not label_name: - label_name = f"{channel_model.label}_thresholded" - - # Load the highest-resolution multiscale array through dask.array - array_zyx = da.from_zarr(f"{zarr_url}/0")[channel_model.index] - logging.info(f"{array_zyx=}") - - # Process the image with an image processing approach of your choice - label_img = process_img( - array_zyx.compute(), - threshold=threshold, - min_size=min_size, - ) - - # Prepare label OME-Zarr - # If the resulting label image is of lower resolution than the intensity - # image, set the downsample variable to the number of downsamplings - # required (e.g. 2 if the image is downsampled 4x per axis with an - # ngff_image_meta.coarsening_xy of 2) - label_attrs = generate_label_attrs(ngff_image_meta, label_name, downsample=0) - label_group = prepare_label_group( - image_group=zarr.group(zarr_url), - label_name=label_name, - label_attrs=label_attrs, - overwrite=overwrite, - ) - # Write the processed array back to the same full-resolution Zarr array - label_group.create_dataset( - "0", - data=label_img, - overwrite=overwrite, - dimension_separator="/", - chunks=array_zyx.chunksize, - ) - - # Starting from on-disk full-resolution data, build and write to disk a - # pyramid of coarser levels - build_pyramid( - zarrurl=f"{zarr_url}/labels/{label_name}", - overwrite=True, - num_levels=ngff_image_meta.num_levels, - coarsening_xy=ngff_image_meta.coarsening_xy, - aggregation_function=np.max, - ) - - -def process_img(int_img: np.array, threshold: int, min_size: int = 50) -> np.array: - """Image processing function, to be replaced with your custom logic - - Numpy image & parameters in, label image out - - Args: - int_img: Intensity image as a numpy array - threshold: Thresholding value to binarize the image - min_size: Object size threshold for filtering - - Returns: - label_img: np.array - """ - # Thresholding the image - binary_img = int_img >= threshold - - # Removing small objects - cleaned_img = remove_small_objects(binary_img, min_size=min_size) - # Opening to separate touching objects - selem = ball(1) - opened_img = opening(cleaned_img, selem) - - # Optional: Dilation to restore object size - dilated_img = dilation(opened_img, selem) - - # Labeling the processed image - label_img = label(dilated_img, connectivity=1) - - return label_img - - -def generate_label_attrs( - ngff_image_meta: NgffImageMeta, label_name: str, downsample: int = 0 -) -> dict[str, Any]: - """Generates the label OME-zarr attrs based on the image metadata - - Args: - ngff_image_meta: image meta object for the corresponding NGFF image - label_name: name of the newly generated label - downsample: How many levels the label image is downsampled from the - ngff_image_meta image (0 for no downsampling, 1 for downsampling - once by the coarsening factor etc.) - - Returns: - label_attrs: Dict of new OME-Zarr label attrs - - """ - new_datasets = rescale_datasets( - datasets=[ - dataset.dict(exclude_none=True) for dataset in ngff_image_meta.datasets - ], - coarsening_xy=ngff_image_meta.coarsening_xy, - reference_level=downsample, - remove_channel_axis=True, - ) - label_attrs = { - "image-label": { - "version": __OME_NGFF_VERSION__, - "source": {"image": "../../"}, - }, - "multiscales": [ - { - "name": label_name, - "version": __OME_NGFF_VERSION__, - "axes": [ - ax.dict() - for ax in ngff_image_meta.multiscale.axes - if ax.type != "channel" - ], - "datasets": new_datasets, - } - ], - } - return label_attrs - - -if __name__ == "__main__": - from fractal_tasks_core.tasks._utils import run_fractal_task - - run_fractal_task(task_function=thresholding_label_task) diff --git a/tests/test_thresholding_label_task.py b/tests/test_thresholding_label_task.py index 6c9ca9c..3021b6e 100644 --- a/tests/test_thresholding_label_task.py +++ b/tests/test_thresholding_label_task.py @@ -5,7 +5,9 @@ from devtools import debug from fractal_tasks_core.channels import ChannelInputModel -from ilastik_tasks.thresholding_label_task import thresholding_label_task +from ilastik_tasks.ilastik_pixel_classification_segmentation import ( + thresholding_label_task, +) @pytest.fixture(scope="function") @@ -22,7 +24,5 @@ def test_data_dir(tmp_path: Path) -> str: def test_thresholding_label_task(test_data_dir): thresholding_label_task( - zarr_url=test_data_dir, - threshold=180, - channel=ChannelInputModel(label="DAPI") + zarr_url=test_data_dir, threshold=180, channel=ChannelInputModel(label="DAPI") )