Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype napari plugin widget for annotator_2d #177

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions micro_sam/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ contributions:
- id: micro-sam.sample_data_segmentation
python_name: micro_sam.sample_data:sample_data_segmentation
title: Load segmentation sample data from micro-sam plugin
- id: micro-sam.annotator_2d_plugin
python_name: micro_sam.sam_annotator.annotator_2d:annotator_2d_plugin
title: micro-sam 2D annotator
widgets:
- command: micro-sam.annotator_2d_plugin
display_name: micro-sam 2D annotator
sample_data:
- command: micro-sam.sample_data_image_series
display_name: Image series example data
Expand Down
221 changes: 147 additions & 74 deletions micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import warnings
import os
from pathlib import Path
from typing import Optional, Tuple
import warnings

import napari
from napari.types import ImageData, LabelsData
import numpy as np

from magicgui import magicgui
from magicgui import magicgui, magic_factory
from napari import Viewer
from segment_anything import SamPredictor

from .. import instance_segmentation, util
from ..precompute_state import cache_amg_state
from ..util import AVAILABLE_MODELS, _DEFAULT_MODEL
from ..visualization import project_embeddings_for_visualization
from . import util as vutil
from .gui_utils import show_wrong_file_warning
Expand Down Expand Up @@ -88,52 +92,89 @@ def _get_shape(raw):
return shape


def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings):
v = Viewer()
def _start_2d_annotation(v: napari.Viewer, raw, segmentation_result=None):
print("_start_2d_annotation")
_setup_layers(v, raw, segmentation_result=segmentation_result)
_setup_dock_widgets(v)
_add_key_bindings(v)

#
# initialize the viewer and add layers
#

v.add_image(raw)
def _setup_layers(v: napari.Viewer, raw, segmentation_result=None):
print("_setup_layers")
# TODO: ideally do not hard code layer names
shape = _get_shape(raw)

v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation")
if segmentation_result is None:
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="committed_objects")
# raw input image data
if "raw" not in v.layers:
v.add_image(raw)
else:
v.add_labels(segmentation_result, name="committed_objects")
v.layers["committed_objects"].new_colormap() # randomize colors so it is easy to see when object committed
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object")
if not np.allclose(raw, v.layers["raw"].data):
v.layers["raw"].data = raw
v.layers["raw"].refresh()

# show the PCA of the image embeddings
if show_embeddings:
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
v.add_image(embedding_vis, name="embeddings", scale=scale)
# auto_segmentation
if "auto_segmentation" not in v.layers:
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation")
else:
v.layers["auto_segmentation"].data = np.zeros(shape, dtype="uint32")
v.layers["auto_segmentation"].refresh()

# committed_objects / segmentation result
if "committed_objects" not in v.layers:
if segmentation_result is not None:
v.add_labels(segmentation_result, name="committed_objects")
else:
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="committed_objects")
else:
v.layers["committed_objects"].data = np.zeros(shape, dtype="uint32")
v.layers["committed_objects"].refresh()

labels = ["positive", "negative"]
prompts = v.add_points(
data=[[0.0, 0.0], [0.0, 0.0]], # FIXME workaround
name="prompts",
properties={"label": labels},
edge_color="label",
edge_color_cycle=vutil.LABEL_COLOR_CYCLE,
symbol="o",
face_color="transparent",
edge_width=0.5,
size=12,
ndim=2,
)
prompts.edge_color_mode = "cycle"
# current_object
if "current_object" not in v.layers:
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object")
else:
v.layers["current_object"].data = np.zeros(shape, dtype="uint32")
v.layers["current_object"].refresh()

