Skip to content

Commit

Permalink
Add example script for automatic segmentation (#781)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored Nov 15, 2024
1 parent 031dae1 commit 962e34d
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 9 deletions.
11 changes: 6 additions & 5 deletions examples/annotator_2d.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os

import imageio.v3 as imageio
from micro_sam.util import get_cache_directory
from micro_sam.sam_annotator import annotator_2d
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data
from micro_sam.util import get_cache_directory


DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
os.makedirs(EMBEDDING_CACHE, exist_ok=True)


def livecell_annotator(use_finetuned_model):
"""Run the 2d annotator for an example image from the LiveCELL dataset.
"""Run the 2d annotator for an example image from the LIVEcell dataset.
See https://doi.org/10.1038/s41592-021-01249-6 for details on the data.
"""
Expand All @@ -29,7 +30,7 @@ def livecell_annotator(use_finetuned_model):


def hela_2d_annotator(use_finetuned_model):
"""Run the 2d annotator for an example image form the cell tracking challenge HeLa 2d dataset.
"""Run the 2d annotator for an example image from the Cell Tracking Challenge (HeLa 2d) dataset.
"""
example_data = fetch_hela_2d_example_data(DATA_CACHE)
image = imageio.imread(example_data)
Expand All @@ -46,7 +47,7 @@ def hela_2d_annotator(use_finetuned_model):

def wholeslide_annotator(use_finetuned_model):
"""Run the 2d annotator with tiling for an example whole-slide image from the
NeuRIPS cell segmentation challenge.
NeurIPS Cell Segmentation challenge.
See https://neurips22-cellseg.grand-challenge.org/ for details on the data.
"""
Expand Down Expand Up @@ -79,6 +80,6 @@ def main():

# The corresponding CLI call for hela_2d_annotator:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_2d -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-hela2d.zarr
# $ micro_sam.annotator_2d -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-hela2d.zarr # noqa
if __name__ == "__main__":
main()
150 changes: 150 additions & 0 deletions examples/automatic_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os

import imageio.v3 as imageio

from micro_sam.util import get_cache_directory
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data


DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")


def livecell_automatic_segmentation(model_type, use_amg, generate_kwargs):
"""Run the automatic segmentation for an example image from the LIVECell dataset.
See https://doi.org/10.1038/s41592-021-01249-6 for details on the data.
"""
example_data = fetch_livecell_example_data(DATA_CACHE)
image = imageio.imread(example_data)

predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type,
checkpoint=None, # Replace this with your custom checkpoint.
amg=use_amg,
is_tiled=False, # Switch to 'True' in case you would like to perform tiling-window based prediction.
)

segmentation = automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=image,
ndim=2,
tile_shape=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a tile shape
halo=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a halo shape.
**generate_kwargs
)

import napari
v = napari.Viewer()
v.add_image(image)
v.add_labels(segmentation)
napari.run()


def hela_automatic_segmentation(model_type, use_amg, generate_kwargs):
"""Run the automatic segmentation for an example image from the Cell Tracking Challenge (HeLa 2d) dataset.
"""
example_data = fetch_hela_2d_example_data(DATA_CACHE)
image = imageio.imread(example_data)

predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type,
checkpoint=None, # Replace this with your custom checkpoint.
amg=use_amg,
is_tiled=False, # Switch to 'True' in case you would like to perform tiling-window based prediction.
)

segmentation = automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=image,
ndim=2,
tile_shape=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a tile shape
halo=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a halo shape.
**generate_kwargs
)

import napari
v = napari.Viewer()
v.add_image(image)
v.add_labels(segmentation)
napari.run()


def wholeslide_automatic_segmentation(model_type, use_amg, generate_kwargs):
"""Run the automatic segmentation with tiling for an example whole-slide image from the
NeurIPS Cell Segmentation challenge.
"""
example_data = fetch_wholeslide_example_data(DATA_CACHE)
image = imageio.imread(example_data)

predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type,
checkpoint=None, # Replace this with your custom checkpoint.
amg=use_amg,
is_tiled=True,
)

segmentation = automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=image,
ndim=2,
tile_shape=(1024, 1024),
halo=(256, 256),
**generate_kwargs
)

import napari
v = napari.Viewer()
v.add_image(image)
v.add_labels(segmentation)
napari.run()


def main():
# The choice of Segment Anything model.
model_type = "vit_b_lm"

# Whether to use:
# the automatic mask generation (AMG): supported by all our models.
# the automatic instance segmentation (AIS): supported by 'micro-sam' models.
use_amg = False # 'False' chooses AIS as the automatic segmentation mode.

# Post-processing parameters for automatic segmentation.
if use_amg: # AMG parameters
generate_kwargs = {
"pred_iou_thresh": 0.88,
"stability_score_thresh": 0.95,
"box_nms_thresh": 0.7,
"crop_nms_thresh": 0.7,
"min_mask_region_area": 0,
"output_mode": "binary_mask",
}
else: # AIS parameters
generate_kwargs = {
"center_distance_threshold": 0.5,
"boundary_distance_threshold": 0.5,
"foreground_threshold": 0.5,
"foreground_smoothing": 1.0,
"distance_smoothing": 1.6,
"min_size": 0,
"output_mode": "binary_mask",
}

# Automatic segmentation for livecell data.
livecell_automatic_segmentation(model_type, use_amg, generate_kwargs)

# Automatic segmentation for cell tracking challenge hela data.
# hela_automatic_segmentation(model_type, use_amg, generate_kwargs)

# Automatic segmentation for a whole slide image.
# wholeslide_automatic_segmentation(model_type, use_amg, generate_kwargs)


# The corresponding CLI call for hela_automatic_segmentation:
# (replace with cache directory on your machine)
# $ micro_sam.automatic_segmentation -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -o hela-2d-image_segmentation.tif # noqa
if __name__ == "__main__":
main()
7 changes: 3 additions & 4 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ def get_predictor_and_segmenter(
decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)

segmenter = get_amg(
predictor=predictor,
is_tiled=is_tiled,
decoder=decoder,
**kwargs
predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs
)

return predictor, segmenter


Expand Down Expand Up @@ -232,6 +230,7 @@ def _convert_argval(value):
model_type=args.model_type,
checkpoint=args.checkpoint,
device=args.device,
amg=args.amg,
is_tiled=args.tile_shape is not None,
)

Expand Down

0 comments on commit 962e34d

Please sign in to comment.