Skip to content

Commit

Permalink
Add example for using the new 3d segmentation function
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Nov 2, 2023
1 parent 320da27 commit 5692c1f
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 62 deletions.
104 changes: 49 additions & 55 deletions examples/use_as_library/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,36 +33,15 @@ def cell_segmentation():

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.EmbeddingMaskGenerator(predictor, min_initial_size=10)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's faster than the original implementation,
# however the quality is not as high as the original instance segmentation quality yet.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
instances_mws = instance_segmentation.mask_data_to_segmentation(
instances_mws, shape=image.shape, with_background=True
instances = amg.generate(pred_iou_thresh=0.88)
instances = instance_segmentation.mask_data_to_segmentation(
instances, shape=image.shape, with_background=True
)

# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances_amg)
v.add_labels(instances_mws)
v.add_labels(instances)
napari.run()


Expand Down Expand Up @@ -95,49 +74,64 @@ def cell_segmentation_with_tiling():

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
instances = amg.generate(pred_iou_thresh=0.88)
instances = instance_segmentation.mask_data_to_segmentation(
instances, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.TiledEmbeddingMaskGenerator(predictor, min_initial_size=10)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's faster than the original implementation.
# however the quality is not as high as the original instance segmentation quality yet.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)

# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances_amg)
v.add_labels(instances_mws)
v.add_labels(instances)
v.add_labels(instances)
napari.run()


def segmentation_in_3d():
"""Run instance segmentation in 3d, for segmenting all objects that intersect
with a given slice. If you use a fine-tuned model for this then you should
first find good parameters for 2d segmentation.
"""
"""
from micro_sam.sample_data import synthetic_data

shape = (5, 512, 512)
data, _ = synthetic_data(shape)
predictor = util.get_sam_model(model_type="vit_t")
seg = segment_3d_from_slice(predictor, data, embedding_path="./tmp_embeddings.zarr", verbose=True)
import imageio.v3 as imageio
from micro_sam.sample_data import fetch_nucleus_3d_example_data

# Load the example image data: 3d nucleus segmentation.
path = fetch_nucleus_3d_example_data("./data")
data = imageio.imread(path)

# Load the SAM model for prediction.
model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc.
checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here.
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)

# Run 3d segmentation for a given slice. Will segment all objects found in that slice
# throughout the volume.

# The slice that is used for segmentation in 2d. If you don't specify a slice
# then the middle slice is used.
z_slice = data.shape[0] // 2

# The threshold for filtering objects in the 2d segmentation based on the model's
# predicted iou score. If you use a custom model you should first find a good setting
# for this value, e.g. with the 2d annotation tool.
pred_iou_thresh = 0.88

# The threshold for filtering objects in the 2d segmentation based on the model's
# stability score for a given object. If you use a custom model you should first find a good setting
# for this value, e.g. with the 2d annotation tool.
stability_score_thresh = 0.95

instances = segment_3d_from_slice(
predictor, data, z=z_slice,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
verbose=True
)

# Show the results.
v = napari.Viewer()
v.add_image(data)
v.add_labels(seg)
v.add_labels(instances)
napari.run()


Expand Down
33 changes: 26 additions & 7 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def segment_mask_in_volume(
iou_threshold: The IOU threshold for continuing segmentation across 3d.
projection: The projection method to use. One of 'mask', 'bounding_box' or 'points'.
progress_bar: Optional progress bar.
box_extension: Extension factor for increasing the box size after projection
box_extension: Extension factor for increasing the box size after projection.
Returns:
Array with the volumetric segmentation
Expand Down Expand Up @@ -150,23 +150,41 @@ def segment_3d_from_slice(
min_object_size_z: int = 50,
max_object_size_z: Optional[int] = None,
iou_threshold: float = 0.8,
precompute_amg_state: bool = True,
):
"""
"""Segment all objects in a volume intersecting with a specific slice.
This function first segments the objects in the specified slice using the
automatic instance segmentation functionality. Then it segments all objects that
were found in that slice in the volume.
Args:
predictor:
predictor: The segment anything predictor.
raw: The volumetric image data.
z: The slice from which to start segmentation.
If none is given the central slice will be used.
embedding_path: The path were embeddings will be cached.
If none is given embeddings will not be cached.
projection: The projection method to use. One of 'mask', 'bounding_box' or 'points'.
box_extension: Extension factor for increasing the box size after projection.
verbose: Whether to print progress bar and other status messages.
pred_iou_thresh: The predicted iou value to filter objects in `AutomaticMaskGenerator.generate`.
stability_score_thresh: The stability score to filter objects in `AutomaticMaskGenerator.generate`.
min_object_size_z: Minimal object size in the segmented frame.
max_object_size_z: Maximal object size in the segmented frame.
iou_threshold: The IOU threshold for linking objects across slices.
Returns:
The
Segmentation volume.
"""
# Perform automatic instance segmentation.
# Compute the image embeddings.
image_embeddings = util.precompute_image_embeddings(predictor, raw, save_path=embedding_path, ndim=3)

# Select the middle slice if no slice is given.
if z is None:
z = raw.shape[0] // 2

if precompute_amg_state and (embedding_path is not None):
# Perform automatic instance segmentation.
if embedding_path is not None:
amg = cache_amg_state(predictor, raw[z], image_embeddings, embedding_path, verbose=verbose, i=z)
else:
amg = AutomaticMaskGenerator(predictor)
Expand All @@ -179,6 +197,7 @@ def segment_3d_from_slice(
max_object_size=max_object_size_z,
)

# Segment all objects that were found in 3d.
seg_ids = np.unique(seg_z)[1:]
segmentation = np.zeros(raw.shape, dtype=seg_z.dtype)
for seg_id in tqdm(seg_ids, desc="Segment objects in 3d", disable=not verbose):
Expand Down
26 changes: 26 additions & 0 deletions micro_sam/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,29 @@ def synthetic_data(shape, seed=None):

segmentation = label(image)
return image, segmentation


def fetch_nucleus_3d_example_data(save_directory: Union[str, os.PathLike]) -> str:
"""Download the sample data for 3d segmentation of nuclei.
This data contains a small crop from a volume from the publication
"Efficient automatic 3D segmentation of cell nuclei for high-content screening"
https://doi.org/10.1186/s12859-022-04737-4
Args:
save_directory: Root folder to save the downloaded data.
Returns:
The path of the downloaded image.
"""
save_directory = Path(save_directory)
os.makedirs(save_directory, exist_ok=True)
print("Example data directory is:", save_directory.resolve())
fname = "3d-nucleus-data.tif"
pooch.retrieve(
url="https://owncloud.gwdg.de/index.php/s/eW0uNCo8gedzWU4/download",
known_hash="4946896f747dc1c3fc82fb2e1320226d92f99d22be88ea5f9c37e3ba4e281205",
fname=fname,
path=save_directory,
progressbar=True,
)
return os.path.join(save_directory, fname)

0 comments on commit 5692c1f

Please sign in to comment.