# prompts
if "prompts" not in v.layers:
labels = ["positive", "negative"]
prompts = v.add_points(
data=[[0.0, 0.0], [0.0, 0.0]], # FIXME workaround
name="prompts",
properties={"label": labels},
edge_color="label",
edge_color_cycle=vutil.LABEL_COLOR_CYCLE,
symbol="o",
face_color="transparent",
edge_width=0.5,
size=12,
ndim=2,
)
prompts.edge_color_mode = "cycle"
else:
v.layers["prompts"].data = []
v.layers["prompts"].refresh()
# box prompts
if "box_prompts" not in v.layers:
v.add_shapes(
name="box_prompts",
face_color="transparent",
edge_color="green",
edge_width=4,
)
# remove dummy point, and/or other existing point prompts
v.layers["box_prompts"].data = [] # warning, this also erases v.layers["prompts"].properties["labels"]
v.layers["box_prompts"].refresh()

v.add_shapes(
face_color="transparent", edge_color="green", edge_width=4, name="box_prompts"
)

#
# add the widgets
#
def _setup_dock_widgets(v: napari.Viewer):
print("_setup_dock_widgets")
#TODO: how can we check if these dock widgets are already opened or not?

# TODO: try to make this a cleaner code design to get the prompt_widget
prompts = v.layers["prompts"]
labels = ["positive", "negative"]
prompt_widget = vutil.create_prompt_menu(prompts, labels)
v.window.add_dock_widget(prompt_widget)

Expand All @@ -142,10 +183,9 @@ def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings):
v.window.add_dock_widget(vutil._commit_segmentation_widget)
v.window.add_dock_widget(vutil._clear_widget)

#
# key bindings
#

def _add_key_bindings(v: napari.Viewer):
print("_add_key_bindings")
@v.bind_key("s")
def _segmet(v):
_segment_widget(v)
Expand All @@ -154,6 +194,8 @@ def _segmet(v):
def _commit(v):
vutil._commit_segmentation_widget(v)

# TODO: try to make this a cleaner code design
prompts = v.layers["prompts"]
@v.bind_key("t")
def _toggle_label(event=None):
vutil.toggle_label(prompts)
Expand All @@ -162,22 +204,6 @@ def _toggle_label(event=None):
def clear_prompts(v):
vutil.clear_annotations(v)

return v


def _update_viewer(v, raw, show_embeddings, segmentation_result):
if show_embeddings or segmentation_result is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is needed by the image_series_annotator. I understand why you removed it (the current logic will not work any longer and we need to change it when adapting the image_series_annotator), but please leave in this function for now (without calling it), so that we have it as a reference for the series annotator.

raise NotImplementedError

# update the image layer
v.layers["raw"].data = raw
shape = _get_shape(raw)

# update the segmentation layers
v.layers["auto_segmentation"].data = np.zeros(shape, dtype="uint32")
v.layers["committed_objects"].data = np.zeros(shape, dtype="uint32")
v.layers["current_object"].data = np.zeros(shape, dtype="uint32")


