diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index acdf1492..d0888e0f 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -756,8 +756,8 @@ def get_masks(p, iscell=None, rpad=20): return M0 def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, - flow_threshold=0.4, interp=True, do_3D=False, min_size=15, - resize=None, device=None): + flow_threshold=0.4, interp=True, do_3D=False, min_size=15, fill_holes=True, + area_threshold=None, resize=None, device=None): """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None. Args: @@ -770,6 +770,8 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. min_size (int, optional): The minimum size of the masks. Defaults to 15. + fill_holes (bool, optional): Whether to fill holes in the masks. Defaults to True. + area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None. resize (tuple, optional): The desired size for resizing the masks. Defaults to None. device (str, optional): The torch device to use for computation. Defaults to None. @@ -779,7 +781,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold mask, p = compute_masks(dP, cellprob, p=p, niter=niter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, interp=interp, do_3D=do_3D, - min_size=min_size, device=device) + min_size=min_size, fill_holes=fill_holes, area_threshold=area_threshold, device=device) if resize is not None: mask = transforms.resize_image(mask, resize[0], resize[1], @@ -794,7 +796,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, flow_threshold=0.4, interp=True, do_3D=False, min_size=15, - device=None): + fill_holes=True, area_threshold=None, device=None): """Compute masks using dynamics from dP and cellprob. Args: @@ -807,6 +809,8 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. min_size (int, optional): The minimum size of the masks. Defaults to 15. + fill_holes (bool, optional): Whether to fill holes in the masks. Defaults to True. + area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None. device (str, optional): The torch device to use for computation. Defaults to None. Returns: @@ -856,7 +860,8 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, p = np.zeros((len(shape), *shape), np.uint16) return mask, p - mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size) + mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size, fill_holes=fill_holes, + area_threshold=area_threshold) if mask.dtype == np.uint32: dynamics_logger.warning( diff --git a/cellpose/io.py b/cellpose/io.py index 2abcd6b4..ad82c6ac 100644 --- a/cellpose/io.py +++ b/cellpose/io.py @@ -2,7 +2,7 @@ Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. """ -import os, datetime, gc, warnings, glob, shutil +import os, datetime, gc, warnings, glob, shutil, json from natsort import natsorted import numpy as np import cv2 @@ -86,6 +86,32 @@ def outlines_to_text(base, outlines): f.write("\n") +def polygons_to_geojson(base, polygons) -> None: + """ + Create a geojson file from polygons. + Args: + base (str): base name of the file to save + polygons (list): list of polygons + Returns: + None + """ + geojson = { + "type": "FeatureCollection", + "features": [] + } + for polygon in polygons: + geojson["features"].append({ + "type": "Feature", + "geometry": { + "type": "Polygon", + "coordinates": polygon + }, + "properties": {} + }) + with open(base + "_cp_outlines.geojson", "w") as f: + json.dump(geojson, f) + + def load_dax(filename): ### modified from ZhuangLab github: ### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156 @@ -261,7 +287,7 @@ def imsave(filename, arr): """ ext = os.path.splitext(filename)[-1].lower() if ext == ".tif" or ext == ".tiff": - tifffile.imwrite(filename, arr) + tifffile.imwrite(filename, arr, compression="zlib") else: if len(arr.shape) > 2: arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) @@ -595,7 +621,7 @@ def save_rois(masks, file_name): def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0], suffix="", save_flows=False, save_outlines=False, dir_above=False, in_folders=False, savedir=None, save_txt=False, - save_mpl=False): + save_geojson=False, keep_holes=False, save_mpl=False): """ Save masks + nicely plotted segmentation image to png and/or tiff. Can save masks, flows to different directories, if in_folders is True. @@ -623,6 +649,8 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[ in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False. savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None. save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False. + save_geojson (bool, optional): Save masks as geojson. Defaults to False. + keep_holes (bool, optional): Keep holes outlines inside polygons. Default is False. save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D. This takes a long time for large images. Defaults to False. @@ -636,7 +664,7 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[ dir_above=dir_above, save_flows=save_flows, save_outlines=save_outlines, savedir=savedir, save_txt=save_txt, in_folders=in_folders, - save_mpl=save_mpl) + save_mpl=save_mpl, save_geojson=save_geojson) return if masks.ndim > 2 and not tif: @@ -720,10 +748,24 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[ outlines = utils.outlines_list(masks) outlines_to_text(os.path.join(txtdir, basename), outlines) + # QuPath geojson files + if masks.ndim < 3 and save_geojson: + polygons = utils.outlines_polygons(masks, keep_holes=keep_holes) + if polygons is not None: + polygons_to_geojson(os.path.join(txtdir, basename), polygons) + # RGB outline images if masks.ndim < 3 and save_outlines: check_dir(outlinedir) - outlines = utils.masks_to_outlines(masks) + polygons = utils.outlines_polygons(masks, keep_holes=True) + image_shape = images.shape[1:] if images.shape[0] < 4 else images.shape[:2] + outlines = np.zeros(shape=image_shape) + for polygon in polygons: # TODO: Little ad-hoc + for outline in polygon: + for coordinates in range(len(outline)): + x = outline[coordinates][0] + y = outline[coordinates][1] + outlines[int(y), int(x)] = 255 outX, outY = np.nonzero(outlines) img0 = transforms.normalize99(images) if img0.shape[0] < 4: diff --git a/cellpose/models.py b/cellpose/models.py index 8fad3e7f..334440cc 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -305,8 +305,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile=True, - tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, - progress=None): + tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, fill_holes=True, + area_threshold=None, progress=None): """ segment list of images x, or 4D array - Z x nchan x Y x X Args: @@ -352,6 +352,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, bsize (int, optional): block size for tiles, recommended to keep at 224, like in training. Defaults to 224. interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True. compute_masks (bool, optional): Whether or not to compute dynamics and return masks. This is set to False when retrieving the styles for the size model. Defaults to True. + fill_holes (bool, optional): Whether or not to fill holes in masks. Defaults to True. + area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None. progress (QProgressBar, optional): pyqt progress bar. Defaults to None. Returns: @@ -384,8 +386,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, tile_overlap=tile_overlap, bsize=bsize, resample=resample, interp=interp, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, compute_masks=compute_masks, - min_size=min_size, stitch_threshold=stitch_threshold, - progress=progress, niter=niter) + min_size=min_size, fill_holes=fill_holes, stitch_threshold=stitch_threshold, + area_threshold=area_threshold, progress=progress, niter=niter) masks.append(maski) flows.append(flowi) styles.append(stylei) @@ -412,8 +414,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, rescale=rescale, resample=resample, augment=augment, tile=tile, tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size, - do_3D=do_3D, anisotropy=anisotropy, niter=niter, - stitch_threshold=stitch_threshold) + do_3D=do_3D, anisotropy=anisotropy, niter=niter, fill_holes=fill_holes, + area_threshold=area_threshold, stitch_threshold=stitch_threshold) flows = [plot.dx_to_circ(dP), dP, cellprob, p] return masks, flows, styles @@ -421,7 +423,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=None, rescale=1.0, resample=True, augment=False, tile=True, tile_overlap=0.1, cellprob_threshold=0.0, bsize=224, flow_threshold=0.4, min_size=15, - interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0): + interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0, fill_holes=True, area_threshold=None): if isinstance(normalize, dict): normalize_params = {**normalize_default, **normalize} @@ -505,7 +507,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non masks, p = dynamics.resize_and_compute_masks( dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, interp=interp, do_3D=do_3D, - min_size=min_size, resize=None, + min_size=min_size, resize=None, fill_holes=fill_holes, area_threshold=area_threshold, device=self.device if self.gpu else None) else: masks, p = [], [] @@ -524,6 +526,8 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non resize=resize, min_size=min_size if stitch_threshold == 0 or nimg == 1 else -1, # turn off for 3D stitching + fill_holes=fill_holes, + area_threshold=area_threshold, device=self.device if self.gpu else None) masks.append(outputs[0]) p.append(outputs[1]) @@ -536,7 +540,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non ) masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold) masks = utils.fill_holes_and_remove_small_masks( - masks, min_size=min_size) + masks, min_size=min_size, fill_holes=fill_holes) elif nimg > 1: models_logger.warning("3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only") diff --git a/cellpose/utils.py b/cellpose/utils.py index 207b6783..4b8445ca 100644 --- a/cellpose/utils.py +++ b/cellpose/utils.py @@ -288,6 +288,77 @@ def outlines_list_multi(masks, num_processes=None): outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks]) return outpix + +def make_polygon(outline: np.ndarray, bb: tuple) -> list: + """ + Enclose a polygon by adding the first point to the end of the list. + Args: + outline (np.ndarray): A list of points in the polygon. + Returns: + polygon (list): The enclosed polygon. + """ + coordinates = outline + np.array([bb[1], bb[0]]) + polygon = [list(map(float, point)) for point in coordinates] + if polygon[0] != polygon[-1]: + polygon.append(polygon[0]) + return polygon + + +def get_polygon(outline: np.ndarray, bb: tuple, dim: int, keep_holes: bool = False) -> list: + """ + Compute contour contours from binary mask, translate to bounding box coordinates, and return as polygon. + Args: + outline (np.ndarray): Binary mask. + bb (tuple): Bounding box coordinates. + dim (int): In which dimension to look for the contour. + keep_holes (bool, optional): Whether to keep holes in the mask. Default is False. + Returns: + polygon (list): Polygon coordinates compatible with geojson format. + """ + cv2_method = cv2.RETR_TREE if keep_holes else cv2.RETR_EXTERNAL + contours, hierarchy = cv2.findContours( + image=outline, + mode=cv2_method, + method=cv2.CHAIN_APPROX_NONE, + ) + outline = contours[dim].squeeze() + polygon = make_polygon(outline, bb) + polygon = [polygon] + if keep_holes and hierarchy.shape[1] > 1: + inner_contours = np.where(hierarchy[0, :, 3].squeeze() != -1)[0] + inner_polygons = [] + for c_idx in inner_contours: + inner_outline = contours[c_idx].squeeze() + inner_polygon = make_polygon(inner_outline, bb) + inner_polygons.append(inner_polygon) + polygon.extend(inner_polygons) + return polygon + + +def outlines_polygons(masks: np.ndarray, keep_holes: bool = False) -> list: + """ + Get outlines of masks as polygons writing geojson. + Args: + masks (np.ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + keep_holes (bool, optional): Whether to keep holes in the mask. Default is False. + Returns: + polygons (list): List of polygons as pixel coordinates. + """ + polygons: list = [] + objects = find_objects(masks) + for i, sl in enumerate(objects): + lb = i + 1 + image = masks[sl] == lb + bbox = tuple([sl[i].start for i in range(masks.ndim)] + [sl[i].stop for i in range(masks.ndim)]) + outline = image.astype(np.uint8) + try: + polygon = get_polygon(outline, bbox, 0, keep_holes) + except TypeError: + polygon = get_polygon(outline, bbox, 1, keep_holes) + polygons.append(polygon) + return polygons + + def get_outline_multi(args): """Get the outline of a specific mask in a multi-mask image. @@ -611,7 +682,7 @@ def size_distribution(masks): counts = np.unique(masks, return_counts=True)[1][1:] return np.percentile(counts, 25) / np.percentile(counts, 75) -def fill_holes_and_remove_small_masks(masks, min_size=15): +def fill_holes_and_remove_small_masks(masks, min_size=15, fill_holes=True, area_threshold=None): """ Fills holes in masks (2D/3D) and discards masks smaller than min_size. This function fills holes in each mask using scipy.ndimage.morphology.binary_fill_holes. @@ -624,7 +695,9 @@ def fill_holes_and_remove_small_masks(masks, min_size=15): min_size (int, optional): Minimum number of pixels per mask. Masks smaller than min_size will be removed. Set to -1 to turn off this functionality. Default is 15. - + fill_holes (bool, optional): Whether to fill holes in masks. Default is True. + area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. + If None or SKIMAGE_ENABLED is False, fills all holes. Default is None. Returns: ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed. 0 represents no mask, while positive integers represent mask labels. @@ -644,11 +717,18 @@ def fill_holes_and_remove_small_masks(masks, min_size=15): if min_size > 0 and npix < min_size: masks[slc][msk] = 0 elif npix > 0: - if msk.ndim == 3: - for k in range(msk.shape[0]): - msk[k] = binary_fill_holes(msk[k]) - else: - msk = binary_fill_holes(msk) + if fill_holes: + if msk.ndim == 3: + for k in range(msk.shape[0]): + if area_threshold is not None and SKIMAGE_ENABLED: + msk[k] = remove_small_holes(msk[k], area_threshold=area_threshold) + else: + msk[k] = binary_fill_holes(msk[k]) + else: + if area_threshold is not None and SKIMAGE_ENABLED: + msk = remove_small_holes(msk, area_threshold=area_threshold) + else: + msk = binary_fill_holes(msk) masks[slc][msk] = (j + 1) j += 1 return masks