Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 3d auto segmentation #169

Merged
merged 4 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 1 addition & 168 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
import segment_anything.utils.amg as amg_utils
import vigra

from elf.evaluation.matching import label_overlap, intersection_over_union
from elf.segmentation import embeddings as embed
from elf.segmentation.stitching import stitch_segmentation
from nifty.tools import blocking, takeDict
from scipy.optimize import linear_sum_assignment
from nifty.tools import blocking

from segment_anything.predictor import SamPredictor

Expand Down Expand Up @@ -1119,168 +1117,3 @@ def get_amg(
amg = EmbeddingMaskGenerator(predictor, **kwargs) if embedding_based_amg else\
AutomaticMaskGenerator(predictor, **kwargs)
return amg


#
# Experimental functionality
#


def _segment_instances_from_embeddings(
predictor, image_embeddings, i=None,
# mws settings
offsets=None, distance_type="l2", bias=0.0,
# sam settings
box_nms_thresh=0.7, pred_iou_thresh=0.88,
stability_score_thresh=0.95, stability_score_offset=1.0,
use_box=True, use_mask=True, use_points=False,
# general settings
min_initial_size=5, min_size=0, with_background=False,
verbose=True, return_initial_segmentation=False, box_extension=0.1,
):
amg = EmbeddingMaskGenerator(
predictor, offsets, min_initial_size, distance_type, bias,
use_box, use_mask, use_points, box_extension
)
shape = image_embeddings["original_size"]
input_ = _FakeInput(shape)
amg.initialize(input_, image_embeddings, i, verbose=verbose)
mask_data = amg.generate(
pred_iou_thresh, stability_score_thresh, stability_score_offset, box_nms_thresh, min_size
)
segmentation = mask_data_to_segmentation(mask_data, shape, with_background)
if return_initial_segmentation:
initial_segmentation = amg.get_initial_segmentation()
return segmentation, initial_segmentation
return segmentation


def _segment_instances_from_embeddings_with_tiling(
predictor, image_embeddings, i=None, verbose=True, with_background=True,
return_initial_segmentation=False, **kwargs,
):
features = image_embeddings["features"]
shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"]

initial_segmentations = {}

def segment_tile(_, tile_id):
tile_features = features[tile_id]
tile_image_embeddings = {
"features": tile_features,
"input_size": tile_features.attrs["input_size"],
"original_size": tile_features.attrs["original_size"]
}
if return_initial_segmentation:
seg, initial_seg = _segment_instances_from_embeddings(
predictor, image_embeddings=tile_image_embeddings, i=i,
with_background=with_background, verbose=verbose,
return_initial_segmentation=True, **kwargs,
)
initial_segmentations[tile_id] = initial_seg
else:
seg = _segment_instances_from_embeddings(
predictor, image_embeddings=tile_image_embeddings, i=i,
with_background=with_background, verbose=verbose, **kwargs,
)
return seg

# fake input data
input_ = _FakeInput(shape)

# run stitching based segmentation
segmentation = stitch_segmentation(
input_, segment_tile, tile_shape, halo, with_background=with_background, verbose=verbose
)

if return_initial_segmentation:
initial_segmentation = stitch_segmentation(
input_, lambda _, tile_id: initial_segmentations[tile_id],
tile_shape, halo,
with_background=with_background, verbose=verbose
)
return segmentation, initial_segmentation

return segmentation


# TODO refactor in a class with `initialize` and `generate` logic
# this is still experimental and not yet ready to be integrated within the annotator_3d
# (will need to see how well it works with retrained models)
def _segment_instances_from_embeddings_3d(predictor, image_embeddings, verbose=1, iou_threshold=0.50, **kwargs):
if image_embeddings["original_size"] is None: # tiled embeddings
is_tiled = True
image_shape = tuple(image_embeddings["features"].attrs["shape"])
n_slices = len(image_embeddings["features"][0])

else: # normal embeddings (not tiled)
is_tiled = False
image_shape = tuple(image_embeddings["original_size"])
n_slices = image_embeddings["features"].shape[0]

shape = (n_slices,) + image_shape
segmentation_function = _segment_instances_from_embeddings_with_tiling if is_tiled else\
_segment_instances_from_embeddings

segmentation = np.zeros(shape, dtype="uint32")

def match_segments(seg, prev_seg):
overlap, ignore_idx = label_overlap(seg, prev_seg, ignore_label=0)
scores = intersection_over_union(overlap)
# remove ignore_label (remapped to continuous object_ids)
if ignore_idx[0] is not None:
scores = np.delete(scores, ignore_idx[0], axis=0)
if ignore_idx[1] is not None:
scores = np.delete(scores, ignore_idx[1], axis=1)

n_matched = min(scores.shape)
no_match = n_matched == 0 or (not np.any(scores >= iou_threshold))

max_id = segmentation.max()
if no_match:
seg[seg != 0] += max_id

else:
# compute optimal matching with scores as tie-breaker
costs = -(scores >= iou_threshold).astype(float) - scores / (2*n_matched)
seg_ind, prev_ind = linear_sum_assignment(costs)

seg_ids, prev_ids = np.unique(seg)[1:], np.unique(prev_seg)[1:]
match_ok = scores[seg_ind, prev_ind] >= iou_threshold

id_updates = {0: 0}
matched_ids, matched_prev = seg_ids[seg_ind[match_ok]], prev_ids[prev_ind[match_ok]]
id_updates.update(
{seg_id: prev_id for seg_id, prev_id in zip(matched_ids, matched_prev) if seg_id != 0}
)

unmatched_ids = np.setdiff1d(seg_ids, np.concatenate([np.zeros(1, dtype=matched_ids.dtype), matched_ids]))
id_updates.update({seg_id: max_id + i for i, seg_id in enumerate(unmatched_ids, 1)})

seg = takeDict(id_updates, seg)

return seg

ids_to_slices = {}
# segment the objects starting from slice 0
for z in tqdm(
range(0, n_slices), total=n_slices, desc="Run instance segmentation in 3d", disable=not bool(verbose)
):
seg = segmentation_function(predictor, image_embeddings, i=z, verbose=False, **kwargs)
if z > 0:
prev_seg = segmentation[z - 1]
seg = match_segments(seg, prev_seg)

# keep track of the slices per object id to get rid of unconnected objects in the post-processing
this_ids = np.unique(seg)[1:]
for id_ in this_ids:
ids_to_slices[id_] = ids_to_slices.get(id_, []) + [z]

segmentation[z] = seg

# get rid of objects that are just in a single slice
filter_objects = [seg_id for seg_id, slice_list in ids_to_slices.items() if len(slice_list) == 1]
segmentation[np.isin(segmentation, filter_objects)] = 0
vigra.analysis.relabelConsecutive(segmentation, out=segmentation)

return segmentation
134 changes: 134 additions & 0 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Multi-dimensional segmentation with segment anything.
"""

from typing import Any, Optional

import numpy as np
from segment_anything.predictor import SamPredictor

from . import util
from .prompt_based_segmentation import segment_from_mask


def segment_mask_in_volume(
segmentation: np.ndarray,
predictor: SamPredictor,
image_embeddings: util.ImageEmbeddings,
segmented_slices: np.ndarray,
stop_lower: bool,
stop_upper: bool,
iou_threshold: float,
projection: str,
progress_bar: Optional[Any] = None,
box_extension: int = 0,
) -> np.ndarray:
"""Segment an object mask in in volumetric data.

Args:
segmentation: The initial segmentation for the object.
predictor: The segment anything predictor.
image_embeddings: The precomputed image embeddings for the volume.
segmented_slices: List of slices for which this object has already been segmented.
stop_lower: Whether to stop at the lowest segmented slice.
stop_upper: Wheter to stop at the topmost segmented slice.
iou_threshold: The IOU threshold for continuing segmentation across 3d.
projection: The projection method to use. One of 'mask', 'bounding_box' or 'points'.
progress_bar: Optional progress bar.
box_extension: Extension factor for increasing the box size after projection

Returns:
Array with the volumetric segmentation
"""
assert projection in ("mask", "bounding_box", "points")
if projection == "mask":
use_box, use_mask, use_points = True, True, False
elif projection == "points":
use_box, use_mask, use_points = True, True, True
else:
use_box, use_mask, use_points = True, False, False

def _update_progress():
if progress_bar is not None:
progress_bar.update(1)

def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
z = z_start + increment
while True:
if verbose:
print(f"Segment {z_start} to {z_stop}: segmenting slice {z}")
seg_prev = segmentation[z - increment]
seg_z = segment_from_mask(predictor, seg_prev, image_embeddings=image_embeddings, i=z,
use_mask=use_mask, use_box=use_box, use_points=use_points,
box_extension=box_extension)
if threshold is not None:
iou = util.compute_iou(seg_prev, seg_z)
if iou < threshold:
msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {iou_threshold}."
print(msg)
break
segmentation[z] = seg_z
z += increment
if stopping_criterion(z, z_stop):
if verbose:
print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
break
_update_progress()

z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())

# segment below the min slice
if z0 > 0 and not stop_lower:
segment_range(z0, 0, -1, np.less, iou_threshold)

# segment above the max slice
if z1 < segmentation.shape[0] - 1 and not stop_upper:
segment_range(z1, segmentation.shape[0] - 1, 1, np.greater, iou_threshold)

verbose = False
# segment in between min and max slice
if z0 != z1:
for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
slice_diff = z_stop - z_start
z_mid = int((z_start + z_stop) // 2)

if slice_diff == 1: # the slices are adjacent -> we don't need to do anything
pass

elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper
segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)

elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower
segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)

elif slice_diff == 2: # there is only one slice in between -> use combined mask
z = z_start + 1
seg_prompt = np.logical_or(segmentation[z_start] == 1, segmentation[z_stop] == 1)
segmentation[z] = segment_from_mask(
predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
use_mask=use_mask, use_box=use_box, use_points=use_points,
box_extension=box_extension
)
_update_progress()

else: # there is a range of more than 2 slices in between -> segment ranges
# segment from bottom
segment_range(
z_start, z_mid, 1, np.greater_equal if slice_diff % 2 == 0 else np.greater, verbose=verbose
)
# segment from top
segment_range(z_stop, z_mid, -1, np.less_equal, verbose=verbose)
# if the difference between start and stop is even,
# then we have a slice in the middle that is the same distance from top bottom
# in this case the slice is not segmented in the ranges above, and we segment it
# using the combined mask from the adjacent top and bottom slice as prompt
if slice_diff % 2 == 0:
seg_prompt = np.logical_or(segmentation[z_mid - 1] == 1, segmentation[z_mid + 1] == 1)
segmentation[z_mid] = segment_from_mask(
predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
use_mask=use_mask, use_box=use_box, use_points=use_points,
box_extension=box_extension
)
_update_progress()

return segmentation
2 changes: 1 addition & 1 deletion micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _autosegment_widget(

shape = v.layers["raw"].data.shape[:2]
seg = instance_segmentation.mask_data_to_segmentation(
seg, shape, with_background=True, min_object_size=min_object_size
seg, shape, with_background=with_background, min_object_size=min_object_size
)
assert isinstance(seg, np.ndarray)

Expand Down
Loading