Skip to content

Commit

Permalink
Remove deprecated instance segmentation functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Sep 2, 2023
1 parent fb2b608 commit 422909a
Showing 1 changed file with 1 addition and 168 deletions.
169 changes: 1 addition & 168 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
import segment_anything.utils.amg as amg_utils
import vigra

from elf.evaluation.matching import label_overlap, intersection_over_union
from elf.segmentation import embeddings as embed
from elf.segmentation.stitching import stitch_segmentation
from nifty.tools import blocking, takeDict
from scipy.optimize import linear_sum_assignment
from nifty.tools import blocking

from segment_anything.predictor import SamPredictor

Expand Down Expand Up @@ -1119,168 +1117,3 @@ def get_amg(
amg = EmbeddingMaskGenerator(predictor, **kwargs) if embedding_based_amg else\
AutomaticMaskGenerator(predictor, **kwargs)
return amg


#
# Experimental functionality
#


def _segment_instances_from_embeddings(
predictor, image_embeddings, i=None,
# mws settings
offsets=None, distance_type="l2", bias=0.0,
# sam settings
box_nms_thresh=0.7, pred_iou_thresh=0.88,
stability_score_thresh=0.95, stability_score_offset=1.0,
use_box=True, use_mask=True, use_points=False,
# general settings
min_initial_size=5, min_size=0, with_background=False,
verbose=True, return_initial_segmentation=False, box_extension=0.1,
):
amg = EmbeddingMaskGenerator(
predictor, offsets, min_initial_size, distance_type, bias,
use_box, use_mask, use_points, box_extension
)
shape = image_embeddings["original_size"]
input_ = _FakeInput(shape)
amg.initialize(input_, image_embeddings, i, verbose=verbose)
mask_data = amg.generate(
pred_iou_thresh, stability_score_thresh, stability_score_offset, box_nms_thresh, min_size
)
segmentation = mask_data_to_segmentation(mask_data, shape, with_background)
if return_initial_segmentation:
initial_segmentation = amg.get_initial_segmentation()
return segmentation, initial_segmentation
return segmentation


def _segment_instances_from_embeddings_with_tiling(
predictor, image_embeddings, i=None, verbose=True, with_background=True,
return_initial_segmentation=False, **kwargs,
):
features = image_embeddings["features"]
shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"]

initial_segmentations = {}

def segment_tile(_, tile_id):
tile_features = features[tile_id]
tile_image_embeddings = {
"features": tile_features,
"input_size": tile_features.attrs["input_size"],
"original_size": tile_features.attrs["original_size"]
}
if return_initial_segmentation:
seg, initial_seg = _segment_instances_from_embeddings(
predictor, image_embeddings=tile_image_embeddings, i=i,
with_background=with_background, verbose=verbose,
return_initial_segmentation=True, **kwargs,
)
initial_segmentations[tile_id] = initial_seg
else:
seg = _segment_instances_from_embeddings(
predictor, image_embeddings=tile_image_embeddings, i=i,
with_background=with_background, verbose=verbose, **kwargs,
)
return seg

# fake input data
input_ = _FakeInput(shape)

# run stitching based segmentation
segmentation = stitch_segmentation(
input_, segment_tile, tile_shape, halo, with_background=with_background, verbose=verbose
)

if return_initial_segmentation:
initial_segmentation = stitch_segmentation(
input_, lambda _, tile_id: initial_segmentations[tile_id],
tile_shape, halo,
with_background=with_background, verbose=verbose
)
return segmentation, initial_segmentation

return segmentation


# TODO refactor in a class with `initialize` and `generate` logic
# this is still experimental and not yet ready to be integrated within the annotator_3d
# (will need to see how well it works with retrained models)
def _segment_instances_from_embeddings_3d(predictor, image_embeddings, verbose=1, iou_threshold=0.50, **kwargs):
if image_embeddings["original_size"] is None: # tiled embeddings
is_tiled = True
image_shape = tuple(image_embeddings["features"].attrs["shape"])
n_slices = len(image_embeddings["features"][0])

else: # normal embeddings (not tiled)
is_tiled = False
image_shape = tuple(image_embeddings["original_size"])
n_slices = image_embeddings["features"].shape[0]

shape = (n_slices,) + image_shape
segmentation_function = _segment_instances_from_embeddings_with_tiling if is_tiled else\
_segment_instances_from_embeddings

segmentation = np.zeros(shape, dtype="uint32")

def match_segments(seg, prev_seg):
overlap, ignore_idx = label_overlap(seg, prev_seg, ignore_label=0)
scores = intersection_over_union(overlap)
# remove ignore_label (remapped to continuous object_ids)
if ignore_idx[0] is not None:
scores = np.delete(scores, ignore_idx[0], axis=0)
if ignore_idx[1] is not None:
scores = np.delete(scores, ignore_idx[1], axis=1)

n_matched = min(scores.shape)
no_match = n_matched == 0 or (not np.any(scores >= iou_threshold))

max_id = segmentation.max()
if no_match:
seg[seg != 0] += max_id

else:
# compute optimal matching with scores as tie-breaker
costs = -(scores >= iou_threshold).astype(float) - scores / (2*n_matched)
seg_ind, prev_ind = linear_sum_assignment(costs)

seg_ids, prev_ids = np.unique(seg)[1:], np.unique(prev_seg)[1:]
match_ok = scores[seg_ind, prev_ind] >= iou_threshold

id_updates = {0: 0}
matched_ids, matched_prev = seg_ids[seg_ind[match_ok]], prev_ids[prev_ind[match_ok]]
id_updates.update(
{seg_id: prev_id for seg_id, prev_id in zip(matched_ids, matched_prev) if seg_id != 0}
)

unmatched_ids = np.setdiff1d(seg_ids, np.concatenate([np.zeros(1, dtype=matched_ids.dtype), matched_ids]))
id_updates.update({seg_id: max_id + i for i, seg_id in enumerate(unmatched_ids, 1)})

seg = takeDict(id_updates, seg)

return seg

ids_to_slices = {}
# segment the objects starting from slice 0
for z in tqdm(
range(0, n_slices), total=n_slices, desc="Run instance segmentation in 3d", disable=not bool(verbose)
):
seg = segmentation_function(predictor, image_embeddings, i=z, verbose=False, **kwargs)
if z > 0:
prev_seg = segmentation[z - 1]
seg = match_segments(seg, prev_seg)

# keep track of the slices per object id to get rid of unconnected objects in the post-processing
this_ids = np.unique(seg)[1:]
for id_ in this_ids:
ids_to_slices[id_] = ids_to_slices.get(id_, []) + [z]

segmentation[z] = seg

# get rid of objects that are just in a single slice
filter_objects = [seg_id for seg_id, slice_list in ids_to_slices.items() if len(slice_list) == 1]
segmentation[np.isin(segmentation, filter_objects)] = 0
vigra.analysis.relabelConsecutive(segmentation, out=segmentation)

return segmentation

0 comments on commit 422909a

Please sign in to comment.