Skip to content

Commit

Permalink
Prototype napari plugin widget for annotator_2d
Browse files Browse the repository at this point in the history
  • Loading branch information
GenevieveBuckley committed Sep 5, 2023
1 parent 9bad755 commit ce542f0
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 21 deletions.
6 changes: 6 additions & 0 deletions micro_sam/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ contributions:
- id: micro-sam.sample_data_segmentation
python_name: micro_sam.sample_data:sample_data_segmentation
title: Load segmentation sample data from micro-sam plugin
- id: micro-sam.annotator_2d_plugin
python_name: micro_sam.sam_annotator.annotator_2d:annotator_2d_plugin
title: micro-sam 2D annotator
widgets:
- command: micro-sam.annotator_2d_plugin
display_name: micro-sam 2D annotator
sample_data:
- command: micro-sam.sample_data_image_series
display_name: Image series example data
Expand Down
200 changes: 181 additions & 19 deletions micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import warnings
import os
from pathlib import Path
from typing import Optional, Tuple
import warnings

import napari
from napari.types import ImageData, LabelsData
import numpy as np

from magicgui import magicgui
from magicgui import magicgui, magic_factory
from napari import Viewer
from segment_anything import SamPredictor

from .. import instance_segmentation, util
from ..precompute_state import cache_amg_state
from ..util import AVAILABLE_MODELS, _DEFAULT_MODEL
from ..visualization import project_embeddings_for_visualization
from . import util as vutil
from .gui_utils import show_wrong_file_warning
Expand Down Expand Up @@ -88,6 +92,117 @@ def _get_shape(raw):
return shape


def _start_2d_annotation(v: napari.Viewer, raw, segmentation_result=None):
print("_start_2d_annotation")
_setup_layers(v, raw, segmentation_result=segmentation_result)
_setup_dock_widgets(v)
_add_key_bindings(v)


def _setup_layers(v: napari.Viewer, raw, segmentation_result=None):
print("_setup_layers")
# TODO: ideally do not hard code layer names
shape = _get_shape(raw)

# raw input image data
if "raw" not in v.layers:
v.add_image(raw)
else:
if not np.allclose(raw, v.layers["raw"].data):
v.layers["raw"].data = raw
v.layers["raw"].refresh()

# auto_segmentation
if "auto_segmentation" not in v.layers:
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation")
else:
v.layers["auto_segmentation"].data = np.zeros(shape, dtype="uint32")
v.layers["auto_segmentation"].refresh()

# committed_objects / segmentation result
if "committed_objects" not in v.layers:
if segmentation_result is not None:
v.add_labels(segmentation_result, name="committed_objects")
else:
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="committed_objects")
else:
v.layers["committed_objects"].data = np.zeros(shape, dtype="uint32")
v.layers["committed_objects"].refresh()

# current_object
if "current_object" not in v.layers:
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object")
else:
v.layers["current_object"].data = np.zeros(shape, dtype="uint32")
v.layers["current_object"].refresh()

# prompts
if "prompts" not in v.layers:
labels = ["positive", "negative"]
prompts = v.add_points(
data=[[0.0, 0.0], [0.0, 0.0]], # FIXME workaround
name="prompts",
properties={"label": labels},
edge_color="label",
edge_color_cycle=vutil.LABEL_COLOR_CYCLE,
symbol="o",
face_color="transparent",
edge_width=0.5,
size=12,
ndim=2,
)
prompts.edge_color_mode = "cycle"
else:
v.layers["prompts"].data = []
v.layers["prompts"].refresh()
# box prompts
if "box_prompts" not in v.layers:
v.add_shapes(
name="box_prompts",
face_color="transparent",
edge_color="green",
edge_width=4,
)
# remove dummy point, and/or other existing point prompts
v.layers["box_prompts"].data = [] # warning, this also erases v.layers["prompts"].properties["labels"]
v.layers["box_prompts"].refresh()


def _setup_dock_widgets(v: napari.Viewer):
print("_setup_dock_widgets")
# TODO: try to make this a cleaner code design to get the prompt_widget
prompts = v.layers["prompts"]
labels = ["positive", "negative"]
prompt_widget = vutil.create_prompt_menu(prompts, labels)
v.window.add_dock_widget(prompt_widget)

v.window.add_dock_widget(_autosegment_widget)
v.window.add_dock_widget(_segment_widget)
v.window.add_dock_widget(vutil._commit_segmentation_widget)
v.window.add_dock_widget(vutil._clear_widget)


def _add_key_bindings(v: napari.Viewer):
print("_add_key_bindings")
@v.bind_key("s")
def _segmet(v):
_segment_widget(v)