def annotator_2d(
raw: np.ndarray,
Expand All @@ -191,6 +217,41 @@ def annotator_2d(
v: Optional[Viewer] = None,
predictor: Optional[SamPredictor] = None,
precompute_amg_state: bool = False,
) -> Optional[Viewer]:
"""This function can be called in python (or from CLI) to start the annotation tool"""
if v is None:
v = napari.Viewer()

v.add_image(raw)
v = annotator_2d_plugin(
raw,
embedding_path=embedding_path,
show_embeddings=show_embeddings,
segmentation_result=segmentation_result,
precompute_amg_state=precompute_amg_state,
tile_shape=tile_shape,
halo=halo,
predictor=predictor,
)

if return_viewer is True:
return v

napari.run()


@magic_factory(call_button="Start 2D annotation")
def annotator_2d_plugin(
raw: ImageData,
v: napari.Viewer,
embedding_path: os.PathLike = "./embeddings",
show_embeddings: bool = False,
segmentation_result: Optional[LabelsData] = None,
precompute_amg_state: bool = False,
model_type: AVAILABLE_MODELS = AVAILABLE_MODELS[_DEFAULT_MODEL],
tile_shape: Tuple[int, int] = (None, None),
halo: Tuple[int, int] = (None, None),
predictor: Optional[SamPredictor] = None, # not accessible via magicgui widget, not a recognised type
) -> Optional[Viewer]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the comment: that's ok. This doesn't make sense to pass via napari anyways. The predictor is currently exposed for 2 reasons:

  1. To pass the already initialized predictor in the image_series_annotator, where the logic for image embeddings is done externally.
  2. To enable passing predictors that were loaded from a custom checkpoint.

As far as I understand your comment, it doesn't hurt to leave it in for now. After we have merged this PR I can rethink a bit how this is handled because: re 1. we need to change the image_series_annotator logic anyways and re 2.: it would be good to enable loading of custom models also from the napari plugin. (Here: custom model means that we are using a url or filepath to specifiy the checkpoint rather than one of the names from AVAILABLE_MODELS).

"""The 2d annotation tool.

Expand Down Expand Up @@ -224,39 +285,51 @@ def annotator_2d(
global PREDICTOR, IMAGE_EMBEDDINGS, AMG
AMG = None

if raw is None:
raise RuntimeError("You must provide a raw input image source.")

# Check if user provided non-zero input values to tile_shape and halo parameters
if all(tile_shape) is False:
tile_shape = None
if all(halo) is False:
halo = None

if predictor is None:
PREDICTOR = util.get_sam_model(model_type=model_type)
PREDICTOR = util.get_sam_model(model_type=model_type.name)
else:
PREDICTOR = predictor

# TODO: check if a pre-computed image embedding already matching the raw input image exists
# and if so use it without re-computing (warn the user you are re-using a previously generated embedding).

# TODO: dispatch long running computations to napari thread_worker
# and connect return to _start_2d_annotation function
# TODO: also consider a progress bar for long running computations
# see https://github.com/pyapp-kit/magicgui/issues/577#issuecomment-1705555387
print("Computing image embedding...")
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
wrong_file_callback=show_wrong_file_warning
)
if precompute_amg_state and (embedding_path is not None):
print("Computing AMG...")
AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path)

# show the PCA of the image embeddings
if show_embeddings:
print("Show embeddings")
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
v.add_image(embedding_vis, name="embeddings", scale=scale)

# we set the pre-computed image embeddings if we don't use tiling
# (if we use tiling we cannot directly set it because the tile will be chosen dynamically)
if tile_shape is None:
print("tiling precompute")
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)

# viewer is freshly initialized
if v is None:
v = _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fyi: this condition is for the image_series_annotator: in that case we pass the napari viewer object via the v argument. Hence it is not None and _update_viewer is called to replace all relevant layers.
I just wanted to clarify because it relates to my comment on _update_viewer. We can go ahead and remove the condition since the image_series_annotator needs to be redone after finalizing the 2d annotator.

# we use an existing viewer and just update all the layers
else:
_update_viewer(v, raw, show_embeddings, segmentation_result)

#
# start the viewer
#
vutil.clear_annotations(v, clear_segmentations=False)

if return_viewer:
return v

napari.run()
_start_2d_annotation(v, raw, segmentation_result)
print("All done")
return v


def main():
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def clear_annotations(v: napari.Viewer, clear_segmentations=True) -> None:
"""@private"""
v.layers["prompts"].data = []
v.layers["prompts"].data = [] # warning, this also erases v.layers["prompts"].properties["labels"]
v.layers["prompts"].refresh()
if "box_prompts" in v.layers:
v.layers["box_prompts"].data = []
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Helper functions for downloading Segment Anything models and predicting image embeddings.
"""

from enum import Enum
import hashlib
import os
import pickle
Expand Down Expand Up @@ -61,7 +62,7 @@
# this is the default model used in micro_sam
# currently set to the default vit_h
_DEFAULT_MODEL = "vit_h"

AVAILABLE_MODELS = Enum("AVAILABLE MODELS", _MODEL_URLS)

# TODO define the proper type for image embeddings
ImageEmbeddings = Dict[str, Any]
Expand Down
Loading