diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index f1c5e644..ca58a9a5 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -16,11 +16,9 @@ import segment_anything.utils.amg as amg_utils import vigra -from elf.evaluation.matching import label_overlap, intersection_over_union from elf.segmentation import embeddings as embed from elf.segmentation.stitching import stitch_segmentation -from nifty.tools import blocking, takeDict -from scipy.optimize import linear_sum_assignment +from nifty.tools import blocking from segment_anything.predictor import SamPredictor @@ -1119,168 +1117,3 @@ def get_amg( amg = EmbeddingMaskGenerator(predictor, **kwargs) if embedding_based_amg else\ AutomaticMaskGenerator(predictor, **kwargs) return amg - - -# -# Experimental functionality -# - - -def _segment_instances_from_embeddings( - predictor, image_embeddings, i=None, - # mws settings - offsets=None, distance_type="l2", bias=0.0, - # sam settings - box_nms_thresh=0.7, pred_iou_thresh=0.88, - stability_score_thresh=0.95, stability_score_offset=1.0, - use_box=True, use_mask=True, use_points=False, - # general settings - min_initial_size=5, min_size=0, with_background=False, - verbose=True, return_initial_segmentation=False, box_extension=0.1, -): - amg = EmbeddingMaskGenerator( - predictor, offsets, min_initial_size, distance_type, bias, - use_box, use_mask, use_points, box_extension - ) - shape = image_embeddings["original_size"] - input_ = _FakeInput(shape) - amg.initialize(input_, image_embeddings, i, verbose=verbose) - mask_data = amg.generate( - pred_iou_thresh, stability_score_thresh, stability_score_offset, box_nms_thresh, min_size - ) - segmentation = mask_data_to_segmentation(mask_data, shape, with_background) - if return_initial_segmentation: - initial_segmentation = amg.get_initial_segmentation() - return segmentation, initial_segmentation - return segmentation - - -def _segment_instances_from_embeddings_with_tiling( - predictor, image_embeddings, i=None, verbose=True, with_background=True, - return_initial_segmentation=False, **kwargs, - ): - features = image_embeddings["features"] - shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"] - - initial_segmentations = {} - - def segment_tile(_, tile_id): - tile_features = features[tile_id] - tile_image_embeddings = { - "features": tile_features, - "input_size": tile_features.attrs["input_size"], - "original_size": tile_features.attrs["original_size"] - } - if return_initial_segmentation: - seg, initial_seg = _segment_instances_from_embeddings( - predictor, image_embeddings=tile_image_embeddings, i=i, - with_background=with_background, verbose=verbose, - return_initial_segmentation=True, **kwargs, - ) - initial_segmentations[tile_id] = initial_seg - else: - seg = _segment_instances_from_embeddings( - predictor, image_embeddings=tile_image_embeddings, i=i, - with_background=with_background, verbose=verbose, **kwargs, - ) - return seg - - # fake input data - input_ = _FakeInput(shape) - - # run stitching based segmentation - segmentation = stitch_segmentation( - input_, segment_tile, tile_shape, halo, with_background=with_background, verbose=verbose - ) - - if return_initial_segmentation: - initial_segmentation = stitch_segmentation( - input_, lambda _, tile_id: initial_segmentations[tile_id], - tile_shape, halo, - with_background=with_background, verbose=verbose - ) - return segmentation, initial_segmentation - - return segmentation - - -# TODO refactor in a class with `initialize` and `generate` logic -# this is still experimental and not yet ready to be integrated within the annotator_3d -# (will need to see how well it works with retrained models) -def _segment_instances_from_embeddings_3d(predictor, image_embeddings, verbose=1, iou_threshold=0.50, **kwargs): - if image_embeddings["original_size"] is None: # tiled embeddings - is_tiled = True - image_shape = tuple(image_embeddings["features"].attrs["shape"]) - n_slices = len(image_embeddings["features"][0]) - - else: # normal embeddings (not tiled) - is_tiled = False - image_shape = tuple(image_embeddings["original_size"]) - n_slices = image_embeddings["features"].shape[0] - - shape = (n_slices,) + image_shape - segmentation_function = _segment_instances_from_embeddings_with_tiling if is_tiled else\ - _segment_instances_from_embeddings - - segmentation = np.zeros(shape, dtype="uint32") - - def match_segments(seg, prev_seg): - overlap, ignore_idx = label_overlap(seg, prev_seg, ignore_label=0) - scores = intersection_over_union(overlap) - # remove ignore_label (remapped to continuous object_ids) - if ignore_idx[0] is not None: - scores = np.delete(scores, ignore_idx[0], axis=0) - if ignore_idx[1] is not None: - scores = np.delete(scores, ignore_idx[1], axis=1) - - n_matched = min(scores.shape) - no_match = n_matched == 0 or (not np.any(scores >= iou_threshold)) - - max_id = segmentation.max() - if no_match: - seg[seg != 0] += max_id - - else: - # compute optimal matching with scores as tie-breaker - costs = -(scores >= iou_threshold).astype(float) - scores / (2*n_matched) - seg_ind, prev_ind = linear_sum_assignment(costs) - - seg_ids, prev_ids = np.unique(seg)[1:], np.unique(prev_seg)[1:] - match_ok = scores[seg_ind, prev_ind] >= iou_threshold - - id_updates = {0: 0} - matched_ids, matched_prev = seg_ids[seg_ind[match_ok]], prev_ids[prev_ind[match_ok]] - id_updates.update( - {seg_id: prev_id for seg_id, prev_id in zip(matched_ids, matched_prev) if seg_id != 0} - ) - - unmatched_ids = np.setdiff1d(seg_ids, np.concatenate([np.zeros(1, dtype=matched_ids.dtype), matched_ids])) - id_updates.update({seg_id: max_id + i for i, seg_id in enumerate(unmatched_ids, 1)}) - - seg = takeDict(id_updates, seg) - - return seg - - ids_to_slices = {} - # segment the objects starting from slice 0 - for z in tqdm( - range(0, n_slices), total=n_slices, desc="Run instance segmentation in 3d", disable=not bool(verbose) - ): - seg = segmentation_function(predictor, image_embeddings, i=z, verbose=False, **kwargs) - if z > 0: - prev_seg = segmentation[z - 1] - seg = match_segments(seg, prev_seg) - - # keep track of the slices per object id to get rid of unconnected objects in the post-processing - this_ids = np.unique(seg)[1:] - for id_ in this_ids: - ids_to_slices[id_] = ids_to_slices.get(id_, []) + [z] - - segmentation[z] = seg - - # get rid of objects that are just in a single slice - filter_objects = [seg_id for seg_id, slice_list in ids_to_slices.items() if len(slice_list) == 1] - segmentation[np.isin(segmentation, filter_objects)] = 0 - vigra.analysis.relabelConsecutive(segmentation, out=segmentation) - - return segmentation