Skip to content

Commit

Permalink
Merge pull request #58 from roboflow/feature/update_to_support_sam
Browse files Browse the repository at this point in the history
feature/update_to_support_sam
SkalskiP authored Apr 10, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
2 parents bc12a8e + ddc8a8c commit dba4d9f
Showing 12 changed files with 292 additions and 45 deletions.
7 changes: 7 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
### 0.5.0 <small>April 10, 2023</small>

- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.mask` to enable segmentation support.
- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `MaskAnnotator` to allow easy `Detections.mask` annotation.
- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.from_sam` to enable native Segment Anything Model (SAM) support.
- Changed [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.area` behaviour to work not only with boxes but also with masks.

### 0.4.0 <small>April 5, 2023</small>

- Added [[#46](https://github.com/roboflow/supervision/discussions/48)]: `Detections.empty` to allow easy creation of empty `Detections` objects.
6 changes: 5 additions & 1 deletion docs/detection/annotate.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## BoxAnnotator

:::supervision.detection.annotate.BoxAnnotator
:::supervision.detection.annotate.BoxAnnotator

## MaskAnnotator

:::supervision.detection.annotate.MaskAnnotator
7 changes: 7 additions & 0 deletions docs/detection/tools/polygon_zone.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## PolygonZone

:::supervision.detection.tools.polygon_zone.PolygonZone

## PolygonZoneAnnotator

:::supervision.detection.tools.polygon_zone.PolygonZoneAnnotator
6 changes: 5 additions & 1 deletion docs/detection/utils.md
Original file line number Diff line number Diff line change
@@ -8,4 +8,8 @@

## non_max_suppression

:::supervision.detection.utils.non_max_suppression
:::supervision.detection.utils.non_max_suppression

## mask_to_xyxy

:::supervision.detection.utils.mask_to_xyxy
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -29,6 +29,8 @@ nav:
- Core: detection/core.md
- Annotate: detection/annotate.md
- Utils: detection/utils.md
- Tools:
- Polygon Zone: detection/tools/polygon_zone.md
- Draw:
- Utils: draw/utils.md
- Annotations:
8 changes: 4 additions & 4 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__version__ = "0.4.0"
__version__ = "0.5.0"

from supervision.annotation.voc import detections_to_voc_xml
from supervision.detection.annotate import BoxAnnotator
from supervision.detection.annotate import BoxAnnotator, MaskAnnotator
from supervision.detection.core import Detections
from supervision.detection.line_counter import LineZone, LineZoneAnnotator
from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator
from supervision.detection.utils import generate_2d_mask
from supervision.detection.tools.polygon_zone import PolygonZone, PolygonZoneAnnotator
from supervision.detection.utils import generate_2d_mask, mask_to_xyxy
from supervision.draw.color import Color, ColorPalette
from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text
from supervision.geometry.core import Point, Position, Rect
62 changes: 60 additions & 2 deletions supervision/detection/annotate.py
Original file line number Diff line number Diff line change
@@ -56,8 +56,11 @@ def annotate(
np.ndarray: The image with the bounding boxes drawn on it
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections):
x1, y1, x2, y2 = xyxy.astype(int)
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
@@ -114,3 +117,58 @@ def annotate(
lineType=cv2.LINE_AA,
)
return scene


class MaskAnnotator:
"""
A class for overlaying masks on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette
"""

def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.default(),
):
self.color: Union[Color, ColorPalette] = color

def annotate(
self, scene: np.ndarray, detections: Detections, opacity: float = 0.5
) -> np.ndarray:
"""
Overlays the masks on the given image based on the provided detections, with a specified opacity.
Parameters:
scene (np.ndarray): The image on which the masks will be overlaid
detections (Detections): The detections for which the masks will be overlaid
opacity (float): The opacity of the masks, between 0 and 1, default is 0.5
Returns:
np.ndarray: The image with the masks overlaid
"""
for i in range(len(detections.xyxy)):
if detections.mask is None:
continue

class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)

mask = detections.mask[i]
colored_mask = np.zeros_like(scene, dtype=np.uint8)
colored_mask[:] = color.as_bgr()

scene = np.where(
np.expand_dims(mask, axis=-1),
np.uint8(opacity * colored_mask + (1 - opacity) * scene),
scene,
)

return scene
141 changes: 112 additions & 29 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,78 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple, Union
from typing import Any, Iterator, List, Optional, Tuple, Union

import numpy as np

from supervision.detection.utils import non_max_suppression
from supervision.detection.utils import non_max_suppression, xywh_to_xyxy
from supervision.geometry.core import Position


def _validate_xyxy(xyxy: Any, n: int) -> None:
is_valid = isinstance(xyxy, np.ndarray) and xyxy.shape == (n, 4)
if not is_valid:
raise ValueError("xyxy must be 2d np.ndarray with (n, 4) shape")


def _validate_mask(mask: Any, n: int) -> None:
is_valid = mask is None or (
isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask.shape[0] == n
)
if not is_valid:
raise ValueError("mask must be 3d np.ndarray with (n, W, H) shape")


def _validate_class_id(class_id: Any, n: int) -> None:
is_valid = class_id is None or (
isinstance(class_id, np.ndarray) and class_id.shape == (n,)
)
if not is_valid:
raise ValueError("class_id must be None or 1d np.ndarray with (n,) shape")


def _validate_confidence(confidence: Any, n: int) -> None:
is_valid = confidence is None or (
isinstance(confidence, np.ndarray) and confidence.shape == (n,)
)
if not is_valid:
raise ValueError("confidence must be None or 1d np.ndarray with (n,) shape")


def _validate_tracker_id(tracker_id: Any, n: int) -> None:
is_valid = tracker_id is None or (
isinstance(tracker_id, np.ndarray) and tracker_id.shape == (n,)
)
if not is_valid:
raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape")


@dataclass
class Detections:
"""
Data class containing information about the detections in a video frame.
Attributes:
xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
mask: (Optional[np.ndarray]): An array of shape `(n, W, H)` containing the segmentation masks.
class_id (Optional[np.ndarray]): An array of shape `(n,)` containing the class ids of the detections.
confidence (Optional[np.ndarray]): An array of shape `(n,)` containing the confidence scores of the detections.
tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
"""

xyxy: np.ndarray
mask: np.Optional[np.ndarray] = None
class_id: Optional[np.ndarray] = None
confidence: Optional[np.ndarray] = None
tracker_id: Optional[np.ndarray] = None

def __post_init__(self):
n = len(self.xyxy)
validators = [
(isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)),
self.class_id is None
or (isinstance(self.class_id, np.ndarray) and self.class_id.shape == (n,)),
self.confidence is None
or (
isinstance(self.confidence, np.ndarray)
and self.confidence.shape == (n,)
),
self.tracker_id is None
or (
isinstance(self.tracker_id, np.ndarray)
and self.tracker_id.shape == (n,)
),
]
if not all(validators):
raise ValueError(
"xyxy must be 2d np.ndarray with (n, 4) shape, "
"class_id must be None or 1d np.ndarray with (n,) shape, "
"confidence must be None or 1d np.ndarray with (n,) shape, "
"tracker_id must be None or 1d np.ndarray with (n,) shape"
)
_validate_xyxy(xyxy=self.xyxy, n=n)
_validate_mask(mask=self.mask, n=n)
_validate_class_id(class_id=self.class_id, n=n)
_validate_confidence(confidence=self.confidence, n=n)
_validate_tracker_id(tracker_id=self.tracker_id, n=n)

def __len__(self):
"""
@@ -59,13 +82,22 @@ def __len__(self):

def __iter__(
self,
) -> Iterator[Tuple[np.ndarray, Optional[float], int, Optional[Union[str, int]]]]:
) -> Iterator[
Tuple[
np.ndarray,
Optional[np.ndarray],
Optional[float],
Optional[int],
Optional[int],
]
]:
"""
Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection.
Iterates over the Detections object and yield a tuple of `(xyxy, mask, confidence, class_id, tracker_id)` for each detection.
"""
for i in range(len(self.xyxy)):
yield (
self.xyxy[i],
self.mask[i] if self.mask is not None else None,
self.confidence[i] if self.confidence is not None else None,
self.class_id[i] if self.class_id is not None else None,
self.tracker_id[i] if self.tracker_id is not None else None,
@@ -75,6 +107,12 @@ def __eq__(self, other: Detections):
return all(
[
np.array_equal(self.xyxy, other.xyxy),
any(
[
self.mask is None and other.mask is None,
np.array_equal(self.mask, other.mask),
]
),
any(
[
self.class_id is None and other.class_id is None,
@@ -113,7 +151,7 @@ def from_yolov5(cls, yolov5_results) -> Detections:
>>> from supervision import Detections
>>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
>>> results = model(frame)
>>> results = model(IMAGE)
>>> detections = Detections.from_yolov5(results)
```
"""
@@ -141,8 +179,8 @@ def from_yolov8(cls, yolov8_results) -> Detections:
>>> from supervision import Detections
>>> model = YOLO('yolov8s.pt')
>>> results = model(frame)[0]
>>> detections = Detections.from_yolov8(results)
>>> yolov8_results = model(IMAGE)[0]
>>> detections = Detections.from_yolov8(yolov8_results)
```
"""
return cls(
@@ -201,6 +239,37 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio
class_id=np.array(class_id).astype(int),
)

@classmethod
def from_sam(cls, sam_result: List[dict]) -> Detections:
"""
Creates a Detections instance from Segment Anything Model (SAM) by Meta AI.
Args:
sam_result (List[dict]): The output Results instance from SAM
Returns:
Detections: A new Detections object.
Example:
```python
>>> from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
>>> import supervision as sv
>>> sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
>>> mask_generator = SamAutomaticMaskGenerator(sam)
>>> sam_result = mask_generator.generate(IMAGE)
>>> detections = sv.Detections.from_sam(sam_result=sam_result)
```
"""
sorted_generated_masks = sorted(
sam_result, key=lambda x: x["area"], reverse=True
)

xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
mask = np.array([mask["segmentation"] for mask in sorted_generated_masks])

return Detections(xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask)

@classmethod
def from_coco_annotations(cls, coco_annotation: dict) -> Detections:
xyxy, class_id = [], []
@@ -264,6 +333,20 @@ def __getitem__(self, index: np.ndarray) -> Detections:

@property
def area(self) -> np.ndarray:
"""
Calculate the area of each detection in the set of object detections. If masks field is defined property
returns are of each mask. If only box is given property return area of each box.
Returns:
np.ndarray: An array of floats containing the area of each detection in the format of `(area_1, area_2, ..., area_n)`, where n is the number of detections.
"""
if self.mask is not None:
return np.ndarray([np.sum(mask) for mask in self.mask])
else:
return self.box_area

@property
def box_area(self) -> np.ndarray:
"""
Calculate the area of each bounding box in the set of object detections.
Empty file.
Original file line number Diff line number Diff line change
@@ -13,6 +13,17 @@


class PolygonZone:
"""
A class for defining a polygon-shaped zone within a frame for detecting objects.
Attributes:
polygon (np.ndarray): A numpy array defining the polygon vertices
frame_resolution_wh (Tuple[int, int]): The frame resolution (width, height)
triggering_position (Position): The position within the bounding box that triggers the zone (default: Position.BOTTOM_CENTER)
current_count (int): The current count of detected objects within the zone
mask (np.ndarray): The 2D bool mask for the polygon zone
"""

def __init__(
self,
polygon: np.ndarray,
@@ -30,6 +41,16 @@ def __init__(
)

def trigger(self, detections: Detections) -> np.ndarray:
"""
Determines if the detections are within the polygon zone.
Parameters:
detections (Detections): The detections to be checked against the polygon zone
Returns:
np.ndarray: A boolean numpy array indicating if each detection is within the polygon zone
"""

clipped_xyxy = clip_boxes(
boxes_xyxy=detections.xyxy, frame_resolution_wh=self.frame_resolution_wh
)
@@ -43,6 +64,21 @@ def trigger(self, detections: Detections) -> np.ndarray:


class PolygonZoneAnnotator:
"""
A class for annotating a polygon-shaped zone within a frame with a count of detected objects.
Attributes:
zone (PolygonZone): The polygon zone to be annotated
color (Color): The color to draw the polygon lines
thickness (int): The thickness of the polygon lines, default is 2
text_color (Color): The color of the text on the polygon, default is black
text_scale (float): The scale of the text on the polygon, default is 0.5
text_thickness (int): The thickness of the text on the polygon, default is 1
text_padding (int): The padding around the text on the polygon, default is 10
font (int): The font type for the text on the polygon, default is cv2.FONT_HERSHEY_SIMPLEX
center (Tuple[int, int]): The center of the polygon for text placement
"""

def __init__(
self,
zone: PolygonZone,
@@ -64,6 +100,16 @@ def __init__(
self.center = get_polygon_center(polygon=zone.polygon)

def annotate(self, scene: np.ndarray, label: Optional[str] = None) -> np.ndarray:
"""
Annotates the polygon zone within a frame with a count of detected objects.
Parameters:
scene (np.ndarray): The image on which the polygon zone will be annotated
label (Optional[str]): An optional label for the count of detected objects within the polygon zone (default: None)
Returns:
np.ndarray: The image with the polygon zone and count of detected objects
"""
annotated_frame = draw_polygon(
scene=scene,
polygon=self.zone.polygon,
31 changes: 31 additions & 0 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
@@ -115,3 +115,34 @@ def clip_boxes(
result[:, [0, 2]] = result[:, [0, 2]].clip(0, width)
result[:, [1, 3]] = result[:, [1, 3]].clip(0, height)
return result


def xywh_to_xyxy(boxes_xywh: np.ndarray) -> np.ndarray:
xyxy = boxes_xywh.copy()
xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2]
xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3]
return xyxy


def mask_to_xyxy(masks: np.ndarray) -> np.ndarray:
"""
Converts a 3D `np.array` of 2D bool masks into a 2D `np.array` of bounding boxes.
Parameters:
masks (np.ndarray): A 3D `np.array` of shape `(N, W, H)` containing 2D bool masks
Returns:
np.ndarray: A 2D `np.array` of shape `(N, 4)` containing the bounding boxes `(x_min, y_min, x_max, y_max)` for each mask
"""
n = masks.shape[0]
bboxes = np.zeros((n, 4), dtype=int)

for i, mask in enumerate(masks):
rows, cols = np.where(mask)

if len(rows) > 0 and len(cols) > 0:
x_min, x_max = np.min(cols), np.max(cols)
y_min, y_max = np.min(rows), np.max(rows)
bboxes[i, :] = [x_min, y_min, x_max, y_max]

return bboxes
21 changes: 13 additions & 8 deletions supervision/notebook/utils.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@


def plot_image(
image: np.ndarray, size: Tuple[int, int] = (10, 10), cmap: Optional[str] = "gray"
image: np.ndarray, size: Tuple[int, int] = (12, 12), cmap: Optional[str] = "gray"
) -> None:
"""
Plots image using matplotlib.
@@ -27,12 +27,14 @@ def plot_image(
>>> sv.plot_image(image, (16, 16))
```
"""
plt.figure(figsize=size)

if image.ndim == 2:
plt.figure(figsize=size)
plt.imshow(image, cmap=cmap)
else:
plt.figure(figsize=size)
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

plt.axis("off")
plt.show()


@@ -41,6 +43,7 @@ def plot_images_grid(
grid_size: Tuple[int, int],
titles: Optional[List[str]] = None,
size: Tuple[int, int] = (12, 12),
cmap: Optional[str] = "gray",
) -> None:
"""
Plots images in a grid using matplotlib.
@@ -50,6 +53,7 @@ def plot_images_grid(
grid_size (Tuple[int, int]): A tuple specifying the number of rows and columns for the grid.
titles (Optional[List[str]]): A list of titles for each image. Defaults to None.
size (Tuple[int, int]): A tuple specifying the width and height of the entire plot in inches.
cmap (str): the colormap to use for single channel images.
Raises:
ValueError: If the number of images exceeds the grid size.
@@ -70,7 +74,6 @@ def plot_images_grid(
>>> plot_images_grid(images, grid_size=(2, 2), titles=titles, figsize=(16, 16))
```
"""

nrows, ncols = grid_size

if len(images) > nrows * ncols:
@@ -82,11 +85,13 @@ def plot_images_grid(

for idx, ax in enumerate(axes.flat):
if idx < len(images):
ax.imshow(cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB))
if images[idx].ndim == 2:
ax.imshow(images[idx], cmap=cmap)
else:
ax.imshow(cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB))

if titles is not None and idx < len(titles):
ax.set_title(titles[idx])
ax.axis("off")
else:
ax.axis("off")

ax.axis("off")
plt.show()

0 comments on commit dba4d9f

Please sign in to comment.