Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option not to fill masks. Compression to TIFF files. #884

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,8 @@ def get_masks(p, iscell=None, rpad=20):
return M0

def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False, min_size=15,
resize=None, device=None):
flow_threshold=0.4, interp=True, do_3D=False, min_size=15, fill_holes=True,
area_threshold=None, resize=None, device=None):
"""Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.

Args:
Expand All @@ -770,6 +770,8 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
min_size (int, optional): The minimum size of the masks. Defaults to 15.
fill_holes (bool, optional): Whether to fill holes in the masks. Defaults to True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None.
resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
device (str, optional): The torch device to use for computation. Defaults to None.

Expand All @@ -779,7 +781,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold
mask, p = compute_masks(dP, cellprob, p=p, niter=niter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
min_size=min_size, device=device)
min_size=min_size, fill_holes=fill_holes, area_threshold=area_threshold, device=device)

if resize is not None:
mask = transforms.resize_image(mask, resize[0], resize[1],
Expand All @@ -794,7 +796,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold

def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False, min_size=15,
device=None):
fill_holes=True, area_threshold=None, device=None):
"""Compute masks using dynamics from dP and cellprob.

Args:
Expand All @@ -807,6 +809,8 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
min_size (int, optional): The minimum size of the masks. Defaults to 15.
fill_holes (bool, optional): Whether to fill holes in the masks. Defaults to True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None.
device (str, optional): The torch device to use for computation. Defaults to None.

Returns:
Expand Down Expand Up @@ -856,7 +860,8 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
p = np.zeros((len(shape), *shape), np.uint16)
return mask, p

mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size, fill_holes=fill_holes,
area_threshold=area_threshold)

