Skip to content

Commit

Permalink
Merge pull request #857 from MouseLand/3d
Browse files Browse the repository at this point in the history
cellpose3 merge
  • Loading branch information
carsen-stringer authored Feb 14, 2024
2 parents 28ddc24 + 13bc98d commit ffbf8d0
Show file tree
Hide file tree
Showing 43 changed files with 7,941 additions and 3,523 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,21 @@ Please see install instructions [below](README.md/#Installation), and also check
**If you use Cellpose 1 or 2, please cite the Cellpose 1.0 [paper](https://t.co/kBMXmPp3Yn?amp=1):**
Stringer, C., Wang, T., Michaelos, M., & Pachitariu, M. (2021). Cellpose: a generalist algorithm for cellular segmentation. <em>Nature methods, 18</em>(1), 100-106.

**If you use the new human-in-the-loop training or use the new cyto2, livecell, or tissuenet models, please also cite the Cellpose 2.0 [paper](https://www.nature.com/articles/s41592-022-01663-4):**
**If you use the new human-in-the-loop training, please also cite the Cellpose 2.0 [paper](https://www.nature.com/articles/s41592-022-01663-4):**
Pachitariu, M. & Stringer, C. (2022). Cellpose 2.0: how to train your own model. <em>Nature methods</em>, 1-8.

:triangular_flag_on_post: the new tissuenet and livecell models (`tissuenet`, `TN1`, `TN2`, `TN3`, `livecell`, `LC1`, `LC2`, `LC3` and `LC4`) were trained using data under a **CC-BY-NC** license, so these models are **non-commercial use only**.
**If you use the restoration models, please also cite the Cellpose3 [paper](https://www.biorxiv.org/content/10.1101/2024.02.10.579780v1):**
Stringer, C. & Pachitariu, M. (2024). Cellpose3: one-click image restoration for improved segmentation. <em>bioRxiv</em>.

:triangular_flag_on_post: All models in Cellpose, except `yeast_BF_cp3`, `yeast_PhC_cp3`, and `deepbacs_cp3`, are trained on some amount of data that is **CC-BY-NC**. The Cellpose annotated dataset is also CC-BY-NC.

### :star2: v3 (Feb 2024) :star2:

Cellpose3 enables image restoration in the GUI and the API (CLI support and example notebooks coming soon!) To learn more...
* Check out the paper [thread]().
* Check out the [paper](https://www.biorxiv.org/content/10.1101/2024.02.10.579780v1).

Try out the new `cyto3` super-generalist Cellpose model with `model_type="cyto3"`. There are some Cellpose API changes from v2.0 which will be documented soon.

### :star2: v2.0 (April 2022) :star2:

Expand Down
73 changes: 27 additions & 46 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""

import sys, os, glob, pathlib, time
import numpy as np
from natsort import natsorted
from tqdm import tqdm
from cellpose import utils, models, io, core, version_str
from cellpose import utils, models, io, version_str, train
from cellpose.cli import get_arg_parser

try:
from cellpose.gui import gui
from cellpose.gui import gui3d, gui
GUI_ENABLED = True
except ImportError as err:
GUI_ERROR = err
Expand Down Expand Up @@ -50,14 +50,17 @@ def main():
print('GUI FAILED: GUI dependencies may not be installed, to install, run')
print(' pip install "cellpose[gui]"')
else:
gui.run()
if args.Zstack:
gui3d.run()
else:
gui.run()

else:
if args.verbose:
from .io import logger_setup
logger, log_file = logger_setup()
else:
print('>>>> !NEW LOGGING SETUP! To see cellpose progress, set --verbose')
print('>>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose')
print('No --verbose => no progress or info printed')
logger = logging.getLogger(__name__)

Expand All @@ -70,7 +73,6 @@ def main():
else:
imf = None


# Check with user if they REALLY mean to run without saving anything
if not (args.train or args.train_size):
saving_something = args.save_png or args.save_tif or args.save_flows or args.save_ncolor or args.save_txt
Expand All @@ -91,12 +93,11 @@ def main():
if ~np.any([model_type == s for s in all_models]):
model_type = 'cyto'
logger.warning('pretrained model has incorrect path')

if model_type=='nuclei':
szmean = 17.
else:
szmean = 30.
builtin_size = model_type == 'cyto' or model_type == 'cyto2' or model_type == 'nuclei'
builtin_size = model_type == 'cyto' or model_type == 'cyto2' or model_type == 'nuclei' or model_type=="cyto3"

if len(args.image_path) > 0 and (args.train or args.train_size):
raise ValueError('ERROR: cannot train model with single image input')
Expand All @@ -122,17 +123,14 @@ def main():

# handle built-in model exceptions; bacterial ones get no size model
if builtin_size:
model = models.Cellpose(gpu=gpu, device=device, model_type=model_type,
net_avg=(not args.fast_mode or args.net_avg))

model = models.Cellpose(gpu=gpu, device=device, model_type=model_type)
else:
if args.all_channels:
channels = None
pretrained_model = None if model_type is not None else pretrained_model
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type,
net_avg=False)
model_type=model_type)

# handle diameters
if args.diameter==0:
Expand All @@ -154,9 +152,8 @@ def main():
image = io.imread(image_name)
out = model.eval(image, channels=channels, diameter=diameter,
do_3D=args.do_3D,
net_avg=(not args.fast_mode or args.net_avg),
augment=args.augment,
resample=(not args.no_resample and not args.fast_mode),
resample=(not args.no_resample),
flow_threshold=args.flow_threshold,
cellprob_threshold=args.cellprob_threshold,
stitch_threshold=args.stitch_threshold,
Expand All @@ -167,8 +164,7 @@ def main():
normalize=(not args.no_norm),
channel_axis=args.channel_axis,
z_axis=args.z_axis,
anisotropy=args.anisotropy,
model_loaded=True)
anisotropy=args.anisotropy)
masks, flows = out[:2]
if len(out) > 3:
diams = out[-1]
Expand All @@ -177,7 +173,7 @@ def main():
if args.exclude_on_edges:
masks = utils.remove_edge_masks(masks)
if not args.no_npy:
io.masks_flows_to_seg(image, masks, flows, diams, image_name, channels)
io.masks_flows_to_seg(image, masks, flows, image_name, channels=channels, diams=diams)
if saving_something:
io.save_masks(image, masks, flows, image_name, png=args.save_png, tif=args.save_tif,
save_flows=args.save_flows,save_outlines=args.save_outlines,
Expand All @@ -189,7 +185,7 @@ def main():
else:

test_dir = None if len(args.test_dir)==0 else args.test_dir
output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter, args.unet, args.look_one_level_down)
output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter, args.look_one_level_down)
images, labels, image_names, test_images, test_labels, image_names_test = output

# training with all channels
Expand All @@ -202,7 +198,6 @@ def main():
channels = None
else:
nchan = 2


# model path
szmean = args.diam_mean
Expand All @@ -213,39 +208,26 @@ def main():
raise ValueError(error_message)
pretrained_model = False
logger.info('>>>> training from scratch')

if args.train:
logger.info('>>>> during training rescaling images to fixed diameter of %0.1f pixels'%args.diam_mean)

# initialize model
if args.unet:
model = core.UnetModel(device=device,
pretrained_model=pretrained_model,
diam_mean=szmean,
residual_on=args.residual_on,
style_on=args.style_on,
concatenation=args.concatenation,
nclasses=args.nclasses,
nchan=nchan)
else:
model = models.CellposeModel(device=device,
pretrained_model=pretrained_model if model_type is None else None,
model_type=model_type,
diam_mean=szmean,
residual_on=args.residual_on,
style_on=args.style_on,
concatenation=args.concatenation,
nchan=nchan)
model = models.CellposeModel(device=device,
pretrained_model=pretrained_model if model_type is None else None,
model_type=model_type,
diam_mean=szmean,
nchan=nchan)

# train segmentation model
if args.train:
cpmodel_path = model.train(images, labels, train_files=image_names,
cpmodel_path = train.train_seg(model.net, images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels, test_files=image_names_test,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
channels=channels,
save_path=os.path.realpath(args.dir), save_every=args.save_every,
save_each=args.save_each,
save_path=os.path.realpath(args.dir),
save_every=args.save_every,
SGD=args.SGD,
n_epochs=args.n_epochs,
batch_size=args.batch_size,
min_train_masks=args.min_train_masks,
Expand All @@ -259,13 +241,12 @@ def main():
masks = [lbl[0] for lbl in labels]
test_masks = [lbl[0] for lbl in test_labels] if test_labels is not None else test_labels
# data has already been normalized and reshaped
sz_model.train(images, masks, test_images, test_masks,
channels=None, normalize=False,
sz_model.params = train.train_size(model.net, model.pretrained_model, images, masks, test_images, test_masks,
channels=channels,
batch_size=args.batch_size)
if test_images is not None:
predicted_diams, diams_style = sz_model.eval(test_images,
channels=None,
normalize=False)
channels=channels)
ccs = np.corrcoef(diams_style, np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0,1]
cc = np.corrcoef(predicted_diams, np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0,1]
logger.info('style test correlation: %0.4f; final test correlation: %0.4f'%(ccs,cc))
Expand Down
26 changes: 7 additions & 19 deletions cellpose/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
"""

import argparse
Expand All @@ -19,6 +19,8 @@ def get_arg_parser():
parser.add_argument('--version', action='store_true', help='show cellpose version info')
parser.add_argument('--verbose', action='store_true',
help='show information about running and settings and save to log')
parser.add_argument('--Zstack', action='store_true',
help='run GUI in 3D mode')

# settings for CPU vs GPU
hardware_args = parser.add_argument_group("Hardware Arguments")
Expand Down Expand Up @@ -58,16 +60,11 @@ def get_arg_parser():
help='model to use for running or starting training')
model_args.add_argument('--add_model', required=False, default=None, type=str,
help='model path to copy model to hidden .cellpose folder for using in GUI/CLI')
model_args.add_argument('--unet', action='store_true', help='run standard unet instead of cellpose flow output')
model_args.add_argument('--nclasses', default=3, type=int,
help='if running unet, choose 2 or 3; cellpose always uses 3')


# algorithm settings
algorithm_args = parser.add_argument_group("Algorithm Arguments")
algorithm_args.add_argument('--no_resample', action='store_true',
help="disable dynamics on full image (makes algorithm faster for images with large diameters)")
algorithm_args.add_argument('--net_avg', action='store_true',
help='run 4 networks instead of 1 and average results')
algorithm_args.add_argument('--no_interp', action='store_true',
help='do not interpolate when running dynamics (was default)')
algorithm_args.add_argument('--no_norm', action='store_true', help='do not normalize images (normalize=False)')
Expand All @@ -79,9 +76,7 @@ def get_arg_parser():
help='compute masks in 2D then stitch together masks with IoU>0.9 across planes')
algorithm_args.add_argument('--min_size', required=False, default=15, type=int,
help='minimum number of pixels per mask, can turn off with -1')
algorithm_args.add_argument('--fast_mode', action='store_true',
help='now equivalent to --no_resample; make code run faster by turning off resampling')


algorithm_args.add_argument('--flow_threshold', default=0.4, type=float,
help='flow error threshold, 0 turns off this optional QC step. Default: %(default)s')
algorithm_args.add_argument('--cellprob_threshold', default=0, type=float,
Expand Down Expand Up @@ -145,18 +140,11 @@ def get_arg_parser():
training_args.add_argument('--min_train_masks',
default=5, type=int,
help='minimum number of masks a training image must have to be used. Default: %(default)s')
training_args.add_argument('--residual_on',
default=1, type=int, help='use residual connections')
training_args.add_argument('--style_on',
default=1, type=int, help='use style vector')
training_args.add_argument('--concatenation',
default=0, type=int,
help='concatenate downsampled layers with upsampled layers (off by default which means they are added)')
training_args.add_argument('--SGD',
default=1, type=int, help='use SGD')
training_args.add_argument('--save_every',
default=100, type=int,
help='number of epochs to skip between saves. Default: %(default)s')
training_args.add_argument('--save_each', action='store_true',
help='save the model under a different filename per --save_every epoch for later comparsion')
training_args.add_argument('--model_name_out', default=None, type=str,
help='Name of model to save as, defaults to name describing model architecture. '
'Model is saved in the folder specified by --dir in models subfolder.')
Expand Down
Loading

0 comments on commit ffbf8d0

Please sign in to comment.