@v.bind_key("c")
def _commit(v):
vutil._commit_segmentation_widget(v)

# TODO: try to make this a cleaner code design
prompts = v.layers["prompts"]
@v.bind_key("t")
def _toggle_label(event=None):
vutil.toggle_label(prompts)

@v.bind_key("Shift-C")
def clear_prompts(v):
vutil.clear_annotations(v)


def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings):
v = Viewer()

Expand Down Expand Up @@ -191,6 +306,41 @@ def annotator_2d(
v: Optional[Viewer] = None,
predictor: Optional[SamPredictor] = None,
precompute_amg_state: bool = False,
) -> Optional[Viewer]:
"""This function can be called in python (or from CLI) to start the annotation tool"""
if v is None:
v = napari.Viewer()

v.add_image(raw)
v = annotator_2d_plugin(
raw,
embedding_path=embedding_path,
show_embeddings=show_embeddings,
segmentation_result=segmentation_result,
precompute_amg_state=precompute_amg_state,
tile_shape=tile_shape,
halo=halo,
predictor=predictor,
)

if return_viewer is True:
return v

napari.run()


@magic_factory(call_button="Start 2D annotation")
def annotator_2d_plugin(
raw: ImageData,
v: napari.Viewer,
embedding_path: os.PathLike = "./embeddings",
show_embeddings: bool = False,
segmentation_result: Optional[LabelsData] = None,
precompute_amg_state: bool = False,
model_type: AVAILABLE_MODELS = AVAILABLE_MODELS[_DEFAULT_MODEL],
tile_shape: Tuple[int, int] = (None, None),
halo: Tuple[int, int] = (None, None),
predictor: Optional[SamPredictor] = None, # not accessible via magicgui widget, not a recognised type
) -> Optional[Viewer]:
"""The 2d annotation tool.
Expand Down Expand Up @@ -224,39 +374,51 @@ def annotator_2d(
global PREDICTOR, IMAGE_EMBEDDINGS, AMG
AMG = None

if raw is None:
raise RuntimeError("You must provide a raw input image source.")

# Check if user provided non-zero input values to tile_shape and halo parameters
if all(tile_shape) is False:
tile_shape = None
if all(halo) is False:
halo = None

if predictor is None:
PREDICTOR = util.get_sam_model(model_type=model_type)
PREDICTOR = util.get_sam_model(model_type=model_type.name)
else:
PREDICTOR = predictor

# TODO: check if a pre-computed image embedding already matching the raw input image exists
# and if so use it without re-computing (warn the user you are re-using a previously generated embedding).

# TODO: dispatch long running computations to napari thread_worker
# and connect return to _start_2d_annotation function
# TODO: also consider a progress bar for long running computations
# see https://github.com/pyapp-kit/magicgui/issues/577#issuecomment-1705555387
print("Computing image embedding...")
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
wrong_file_callback=show_wrong_file_warning
)
if precompute_amg_state and (embedding_path is not None):
print("Computing AMG...")
AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path)

# show the PCA of the image embeddings
if show_embeddings:
print("Show embeddings")
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
v.add_image(embedding_vis, name="embeddings", scale=scale)

# we set the pre-computed image embeddings if we don't use tiling
# (if we use tiling we cannot directly set it because the tile will be chosen dynamically)
if tile_shape is None:
print("tiling precompute")
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)

# viewer is freshly initialized
if v is None:
v = _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings)
# we use an existing viewer and just update all the layers
else:
_update_viewer(v, raw, show_embeddings, segmentation_result)

#
# start the viewer
#
vutil.clear_annotations(v, clear_segmentations=False)

if return_viewer:
return v

napari.run()
_start_2d_annotation(v, raw, segmentation_result)
print("All done")
return v


def main():
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def clear_annotations(v: napari.Viewer, clear_segmentations=True) -> None:
"""@private"""
v.layers["prompts"].data = []
v.layers["prompts"].data = [] # warning, this also erases v.layers["prompts"].properties["labels"]
v.layers["prompts"].refresh()
if "box_prompts" in v.layers:
v.layers["box_prompts"].data = []
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Helper functions for downloading Segment Anything models and predicting image embeddings.
"""

from enum import Enum
import hashlib
import os
import pickle
Expand Down Expand Up @@ -61,7 +62,7 @@
# this is the default model used in micro_sam
# currently set to the default vit_h
_DEFAULT_MODEL = "vit_h"

AVAILABLE_MODELS = Enum("AVAILABLE MODELS", _MODEL_URLS)

# TODO define the proper type for image embeddings
ImageEmbeddings = Dict[str, Any]
Expand Down

0 comments on commit ce542f0

Please sign in to comment.