diff --git a/labelme/ai/__init__.py b/labelme/ai/__init__.py index 717e4924b..aee927f71 100644 --- a/labelme/ai/__init__.py +++ b/labelme/ai/__init__.py @@ -1,6 +1,7 @@ import gdown from .efficient_sam import EfficientSam +from .segment_anything2_model import SegmentAnything2Model from .segment_anything_model import SegmentAnythingModel from .text_to_annotation import get_rectangles_from_texts # NOQA: F401 from .text_to_annotation import get_shapes_from_annotations # NOQA: F401 @@ -86,6 +87,31 @@ def __init__(self): ), ) +class SAM2HieraL(SegmentAnything2Model): + name = "SegmentAnything2 (accuracy)" + + def __init__(self): + super().__init__( + encoder_path=gdown.cached_download( + url="https://github.com/jakep72/labelme/releases/download/SAM2/sam2_large.encoder.onnx", # NOQA + ), + decoder_path=gdown.cached_download( + url="https://github.com/jakep72/labelme/releases/download/SAM2/sam2_large.decoder.onnx" # NOQA + ), + ) + +class SAM2HieraT(SegmentAnything2Model): + name = "SegmentAnything2 (speed)" + + def __init__(self): + super().__init__( + encoder_path=gdown.cached_download( + url="https://github.com/jakep72/labelme/releases/download/SAM2/sam2_hiera_tiny.encoder.onnx" # NOQA + ), + decoder_path=gdown.cached_download( + url="https://github.com/jakep72/labelme/releases/download/SAM2/sam2_hiera_tiny.decoder.onnx" # NOQA + ), + ) MODELS = [ SegmentAnythingModelVitB, @@ -93,4 +119,6 @@ def __init__(self): SegmentAnythingModelVitH, EfficientSamVitT, EfficientSamVitS, + SAM2HieraT, + SAM2HieraL ] diff --git a/labelme/ai/_utils.py b/labelme/ai/_utils.py index 6806a5b06..bf6972470 100644 --- a/labelme/ai/_utils.py +++ b/labelme/ai/_utils.py @@ -10,13 +10,12 @@ def _get_contour_length(contour): contour_end = np.r_[contour[1:], contour[0:1]] return np.linalg.norm(contour_end - contour_start, axis=1).sum() - def compute_polygon_from_mask(mask): contours = skimage.measure.find_contours(np.pad(mask, pad_width=1)) if len(contours) == 0: logger.warning("No contour found, so returning empty polygon.") return np.empty((0, 2), dtype=np.float32) - + contour = max(contours, key=_get_contour_length) POLYGON_APPROX_TOLERANCE = 0.004 polygon = skimage.measure.approximate_polygon( diff --git a/labelme/ai/segment_anything2_model.py b/labelme/ai/segment_anything2_model.py new file mode 100644 index 000000000..f7e20c549 --- /dev/null +++ b/labelme/ai/segment_anything2_model.py @@ -0,0 +1,352 @@ +import collections +import threading +from typing import Any + +import imgviz +import numpy as np +import onnxruntime +import skimage +from numpy import ndarray + +from ..logger import logger +from . import _utils + + +class SegmentAnything2Model: + """Segmentation model using Segment Anything 2 (SAM2)""" + def __init__(self, encoder_path, decoder_path) -> None: + self.model = SegmentAnything2ONNX(encoder_path, decoder_path) + self._lock = threading.Lock() + self._image_embedding_cache = collections.OrderedDict() + self._thread = None + + def set_image(self, image: np.ndarray): + with self._lock: + self._image = image + self._image_embedding = self._image_embedding_cache.get( + self._image.tobytes() + ) + + if self._image_embedding is None: + self._thread = threading.Thread( + target=self._compute_and_cache_image_embedding + ) + self._thread.start() + + def _compute_and_cache_image_embedding(self): + with self._lock: + logger.debug("Computing image embedding...") + self._image_embedding = self.model.encode(self._image) + if len(self._image_embedding_cache) > 10: + self._image_embedding_cache.popitem(last=False) + self._image_embedding_cache[self._image.tobytes()] = self._image_embedding + logger.debug("Done computing image embedding.") + + def _get_image_embedding(self): + if self._thread is not None: + self._thread.join() + self._thread = None + with self._lock: + return self._image_embedding + + def predict_mask_from_points(self, points, point_labels): + embedding = self._get_image_embedding() + masks, scores, orig_im_size = self.model.predict_masks(embedding, points, point_labels) + best_mask = masks[np.argmax(scores)] + best_mask = imgviz.resize(best_mask, + height=orig_im_size[0], + width=orig_im_size[1]) + + best_mask = np.array([[best_mask]]) + best_mask = best_mask[0,0] + mask = best_mask > 0.0 + + MIN_SIZE_RATIO = 0.05 + skimage.morphology.remove_small_objects(mask, min_size=mask.sum()*MIN_SIZE_RATIO, out=mask) + + return mask + + def predict_polygon_from_points(self, points, point_labels): + mask = self.predict_mask_from_points(points=points, point_labels=point_labels) + return _utils.compute_polygon_from_mask(mask=mask) + +class SegmentAnything2ONNX: + """Segmentation model using Segment Anything 2 (SAM2)""" + + def __init__(self, encoder_model_path, decoder_model_path) -> None: + self.encoder = SAM2ImageEncoder(encoder_model_path) + self.decoder = SAM2ImageDecoder( + decoder_model_path, self.encoder.input_shape[2:] + ) + + def encode(self, cv_image: np.ndarray) -> list[np.ndarray]: + original_size = cv_image.shape[:2] + high_res_feats_0, high_res_feats_1, image_embed = self.encoder(cv_image) + return { + "high_res_feats_0": high_res_feats_0, + "high_res_feats_1": high_res_feats_1, + "image_embedding": image_embed, + "original_size": original_size, + } + + def predict_masks(self, embedding, points, labels) -> list[np.ndarray]: + points, labels = np.array(points), np.array(labels) + + image_embedding = embedding["image_embedding"] + high_res_feats_0 = embedding["high_res_feats_0"] + high_res_feats_1 = embedding["high_res_feats_1"] + original_size = embedding["original_size"] + self.decoder.set_image_size(original_size) + masks, scores, orig_im_size = self.decoder( + image_embedding, + high_res_feats_0, + high_res_feats_1, + points, + labels, + ) + + return masks, scores, orig_im_size + +class SAM2ImageEncoder: + def __init__(self, path: str) -> None: + # Initialize model + self.session = onnxruntime.InferenceSession( + path, providers=onnxruntime.get_available_providers() + ) + + # Get model info + self.get_input_details() + self.get_output_details() + + def __call__( + self, image: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + return self.encode_image(image) + + def encode_image( + self, image: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + input_tensor = self.prepare_input(image) + + outputs = self.infer(input_tensor) + + return self.process_output(outputs) + + def prepare_input(self, image: np.ndarray) -> np.ndarray: + self.img_height, self.img_width = image.shape[:2] + + input_img = image[:, :, [2, 1, 0]] + + # Resize the image using imgviz + input_img = imgviz.resize(input_img, height=self.input_height, width=self.input_width) + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + input_img = (input_img / 255.0 - mean) / std + input_img = input_img.transpose(2, 0, 1) + input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32) + + return input_tensor + + def infer(self, input_tensor: np.ndarray) -> list[np.ndarray]: + outputs = self.session.run( + self.output_names, {self.input_names[0]: input_tensor} + ) + return outputs + + def process_output( + self, outputs: list[np.ndarray] + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + return outputs[0], outputs[1], outputs[2] + + def get_input_details(self) -> None: + model_inputs = self.session.get_inputs() + self.input_names = [ + model_inputs[i].name for i in range(len(model_inputs)) + ] + + self.input_shape = model_inputs[0].shape + self.input_height = self.input_shape[2] + self.input_width = self.input_shape[3] + + def get_output_details(self) -> None: + model_outputs = self.session.get_outputs() + self.output_names = [ + model_outputs[i].name for i in range(len(model_outputs)) + ] + +class SAM2ImageDecoder: + def __init__( + self, + path: str, + encoder_input_size: tuple[int, int], + orig_im_size: tuple[int, int] = None, + mask_threshold: float = 0.0, + ) -> None: + # Initialize model + self.session = onnxruntime.InferenceSession( + path, providers=onnxruntime.get_available_providers() + ) + + self.orig_im_size = ( + orig_im_size if orig_im_size is not None else encoder_input_size + ) + self.encoder_input_size = encoder_input_size + self.mask_threshold = mask_threshold + self.scale_factor = 4 + + # Get model info + self.get_input_details() + self.get_output_details() + + def __call__( + self, + image_embed: np.ndarray, + high_res_feats_0: np.ndarray, + high_res_feats_1: np.ndarray, + point_coords: list[np.ndarray] | np.ndarray, + point_labels: list[np.ndarray] | np.ndarray, + ) -> tuple[list[np.ndarray], ndarray]: + + return self.predict( + image_embed, + high_res_feats_0, + high_res_feats_1, + point_coords, + point_labels, + ) + + def predict( + self, + image_embed: np.ndarray, + high_res_feats_0: np.ndarray, + high_res_feats_1: np.ndarray, + point_coords: list[np.ndarray] | np.ndarray, + point_labels: list[np.ndarray] | np.ndarray, + ) -> tuple[list[np.ndarray], ndarray]: + + inputs = self.prepare_inputs( + image_embed, + high_res_feats_0, + high_res_feats_1, + point_coords, + point_labels, + ) + + outputs = self.infer(inputs) + + return self.process_output(outputs) + + def prepare_inputs( + self, + image_embed: np.ndarray, + high_res_feats_0: np.ndarray, + high_res_feats_1: np.ndarray, + point_coords: list[np.ndarray] | np.ndarray, + point_labels: list[np.ndarray] | np.ndarray, + ): + + input_point_coords, input_point_labels = self.prepare_points( + point_coords, point_labels + ) + + num_labels = input_point_labels.shape[0] + mask_input = np.zeros( + ( + num_labels, + 1, + self.encoder_input_size[0] // self.scale_factor, + self.encoder_input_size[1] // self.scale_factor, + ), + dtype=np.float32, + ) + has_mask_input = np.array([0], dtype=np.float32) + + return ( + image_embed, + high_res_feats_0, + high_res_feats_1, + input_point_coords, + input_point_labels, + mask_input, + has_mask_input, + ) + + def prepare_points( + self, + point_coords: list[np.ndarray] | np.ndarray, + point_labels: list[np.ndarray] | np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + + if isinstance(point_coords, np.ndarray): + input_point_coords = point_coords[np.newaxis, ...] + input_point_labels = point_labels[np.newaxis, ...] + else: + max_num_points = max([coords.shape[0] for coords in point_coords]) + # We need to make sure that all inputs have the same number of points + # Add invalid points to pad the input (0, 0) with -1 value for labels + input_point_coords = np.zeros( + (len(point_coords), max_num_points, 2), dtype=np.float32 + ) + input_point_labels = ( + np.ones((len(point_coords), max_num_points), dtype=np.float32) + * -1 + ) + + for i, (coords, labels) in enumerate( + zip(point_coords, point_labels) + ): + input_point_coords[i, : coords.shape[0], :] = coords + input_point_labels[i, : labels.shape[0]] = labels + + input_point_coords[..., 0] = ( + input_point_coords[..., 0] + / self.orig_im_size[1] + * self.encoder_input_size[1] + ) # Normalize x + input_point_coords[..., 1] = ( + input_point_coords[..., 1] + / self.orig_im_size[0] + * self.encoder_input_size[0] + ) # Normalize y + + return input_point_coords.astype(np.float32), input_point_labels.astype( + np.float32 + ) + + def infer(self, inputs) -> list[np.ndarray]: + outputs = self.session.run( + self.output_names, + { + self.input_names[i]: inputs[i] + for i in range(len(self.input_names)) + }, + ) + return outputs + + def process_output( + self, outputs: list[np.ndarray] + ) -> tuple[list[ndarray | Any], ndarray[Any, Any]]: + + scores = outputs[1].squeeze() + masks = outputs[0][0] + + return (masks, + scores, + self.orig_im_size + ) + + def set_image_size(self, orig_im_size: tuple[int, int]) -> None: + self.orig_im_size = orig_im_size + + def get_input_details(self) -> None: + model_inputs = self.session.get_inputs() + self.input_names = [ + model_inputs[i].name for i in range(len(model_inputs)) + ] + + def get_output_details(self) -> None: + model_outputs = self.session.get_outputs() + self.output_names = [ + model_outputs[i].name for i in range(len(model_outputs)) + ] diff --git a/labelme/app.py b/labelme/app.py index 207a769cf..feae639c1 100644 --- a/labelme/app.py +++ b/labelme/app.py @@ -396,6 +396,22 @@ def __init__( if self.canvas.createMode == "ai_mask" else None ) + + createAiBatchMode = action( + self.tr("Create AI-Batch"), + lambda: self.toggleDrawMode(False, createMode="ai_batch"), + None, + "objects", + self.tr("Start drawing ai_batch. Ctrl+LeftClick ends creation."), + enabled=False, + ) + createAiBatchMode.changed.connect( + lambda: self.canvas.initializeAiModel( + name=self._selectAiModelComboBox.currentText() + ) + if self.canvas.createMode == "ai_batch" + else None + ) editMode = action( self.tr("Edit Polygons"), self.setEditMode, @@ -649,6 +665,7 @@ def __init__( createLineStripMode=createLineStripMode, createAiPolygonMode=createAiPolygonMode, createAiMaskMode=createAiMaskMode, + createAiBatchMode=createAiBatchMode, zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, @@ -687,6 +704,7 @@ def __init__( createLineStripMode, createAiPolygonMode, createAiMaskMode, + createAiBatchMode, editMode, edit, duplicate, @@ -707,6 +725,7 @@ def __init__( createLineStripMode, createAiPolygonMode, createAiMaskMode, + createAiBatchMode, editMode, brightnessContrast, ), @@ -807,7 +826,7 @@ def __init__( lambda: self.canvas.initializeAiModel( name=self._selectAiModelComboBox.currentText() ) - if self.canvas.createMode in ["ai_polygon", "ai_mask"] + if self.canvas.createMode in ["ai_polygon", "ai_mask", "ai_batch"] else None ) @@ -943,6 +962,7 @@ def populateModeActions(self): self.actions.createLineStripMode, self.actions.createAiPolygonMode, self.actions.createAiMaskMode, + self.actions.createAiBatchMode, self.actions.editMode, ) utils.addActions(self.menus.edit, actions + self.actions.editMenu) @@ -976,6 +996,7 @@ def setClean(self): self.actions.createLineStripMode.setEnabled(True) self.actions.createAiPolygonMode.setEnabled(True) self.actions.createAiMaskMode.setEnabled(True) + self.actions.createAiBatchMode.setEnabled(True) title = __appname__ if self.filename is not None: title = "{} - {}".format(title, self.filename) @@ -1113,6 +1134,7 @@ def toggleDrawMode(self, edit=True, createMode="polygon"): "linestrip": self.actions.createLineStripMode, "ai_polygon": self.actions.createAiPolygonMode, "ai_mask": self.actions.createAiMaskMode, + "ai_batch": self.actions.createAiBatchMode, } self.canvas.setEditing(edit) diff --git a/labelme/config/default_config.yaml b/labelme/config/default_config.yaml index 128dc6d6a..c3c56e176 100644 --- a/labelme/config/default_config.yaml +++ b/labelme/config/default_config.yaml @@ -81,6 +81,7 @@ canvas: linestrip: false ai_polygon: false ai_mask: false + ai_batch: false shortcuts: close: Ctrl+W diff --git a/labelme/widgets/canvas.py b/labelme/widgets/canvas.py index d71a6125b..b8a22a983 100644 --- a/labelme/widgets/canvas.py +++ b/labelme/widgets/canvas.py @@ -9,6 +9,9 @@ from labelme.logger import logger from labelme.shape import Shape +import numpy as np +import skimage + # TODO(unknown): # - [maybe] Find optimal epsilon value. @@ -58,6 +61,7 @@ def __init__(self, *args, **kwargs): "linestrip": False, "ai_polygon": False, "ai_mask": False, + "ai_batch": False }, ) super(Canvas, self).__init__(*args, **kwargs) @@ -124,6 +128,7 @@ def createMode(self, value): "linestrip", "ai_polygon", "ai_mask", + "ai_batch", ]: raise ValueError("Unsupported createMode: %s" % value) self._createMode = value @@ -244,7 +249,7 @@ def mouseMoveEvent(self, ev): # Polygon drawing. if self.drawing(): - if self.createMode in ["ai_polygon", "ai_mask"]: + if self.createMode in ["ai_polygon", "ai_mask", "ai_batch"]: self.line.shape_type = "points" else: self.line.shape_type = self.createMode @@ -272,7 +277,7 @@ def mouseMoveEvent(self, ev): if self.createMode in ["polygon", "linestrip"]: self.line.points = [self.current[-1], pos] self.line.point_labels = [1, 1] - elif self.createMode in ["ai_polygon", "ai_mask"]: + elif self.createMode in ["ai_polygon", "ai_mask", "ai_batch"]: self.line.points = [self.current.points[-1], pos] self.line.point_labels = [ self.current.point_labels[-1], @@ -432,7 +437,7 @@ def mousePressEvent(self, ev): self.line[0] = self.current[-1] if int(ev.modifiers()) == QtCore.Qt.ControlModifier: self.finalise() - elif self.createMode in ["ai_polygon", "ai_mask"]: + elif self.createMode in ["ai_polygon", "ai_mask", "ai_batch"]: self.current.addPoint( self.line.points[1], label=self.line.point_labels[1], @@ -445,14 +450,14 @@ def mousePressEvent(self, ev): # Create new shape. self.current = Shape( shape_type="points" - if self.createMode in ["ai_polygon", "ai_mask"] + if self.createMode in ["ai_polygon", "ai_mask", "ai_batch"] else self.createMode ) self.current.addPoint(pos, label=0 if is_shift_pressed else 1) if self.createMode == "point": self.finalise() elif ( - self.createMode in ["ai_polygon", "ai_mask"] + self.createMode in ["ai_polygon", "ai_mask", "ai_batch"] and ev.modifiers() & QtCore.Qt.ControlModifier ): self.finalise() @@ -461,7 +466,7 @@ def mousePressEvent(self, ev): self.current.shape_type = "circle" self.line.points = [pos, pos] if ( - self.createMode in ["ai_polygon", "ai_mask"] + self.createMode in ["ai_polygon", "ai_mask", "ai_batch"] and is_shift_pressed ): self.line.point_labels = [0, 0] @@ -548,7 +553,7 @@ def setHiding(self, enable=True): def canCloseShape(self): return self.drawing() and ( (self.current and len(self.current) > 2) - or self.createMode in ["ai_polygon", "ai_mask"] + or self.createMode in ["ai_polygon", "ai_mask", "ai_batch"] ) def mouseDoubleClickEvent(self, ev): @@ -557,7 +562,7 @@ def mouseDoubleClickEvent(self, ev): if ( self.createMode == "polygon" and self.canCloseShape() - ) or self.createMode in ["ai_polygon", "ai_mask"]: + ) or self.createMode in ["ai_polygon", "ai_mask", "ai_batch"]: self.finalise() def selectShapes(self, shapes): @@ -771,6 +776,35 @@ def paintEvent(self, event): ) drawing_shape.selected = True drawing_shape.paint(p) + elif self.createMode == "ai_batch" and self.current is not None: + drawing_shape = self.current.copy() + drawing_shape.addPoint( + point=self.line.points[1], + label=self.line.point_labels[1], + ) + mask = self._ai_model.predict_mask_from_points( + points=[[point.x(), point.y()] for point in drawing_shape.points], + point_labels=drawing_shape.point_labels, + ) + contours = skimage.measure.find_contours(np.pad(mask, pad_width=1)) + + for c in contours: + POLYGON_APPROX_TOLERANCE = 0.004 + polygon = skimage.measure.approximate_polygon( + coords=c, + tolerance=np.ptp(c, axis=0).max() * POLYGON_APPROX_TOLERANCE, + ) + polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1)) + polygon = polygon[:-1] + + drawing_shape.setShapeRefined( + shape_type="polygon", + points=[QtCore.QPointF(point[1], point[0]) for point in polygon], + point_labels=[1] * len(polygon), + ) + drawing_shape.fill = self.fillDrawing() + drawing_shape.selected = True + drawing_shape.paint(p) p.end() @@ -819,14 +853,50 @@ def finalise(self): point_labels=[1, 1], mask=mask[y1 : y2 + 1, x1 : x2 + 1], ) - self.current.close() - self.shapes.append(self.current) - self.storeShapes() - self.current = None - self.setHiding(False) - self.newShape.emit() - self.update() + elif self.createMode == "ai_batch": + # convert points to polygon by an AI model in batch mode + assert self.current.shape_type == "points" + mask = self._ai_model.predict_mask_from_points( + points=[[point.x(), point.y()] for point in self.current.points], + point_labels=self.current.point_labels, + ) + contours = skimage.measure.find_contours(np.pad(mask, pad_width=1)) + + for c in contours: + POLYGON_APPROX_TOLERANCE = 0.004 + polygon = skimage.measure.approximate_polygon( + coords=c, + tolerance=np.ptp(c, axis=0).max() * POLYGON_APPROX_TOLERANCE, + ) + polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1)) + polygon = polygon[:-1] + + self.current.setShapeRefined( + shape_type="polygon", + points=[QtCore.QPointF(point[1], point[0]) for point in polygon], + point_labels=[1] * len(polygon), + ) + + self.current.close() + self.shapes.append(self.current) + self.storeShapes() + self.current = None + self.setHiding(False) + self.newShape.emit() + self.update() + self.current = Shape(shape_type="points") + + self.current = None + + if self.createMode != "ai_batch": + self.current.close() + self.shapes.append(self.current) + self.storeShapes() + self.current = None + self.setHiding(False) + self.newShape.emit() + self.update() def closeEnough(self, p1, p2): # d = distance(p1 - p2)