Skip to content

Commit

Permalink
Added smart positioning (non overlapping labels) to LabelAnnotator
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijaucharmal committed Oct 28, 2024
1 parent ba559c2 commit a2e2e37
Showing 1 changed file with 129 additions and 47 deletions.
176 changes: 129 additions & 47 deletions supervision/annotators/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
from supervision.detection.core import Detections
from supervision.detection.utils import clip_boxes, mask_to_polygons
from supervision.detection.utils import clip_boxes, mask_to_polygons, spread_out
from supervision.draw.color import Color, ColorPalette
from supervision.draw.utils import draw_polygon
from supervision.geometry.core import Position
Expand Down Expand Up @@ -1054,6 +1054,7 @@ def __init__(
text_position: Position = Position.TOP_LEFT,
color_lookup: ColorLookup = ColorLookup.CLASS,
border_radius: int = 0,
use_smart_positioning: bool = False,
):
"""
Args:
Expand All @@ -1070,7 +1071,10 @@ def __init__(
Options are `INDEX`, `CLASS`, `TRACK`.
border_radius (int): The radius to apply round edges. If the selected
value is higher than the lower dimension, width or height, is clipped.
use_smart_positioning (bool): Whether to use smart positioning to prevent
label overlapping or not.
"""
self.use_smart_positioning: bool = use_smart_positioning
self.border_radius: int = border_radius
self.color: Union[Color, ColorPalette] = color
self.text_color: Union[Color, ColorPalette] = text_color
Expand Down Expand Up @@ -1128,84 +1132,162 @@ def annotate(
![label-annotator-example](https://media.roboflow.com/
supervision-annotator-examples/label-annotator-example-purple.png)
"""

assert isinstance(scene, np.ndarray)
font = cv2.FONT_HERSHEY_SIMPLEX
anchors_coordinates = detections.get_anchors_coordinates(
anchor=self.text_anchor
).astype(int)
self._validate_labels(labels, detections)

# Get text properties for all detections
text_props = self._get_text_properties(detections, labels)

# Calculate background coordinates for all labels
xyxy = self._calculate_label_backgrounds(
detections, text_props, self.text_anchor, self.text_padding
)

# Adjust positions if smart positioning is enabled
if self.use_smart_positioning:
xyxy = spread_out(xyxy, step=2)

# Draw all labels
self._draw_labels(
scene=scene,
xyxy=xyxy,
text_props=text_props,
detections=detections,
custom_color_lookup=custom_color_lookup,
)

return scene

def _validate_labels(self, labels: Optional[List[str]], detections: Detections):
"""Validates that the number of labels matches the number of detections."""
if labels is not None and len(labels) != len(detections):
raise ValueError(
f"The number of labels ({len(labels)}) does not match the "
f"number of detections ({len(detections)}). Each detection "
f"should have exactly 1 label."
)

for detection_idx, center_coordinates in enumerate(anchors_coordinates):
color = resolve_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
color_lookup=(
self.color_lookup
if custom_color_lookup is None
else custom_color_lookup
),
)

text_color = resolve_color(
color=self.text_color,
detections=detections,
detection_idx=detection_idx,
color_lookup=(
self.color_lookup
if custom_color_lookup is None
else custom_color_lookup
),
)
def _get_text_properties(
self, detections: Detections, custom_labels: Optional[List[str]]
) -> List[dict]:
"""Gets text content and dimensions for all detections."""
text_props = []
font = cv2.FONT_HERSHEY_SIMPLEX

if labels is not None:
text = labels[detection_idx]
elif CLASS_NAME_DATA_FIELD in detections.data:
text = detections.data[CLASS_NAME_DATA_FIELD][detection_idx]
elif detections.class_id is not None:
text = str(detections.class_id[detection_idx])
else:
text = str(detection_idx)
for idx in range(len(detections)):
# Determine label text
text = self._get_label_text(detections, custom_labels, idx)

text_w, text_h = cv2.getTextSize(
# Calculate text dimensions
(text_w, text_h) = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]
text_w_padded = text_w + 2 * self.text_padding
text_h_padded = text_h + 2 * self.text_padding

text_props.append(
{
"text": text,
"width": text_w,
"height": text_h,
"width_padded": text_w + 2 * self.text_padding,
"height_padded": text_h + 2 * self.text_padding,
}
)

return text_props

def _get_label_text(
self, detections: Detections, custom_labels: Optional[List[str]], idx: int
) -> str:
"""Determines the label text for a given detection."""
if custom_labels is not None:
return custom_labels[idx]
elif CLASS_NAME_DATA_FIELD in detections.data:
return detections.data[CLASS_NAME_DATA_FIELD][idx]
elif detections.class_id is not None:
return str(detections.class_id[idx])
return str(idx)

def _calculate_label_backgrounds(
self,
detections: Detections,
text_props: List[dict],
text_anchor: str,
text_padding: int,
) -> np.ndarray:
"""Calculates background coordinates for all labels."""
anchors_coordinates = detections.get_anchors_coordinates(
anchor=text_anchor
).astype(int)

xyxy = []
for idx, center_coords in enumerate(anchors_coordinates):
text_background_xyxy = resolve_text_background_xyxy(
center_coordinates=tuple(center_coordinates),
text_wh=(text_w_padded, text_h_padded),
position=self.text_anchor,
center_coordinates=tuple(center_coords),
text_wh=(
text_props[idx]["width_padded"],
text_props[idx]["height_padded"],
),
position=text_anchor,
)
xyxy.append(text_background_xyxy)

return np.array(xyxy)

def _draw_labels(
self,
scene: ImageType,
xyxy: np.ndarray,
text_props: List[dict],
detections: Detections,
custom_color_lookup: Optional[np.ndarray],
) -> None:
"""Draws all labels and their backgrounds on the scene."""
if custom_color_lookup is not None:
color_lookup = custom_color_lookup
else:
color_lookup = self.color_lookup
font = cv2.FONT_HERSHEY_SIMPLEX

for idx, coordinates in enumerate(xyxy):
# Resolve colors
bg_color = resolve_color(
color=self.color,
detections=detections,
detection_idx=idx,
color_lookup=color_lookup,
)
text_color = resolve_color(
color=self.text_color,
detections=detections,
detection_idx=idx,
color_lookup=color_lookup,
)

text_x = text_background_xyxy[0] + self.text_padding
text_y = text_background_xyxy[1] + self.text_padding + text_h
# Calculate text position
text_x = coordinates[0] + self.text_padding
text_y = coordinates[1] + self.text_padding + text_props[idx]["height"]

# Draw background and text
self.draw_rounded_rectangle(
scene=scene,
xyxy=text_background_xyxy,
color=color.as_bgr(),
xyxy=coordinates,
color=bg_color.as_bgr(),
border_radius=self.border_radius,
)
cv2.putText(
img=scene,
text=text,
text=text_props[idx]["text"],
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
color=text_color.as_bgr(),
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return scene

@staticmethod
def draw_rounded_rectangle(
Expand Down

0 comments on commit a2e2e37

Please sign in to comment.