if mask.dtype == np.uint32:
dynamics_logger.warning(
Expand Down
52 changes: 47 additions & 5 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""

import os, datetime, gc, warnings, glob, shutil
import os, datetime, gc, warnings, glob, shutil, json
from natsort import natsorted
import numpy as np
import cv2
Expand Down Expand Up @@ -86,6 +86,32 @@ def outlines_to_text(base, outlines):
f.write("\n")


def polygons_to_geojson(base, polygons) -> None:
"""
Create a geojson file from polygons.
Args:
base (str): base name of the file to save
polygons (list): list of polygons
Returns:
None
"""
geojson = {
"type": "FeatureCollection",
"features": []
}
for polygon in polygons:
geojson["features"].append({
"type": "Feature",
"geometry": {
"type": "Polygon",
"coordinates": polygon
},
"properties": {}
})
with open(base + "_cp_outlines.geojson", "w") as f:
json.dump(geojson, f)


def load_dax(filename):
### modified from ZhuangLab github:
### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
Expand Down Expand Up @@ -261,7 +287,7 @@ def imsave(filename, arr):
"""
ext = os.path.splitext(filename)[-1].lower()
if ext == ".tif" or ext == ".tiff":
tifffile.imwrite(filename, arr)
tifffile.imwrite(filename, arr, compression="zlib")
else:
if len(arr.shape) > 2:
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
Expand Down Expand Up @@ -595,7 +621,7 @@ def save_rois(masks, file_name):
def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0],
suffix="", save_flows=False, save_outlines=False,
dir_above=False, in_folders=False, savedir=None, save_txt=False,
save_mpl=False):
save_geojson=False, keep_holes=False, save_mpl=False):
""" Save masks + nicely plotted segmentation image to png and/or tiff.

Can save masks, flows to different directories, if in_folders is True.
Expand Down Expand Up @@ -623,6 +649,8 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[
in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False.
savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None.
save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False.
save_geojson (bool, optional): Save masks as geojson. Defaults to False.
keep_holes (bool, optional): Keep holes outlines inside polygons. Default is False.
save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D.
This takes a long time for large images. Defaults to False.

Expand All @@ -636,7 +664,7 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[
dir_above=dir_above, save_flows=save_flows,
save_outlines=save_outlines,
savedir=savedir, save_txt=save_txt, in_folders=in_folders,
save_mpl=save_mpl)
save_mpl=save_mpl, save_geojson=save_geojson)
return

if masks.ndim > 2 and not tif:
Expand Down Expand Up @@ -720,10 +748,24 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[
outlines = utils.outlines_list(masks)
outlines_to_text(os.path.join(txtdir, basename), outlines)

# QuPath geojson files
if masks.ndim < 3 and save_geojson:
polygons = utils.outlines_polygons(masks, keep_holes=keep_holes)
if polygons is not None:
polygons_to_geojson(os.path.join(txtdir, basename), polygons)

# RGB outline images
if masks.ndim < 3 and save_outlines:
check_dir(outlinedir)
outlines = utils.masks_to_outlines(masks)
polygons = utils.outlines_polygons(masks, keep_holes=True)
image_shape = images.shape[1:] if images.shape[0] < 4 else images.shape[:2]
outlines = np.zeros(shape=image_shape)
for polygon in polygons: # TODO: Little ad-hoc
for outline in polygon:
for coordinates in range(len(outline)):
x = outline[coordinates][0]
y = outline[coordinates][1]
outlines[int(y), int(x)] = 255
outX, outY = np.nonzero(outlines)
img0 = transforms.normalize99(images)
if img0.shape[0] < 4:
Expand Down
22 changes: 13 additions & 9 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile=True,
tile_overlap=0.1, bsize=224, interp=True, compute_masks=True,
progress=None):
tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, fill_holes=True,
area_threshold=None, progress=None):
""" segment list of images x, or 4D array - Z x nchan x Y x X

Args:
Expand Down Expand Up @@ -352,6 +352,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
bsize (int, optional): block size for tiles, recommended to keep at 224, like in training. Defaults to 224.
interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
compute_masks (bool, optional): Whether or not to compute dynamics and return masks. This is set to False when retrieving the styles for the size model. Defaults to True.
fill_holes (bool, optional): Whether or not to fill holes in masks. Defaults to True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None.
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.

Returns:
Expand Down Expand Up @@ -384,8 +386,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
tile_overlap=tile_overlap, bsize=bsize, resample=resample,
interp=interp, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, compute_masks=compute_masks,
min_size=min_size, stitch_threshold=stitch_threshold,
progress=progress, niter=niter)
min_size=min_size, fill_holes=fill_holes, stitch_threshold=stitch_threshold,
area_threshold=area_threshold, progress=progress, niter=niter)
masks.append(maski)
flows.append(flowi)
styles.append(stylei)
Expand All @@ -412,16 +414,16 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
rescale=rescale, resample=resample, augment=augment, tile=tile,
tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size,
do_3D=do_3D, anisotropy=anisotropy, niter=niter,
stitch_threshold=stitch_threshold)
do_3D=do_3D, anisotropy=anisotropy, niter=niter, fill_holes=fill_holes,
area_threshold=area_threshold, stitch_threshold=stitch_threshold)

flows = [plot.dx_to_circ(dP), dP, cellprob, p]
return masks, flows, styles

def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=None,
rescale=1.0, resample=True, augment=False, tile=True, tile_overlap=0.1,
cellprob_threshold=0.0, bsize=224, flow_threshold=0.4, min_size=15,
interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0):
interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0, fill_holes=True, area_threshold=None):

if isinstance(normalize, dict):
normalize_params = {**normalize_default, **normalize}
Expand Down Expand Up @@ -505,7 +507,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
masks, p = dynamics.resize_and_compute_masks(
dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
min_size=min_size, resize=None,
min_size=min_size, resize=None, fill_holes=fill_holes, area_threshold=area_threshold,
device=self.device if self.gpu else None)
else:
masks, p = [], []
Expand All @@ -524,6 +526,8 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
resize=resize,
min_size=min_size if stitch_threshold == 0 or nimg == 1 else
-1, # turn off for 3D stitching
fill_holes=fill_holes,
area_threshold=area_threshold,
device=self.device if self.gpu else None)
masks.append(outputs[0])
p.append(outputs[1])
Expand All @@ -536,7 +540,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
)
masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
masks = utils.fill_holes_and_remove_small_masks(
masks, min_size=min_size)
masks, min_size=min_size, fill_holes=fill_holes)
elif nimg > 1:
models_logger.warning("3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")

Expand Down
94 changes: 87 additions & 7 deletions cellpose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,77 @@ def outlines_list_multi(masks, num_processes=None):
outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks])
return outpix


def make_polygon(outline: np.ndarray, bb: tuple) -> list:
"""
Enclose a polygon by adding the first point to the end of the list.
Args:
outline (np.ndarray): A list of points in the polygon.
Returns:
polygon (list): The enclosed polygon.
"""
coordinates = outline + np.array([bb[1], bb[0]])
polygon = [list(map(float, point)) for point in coordinates]
if polygon[0] != polygon[-1]:
polygon.append(polygon[0])
return polygon


def get_polygon(outline: np.ndarray, bb: tuple, dim: int, keep_holes: bool = False) -> list:
"""
Compute contour contours from binary mask, translate to bounding box coordinates, and return as polygon.
Args:
outline (np.ndarray): Binary mask.
bb (tuple): Bounding box coordinates.
dim (int): In which dimension to look for the contour.
keep_holes (bool, optional): Whether to keep holes in the mask. Default is False.
Returns:
polygon (list): Polygon coordinates compatible with geojson format.
"""
cv2_method = cv2.RETR_TREE if keep_holes else cv2.RETR_EXTERNAL
contours, hierarchy = cv2.findContours(
image=outline,
mode=cv2_method,
method=cv2.CHAIN_APPROX_NONE,
)
outline = contours[dim].squeeze()
polygon = make_polygon(outline, bb)
polygon = [polygon]
if keep_holes and hierarchy.shape[1] > 1:
inner_contours = np.where(hierarchy[0, :, 3].squeeze() != -1)[0]
inner_polygons = []
for c_idx in inner_contours:
inner_outline = contours[c_idx].squeeze()
inner_polygon = make_polygon(inner_outline, bb)
inner_polygons.append(inner_polygon)
polygon.extend(inner_polygons)
return polygon


def outlines_polygons(masks: np.ndarray, keep_holes: bool = False) -> list:
"""
Get outlines of masks as polygons writing geojson.
Args:
masks (np.ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
keep_holes (bool, optional): Whether to keep holes in the mask. Default is False.
Returns:
polygons (list): List of polygons as pixel coordinates.
"""
polygons: list = []
objects = find_objects(masks)
for i, sl in enumerate(objects):
lb = i + 1
image = masks[sl] == lb
bbox = tuple([sl[i].start for i in range(masks.ndim)] + [sl[i].stop for i in range(masks.ndim)])
outline = image.astype(np.uint8)
try:
polygon = get_polygon(outline, bbox, 0, keep_holes)
except TypeError:
polygon = get_polygon(outline, bbox, 1, keep_holes)
polygons.append(polygon)
return polygons


def get_outline_multi(args):
"""Get the outline of a specific mask in a multi-mask image.

Expand Down Expand Up @@ -611,7 +682,7 @@ def size_distribution(masks):
counts = np.unique(masks, return_counts=True)[1][1:]
return np.percentile(counts, 25) / np.percentile(counts, 75)

def fill_holes_and_remove_small_masks(masks, min_size=15):
def fill_holes_and_remove_small_masks(masks, min_size=15, fill_holes=True, area_threshold=None):
""" Fills holes in masks (2D/3D) and discards masks smaller than min_size.

This function fills holes in each mask using scipy.ndimage.morphology.binary_fill_holes.
Expand All @@ -624,7 +695,9 @@ def fill_holes_and_remove_small_masks(masks, min_size=15):
min_size (int, optional): Minimum number of pixels per mask.
Masks smaller than min_size will be removed.
Set to -1 to turn off this functionality. Default is 15.

fill_holes (bool, optional): Whether to fill holes in masks. Default is True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold.
If None or SKIMAGE_ENABLED is False, fills all holes. Default is None.
Returns:
ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
0 represents no mask, while positive integers represent mask labels.
Expand All @@ -644,11 +717,18 @@ def fill_holes_and_remove_small_masks(masks, min_size=15):
if min_size > 0 and npix < min_size:
masks[slc][msk] = 0
elif npix > 0:
if msk.ndim == 3:
for k in range(msk.shape[0]):
msk[k] = binary_fill_holes(msk[k])
else:
msk = binary_fill_holes(msk)
if fill_holes:
if msk.ndim == 3:
for k in range(msk.shape[0]):
if area_threshold is not None and SKIMAGE_ENABLED:
msk[k] = remove_small_holes(msk[k], area_threshold=area_threshold)
else:
msk[k] = binary_fill_holes(msk[k])
else:
if area_threshold is not None and SKIMAGE_ENABLED:
msk = remove_small_holes(msk, area_threshold=area_threshold)
else:
msk = binary_fill_holes(msk)
masks[slc][msk] = (j + 1)
j += 1
return masks