From 917bb9cda96eca2e2ed60667d7290362ddd8e338 Mon Sep 17 00:00:00 2001 From: KshitijAucharmal Date: Mon, 28 Oct 2024 19:05:39 +0530 Subject: [PATCH] Added smart positioning (non overlapping labels) to LabelAnnotator --- supervision/annotators/core.py | 176 ++++++++++++++++++++++++--------- 1 file changed, 129 insertions(+), 47 deletions(-) diff --git a/supervision/annotators/core.py b/supervision/annotators/core.py index 1910ac9f4..35faea6ca 100644 --- a/supervision/annotators/core.py +++ b/supervision/annotators/core.py @@ -16,7 +16,7 @@ ) from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES from supervision.detection.core import Detections -from supervision.detection.utils import clip_boxes, mask_to_polygons +from supervision.detection.utils import clip_boxes, mask_to_polygons, spread_out from supervision.draw.color import Color, ColorPalette from supervision.draw.utils import draw_polygon from supervision.geometry.core import Position @@ -1054,6 +1054,7 @@ def __init__( text_position: Position = Position.TOP_LEFT, color_lookup: ColorLookup = ColorLookup.CLASS, border_radius: int = 0, + use_smart_positioning: bool = True, ): """ Args: @@ -1070,7 +1071,10 @@ def __init__( Options are `INDEX`, `CLASS`, `TRACK`. border_radius (int): The radius to apply round edges. If the selected value is higher than the lower dimension, width or height, is clipped. + use_smart_positioning (bool): Whether to use smart positioning to prevent + label overlapping or not. """ + self.use_smart_positioning: bool = use_smart_positioning self.border_radius: int = border_radius self.color: Union[Color, ColorPalette] = color self.text_color: Union[Color, ColorPalette] = text_color @@ -1128,11 +1132,35 @@ def annotate( ![label-annotator-example](https://media.roboflow.com/ supervision-annotator-examples/label-annotator-example-purple.png) """ + assert isinstance(scene, np.ndarray) - font = cv2.FONT_HERSHEY_SIMPLEX - anchors_coordinates = detections.get_anchors_coordinates( - anchor=self.text_anchor - ).astype(int) + self._validate_labels(labels, detections) + + # Get text properties for all detections + text_props = self._get_text_properties(detections, labels) + + # Calculate background coordinates for all labels + xyxy = self._calculate_label_backgrounds( + detections, text_props, self.text_anchor, self.text_padding + ) + + # Adjust positions if smart positioning is enabled + if self.use_smart_positioning: + xyxy = spread_out(xyxy, step=2) + + # Draw all labels + self._draw_labels( + scene=scene, + xyxy=xyxy, + text_props=text_props, + detections=detections, + custom_color_lookup=custom_color_lookup, + ) + + return scene + + def _validate_labels(self, labels: Optional[List[str]], detections: Detections): + """Validates that the number of labels matches the number of detections.""" if labels is not None and len(labels) != len(detections): raise ValueError( f"The number of labels ({len(labels)}) does not match the " @@ -1140,64 +1168,119 @@ def annotate( f"should have exactly 1 label." ) - for detection_idx, center_coordinates in enumerate(anchors_coordinates): - color = resolve_color( - color=self.color, - detections=detections, - detection_idx=detection_idx, - color_lookup=( - self.color_lookup - if custom_color_lookup is None - else custom_color_lookup - ), - ) - - text_color = resolve_color( - color=self.text_color, - detections=detections, - detection_idx=detection_idx, - color_lookup=( - self.color_lookup - if custom_color_lookup is None - else custom_color_lookup - ), - ) + def _get_text_properties( + self, detections: Detections, custom_labels: Optional[List[str]] + ) -> List[dict]: + """Gets text content and dimensions for all detections.""" + text_props = [] + font = cv2.FONT_HERSHEY_SIMPLEX - if labels is not None: - text = labels[detection_idx] - elif CLASS_NAME_DATA_FIELD in detections.data: - text = detections.data[CLASS_NAME_DATA_FIELD][detection_idx] - elif detections.class_id is not None: - text = str(detections.class_id[detection_idx]) - else: - text = str(detection_idx) + for idx in range(len(detections)): + # Determine label text + text = self._get_label_text(detections, custom_labels, idx) - text_w, text_h = cv2.getTextSize( + # Calculate text dimensions + (text_w, text_h) = cv2.getTextSize( text=text, fontFace=font, fontScale=self.text_scale, thickness=self.text_thickness, )[0] - text_w_padded = text_w + 2 * self.text_padding - text_h_padded = text_h + 2 * self.text_padding + + text_props.append( + { + "text": text, + "width": text_w, + "height": text_h, + "width_padded": text_w + 2 * self.text_padding, + "height_padded": text_h + 2 * self.text_padding, + } + ) + + return text_props + + def _get_label_text( + self, detections: Detections, custom_labels: Optional[List[str]], idx: int + ) -> str: + """Determines the label text for a given detection.""" + if custom_labels is not None: + return custom_labels[idx] + elif CLASS_NAME_DATA_FIELD in detections.data: + return detections.data[CLASS_NAME_DATA_FIELD][idx] + elif detections.class_id is not None: + return str(detections.class_id[idx]) + return str(idx) + + def _calculate_label_backgrounds( + self, + detections: Detections, + text_props: List[dict], + text_anchor: str, + text_padding: int, + ) -> np.ndarray: + """Calculates background coordinates for all labels.""" + anchors_coordinates = detections.get_anchors_coordinates( + anchor=text_anchor + ).astype(int) + + xyxy = [] + for idx, center_coords in enumerate(anchors_coordinates): text_background_xyxy = resolve_text_background_xyxy( - center_coordinates=tuple(center_coordinates), - text_wh=(text_w_padded, text_h_padded), - position=self.text_anchor, + center_coordinates=tuple(center_coords), + text_wh=( + text_props[idx]["width_padded"], + text_props[idx]["height_padded"], + ), + position=text_anchor, + ) + xyxy.append(text_background_xyxy) + + return np.array(xyxy) + + def _draw_labels( + self, + scene: ImageType, + xyxy: np.ndarray, + text_props: List[dict], + detections: Detections, + custom_color_lookup: Optional[np.ndarray], + ) -> None: + """Draws all labels and their backgrounds on the scene.""" + if custom_color_lookup is not None: + color_lookup = custom_color_lookup + else: + color_lookup = self.color_lookup + font = cv2.FONT_HERSHEY_SIMPLEX + + for idx, coordinates in enumerate(xyxy): + # Resolve colors + bg_color = resolve_color( + color=self.color, + detections=detections, + detection_idx=idx, + color_lookup=color_lookup, + ) + text_color = resolve_color( + color=self.text_color, + detections=detections, + detection_idx=idx, + color_lookup=color_lookup, ) - text_x = text_background_xyxy[0] + self.text_padding - text_y = text_background_xyxy[1] + self.text_padding + text_h + # Calculate text position + text_x = coordinates[0] + self.text_padding + text_y = coordinates[1] + self.text_padding + text_props[idx]["height"] + # Draw background and text self.draw_rounded_rectangle( scene=scene, - xyxy=text_background_xyxy, - color=color.as_bgr(), + xyxy=coordinates, + color=bg_color.as_bgr(), border_radius=self.border_radius, ) cv2.putText( img=scene, - text=text, + text=text_props[idx]["text"], org=(text_x, text_y), fontFace=font, fontScale=self.text_scale, @@ -1205,7 +1288,6 @@ def annotate( thickness=self.text_thickness, lineType=cv2.LINE_AA, ) - return scene @staticmethod def draw_rounded_rectangle(