-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype napari plugin widget for annotator_2d #177
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,19 @@ | ||
import warnings | ||
import os | ||
from pathlib import Path | ||
from typing import Optional, Tuple | ||
import warnings | ||
|
||
import napari | ||
from napari.types import ImageData, LabelsData | ||
import numpy as np | ||
|
||
from magicgui import magicgui | ||
from magicgui import magicgui, magic_factory | ||
from napari import Viewer | ||
from segment_anything import SamPredictor | ||
|
||
from .. import instance_segmentation, util | ||
from ..precompute_state import cache_amg_state | ||
from ..util import AVAILABLE_MODELS, _DEFAULT_MODEL | ||
from ..visualization import project_embeddings_for_visualization | ||
from . import util as vutil | ||
from .gui_utils import show_wrong_file_warning | ||
|
@@ -88,52 +92,89 @@ def _get_shape(raw): | |
return shape | ||
|
||
|
||
def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings): | ||
v = Viewer() | ||
def _start_2d_annotation(v: napari.Viewer, raw, segmentation_result=None): | ||
print("_start_2d_annotation") | ||
_setup_layers(v, raw, segmentation_result=segmentation_result) | ||
_setup_dock_widgets(v) | ||
_add_key_bindings(v) | ||
|
||
# | ||
# initialize the viewer and add layers | ||
# | ||
|
||
v.add_image(raw) | ||
def _setup_layers(v: napari.Viewer, raw, segmentation_result=None): | ||
print("_setup_layers") | ||
# TODO: ideally do not hard code layer names | ||
shape = _get_shape(raw) | ||
|
||
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation") | ||
if segmentation_result is None: | ||
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="committed_objects") | ||
# raw input image data | ||
if "raw" not in v.layers: | ||
v.add_image(raw) | ||
else: | ||
v.add_labels(segmentation_result, name="committed_objects") | ||
v.layers["committed_objects"].new_colormap() # randomize colors so it is easy to see when object committed | ||
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object") | ||
if not np.allclose(raw, v.layers["raw"].data): | ||
v.layers["raw"].data = raw | ||
v.layers["raw"].refresh() | ||
|
||
# show the PCA of the image embeddings | ||
if show_embeddings: | ||
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS) | ||
v.add_image(embedding_vis, name="embeddings", scale=scale) | ||
# auto_segmentation | ||
if "auto_segmentation" not in v.layers: | ||
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation") | ||
else: | ||
v.layers["auto_segmentation"].data = np.zeros(shape, dtype="uint32") | ||
v.layers["auto_segmentation"].refresh() | ||
|
||
# committed_objects / segmentation result | ||
if "committed_objects" not in v.layers: | ||
if segmentation_result is not None: | ||
v.add_labels(segmentation_result, name="committed_objects") | ||
else: | ||
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="committed_objects") | ||
else: | ||
v.layers["committed_objects"].data = np.zeros(shape, dtype="uint32") | ||
v.layers["committed_objects"].refresh() | ||
|
||
labels = ["positive", "negative"] | ||
prompts = v.add_points( | ||
data=[[0.0, 0.0], [0.0, 0.0]], # FIXME workaround | ||
name="prompts", | ||
properties={"label": labels}, | ||
edge_color="label", | ||
edge_color_cycle=vutil.LABEL_COLOR_CYCLE, | ||
symbol="o", | ||
face_color="transparent", | ||
edge_width=0.5, | ||
size=12, | ||
ndim=2, | ||
) | ||
prompts.edge_color_mode = "cycle" | ||
# current_object | ||
if "current_object" not in v.layers: | ||
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object") | ||
else: | ||
v.layers["current_object"].data = np.zeros(shape, dtype="uint32") | ||
v.layers["current_object"].refresh() | ||
|
||
# prompts | ||
if "prompts" not in v.layers: | ||
labels = ["positive", "negative"] | ||
prompts = v.add_points( | ||
data=[[0.0, 0.0], [0.0, 0.0]], # FIXME workaround | ||
name="prompts", | ||
properties={"label": labels}, | ||
edge_color="label", | ||
edge_color_cycle=vutil.LABEL_COLOR_CYCLE, | ||
symbol="o", | ||
face_color="transparent", | ||
edge_width=0.5, | ||
size=12, | ||
ndim=2, | ||
) | ||
prompts.edge_color_mode = "cycle" | ||
else: | ||
v.layers["prompts"].data = [] | ||
v.layers["prompts"].refresh() | ||
# box prompts | ||
if "box_prompts" not in v.layers: | ||
v.add_shapes( | ||
name="box_prompts", | ||
face_color="transparent", | ||
edge_color="green", | ||
edge_width=4, | ||
) | ||
# remove dummy point, and/or other existing point prompts | ||
v.layers["box_prompts"].data = [] # warning, this also erases v.layers["prompts"].properties["labels"] | ||
v.layers["box_prompts"].refresh() | ||
|
||
v.add_shapes( | ||
face_color="transparent", edge_color="green", edge_width=4, name="box_prompts" | ||
) | ||
|
||
# | ||
# add the widgets | ||
# | ||
def _setup_dock_widgets(v: napari.Viewer): | ||
print("_setup_dock_widgets") | ||
#TODO: how can we check if these dock widgets are already opened or not? | ||
|
||
# TODO: try to make this a cleaner code design to get the prompt_widget | ||
prompts = v.layers["prompts"] | ||
labels = ["positive", "negative"] | ||
prompt_widget = vutil.create_prompt_menu(prompts, labels) | ||
v.window.add_dock_widget(prompt_widget) | ||
|
||
|
@@ -142,10 +183,9 @@ def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings): | |
v.window.add_dock_widget(vutil._commit_segmentation_widget) | ||
v.window.add_dock_widget(vutil._clear_widget) | ||
|
||
# | ||
# key bindings | ||
# | ||
|
||
def _add_key_bindings(v: napari.Viewer): | ||
print("_add_key_bindings") | ||
@v.bind_key("s") | ||
def _segmet(v): | ||
_segment_widget(v) | ||
|
@@ -154,6 +194,8 @@ def _segmet(v): | |
def _commit(v): | ||
vutil._commit_segmentation_widget(v) | ||
|
||
# TODO: try to make this a cleaner code design | ||
prompts = v.layers["prompts"] | ||
@v.bind_key("t") | ||
def _toggle_label(event=None): | ||
vutil.toggle_label(prompts) | ||
|
@@ -162,22 +204,6 @@ def _toggle_label(event=None): | |
def clear_prompts(v): | ||
vutil.clear_annotations(v) | ||
|
||
return v | ||
|
||
|
||
def _update_viewer(v, raw, show_embeddings, segmentation_result): | ||
if show_embeddings or segmentation_result is not None: | ||
raise NotImplementedError | ||
|
||
# update the image layer | ||
v.layers["raw"].data = raw | ||
shape = _get_shape(raw) | ||
|
||
# update the segmentation layers | ||
v.layers["auto_segmentation"].data = np.zeros(shape, dtype="uint32") | ||
v.layers["committed_objects"].data = np.zeros(shape, dtype="uint32") | ||
v.layers["current_object"].data = np.zeros(shape, dtype="uint32") | ||
|
||
|
||
def annotator_2d( | ||
raw: np.ndarray, | ||
|
@@ -191,6 +217,41 @@ def annotator_2d( | |
v: Optional[Viewer] = None, | ||
predictor: Optional[SamPredictor] = None, | ||
precompute_amg_state: bool = False, | ||
) -> Optional[Viewer]: | ||
"""This function can be called in python (or from CLI) to start the annotation tool""" | ||
if v is None: | ||
v = napari.Viewer() | ||
|
||
v.add_image(raw) | ||
v = annotator_2d_plugin( | ||
raw, | ||
embedding_path=embedding_path, | ||
show_embeddings=show_embeddings, | ||
segmentation_result=segmentation_result, | ||
precompute_amg_state=precompute_amg_state, | ||
tile_shape=tile_shape, | ||
halo=halo, | ||
predictor=predictor, | ||
) | ||
|
||
if return_viewer is True: | ||
return v | ||
|
||
napari.run() | ||
|
||
|
||
@magic_factory(call_button="Start 2D annotation") | ||
def annotator_2d_plugin( | ||
raw: ImageData, | ||
v: napari.Viewer, | ||
embedding_path: os.PathLike = "./embeddings", | ||
show_embeddings: bool = False, | ||
segmentation_result: Optional[LabelsData] = None, | ||
precompute_amg_state: bool = False, | ||
model_type: AVAILABLE_MODELS = AVAILABLE_MODELS[_DEFAULT_MODEL], | ||
tile_shape: Tuple[int, int] = (None, None), | ||
halo: Tuple[int, int] = (None, None), | ||
predictor: Optional[SamPredictor] = None, # not accessible via magicgui widget, not a recognised type | ||
) -> Optional[Viewer]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding the comment: that's ok. This doesn't make sense to pass via napari anyways. The predictor is currently exposed for 2 reasons:
As far as I understand your comment, it doesn't hurt to leave it in for now. After we have merged this PR I can rethink a bit how this is handled because: re 1. we need to change the |
||
"""The 2d annotation tool. | ||
|
||
|
@@ -224,39 +285,51 @@ def annotator_2d( | |
global PREDICTOR, IMAGE_EMBEDDINGS, AMG | ||
AMG = None | ||
|
||
if raw is None: | ||
raise RuntimeError("You must provide a raw input image source.") | ||
|
||
# Check if user provided non-zero input values to tile_shape and halo parameters | ||
if all(tile_shape) is False: | ||
tile_shape = None | ||
if all(halo) is False: | ||
halo = None | ||
|
||
if predictor is None: | ||
PREDICTOR = util.get_sam_model(model_type=model_type) | ||
PREDICTOR = util.get_sam_model(model_type=model_type.name) | ||
else: | ||
PREDICTOR = predictor | ||
|
||
# TODO: check if a pre-computed image embedding already matching the raw input image exists | ||
# and if so use it without re-computing (warn the user you are re-using a previously generated embedding). | ||
|
||
# TODO: dispatch long running computations to napari thread_worker | ||
# and connect return to _start_2d_annotation function | ||
# TODO: also consider a progress bar for long running computations | ||
# see https://github.com/pyapp-kit/magicgui/issues/577#issuecomment-1705555387 | ||
print("Computing image embedding...") | ||
IMAGE_EMBEDDINGS = util.precompute_image_embeddings( | ||
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, | ||
wrong_file_callback=show_wrong_file_warning | ||
) | ||
if precompute_amg_state and (embedding_path is not None): | ||
print("Computing AMG...") | ||
AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path) | ||
|
||
# show the PCA of the image embeddings | ||
if show_embeddings: | ||
print("Show embeddings") | ||
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS) | ||
v.add_image(embedding_vis, name="embeddings", scale=scale) | ||
|
||
# we set the pre-computed image embeddings if we don't use tiling | ||
# (if we use tiling we cannot directly set it because the tile will be chosen dynamically) | ||
if tile_shape is None: | ||
print("tiling precompute") | ||
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS) | ||
|
||
# viewer is freshly initialized | ||
if v is None: | ||
v = _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just fyi: this condition is for the |
||
# we use an existing viewer and just update all the layers | ||
else: | ||
_update_viewer(v, raw, show_embeddings, segmentation_result) | ||
|
||
# | ||
# start the viewer | ||
# | ||
vutil.clear_annotations(v, clear_segmentations=False) | ||
|
||
if return_viewer: | ||
return v | ||
|
||
napari.run() | ||
_start_2d_annotation(v, raw, segmentation_result) | ||
print("All done") | ||
return v | ||
|
||
|
||
def main(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is needed by the
image_series_annotator
. I understand why you removed it (the current logic will not work any longer and we need to change it when adapting theimage_series_annotator
), but please leave in this function for now (without calling it), so that we have it as a reference for the series annotator.