diff --git a/test/assets/fakedata/draw_boxes_different_label_colors.png b/test/assets/fakedata/draw_boxes_different_label_colors.png new file mode 100644 index 00000000000..72178930602 Binary files /dev/null and b/test/assets/fakedata/draw_boxes_different_label_colors.png differ diff --git a/test/test_utils.py b/test/test_utils.py index e89bef4a6d9..8dfe3a1080f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -116,6 +116,21 @@ def test_draw_boxes(): assert_equal(img, img_cp) +@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1") +def test_draw_boxes_with_coloured_labels(): + img = torch.full((3, 100, 100), 255, dtype=torch.uint8) + labels = ["a", "b", "c", "d"] + colors = ["green", "#FF00FF", (0, 255, 0), "red"] + label_colors = ["green", "red", (0, 255, 0), "#FF00FF"] + result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors) + + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_colors.png" + ) + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + + @pytest.mark.parametrize("fill", [True, False]) def test_draw_boxes_dtypes(fill): img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8) diff --git a/torchvision/utils.py b/torchvision/utils.py index b69edcb572e..0c62db363f2 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -161,6 +161,7 @@ def draw_bounding_boxes( width: int = 1, font: Optional[str] = None, font_size: Optional[int] = None, + label_colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, ) -> torch.Tensor: """ @@ -184,9 +185,12 @@ def draw_bounding_boxes( also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. font_size (int): The requested font size in points. + label_colors (color or list of colors, optional): Colors for the label text. See the description of the + `colors` argument for details. Defaults to the same colors used for the boxes. Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. + """ import torchvision.transforms.v2.functional as F # noqa @@ -219,6 +223,10 @@ def draw_bounding_boxes( ) colors = _parse_colors(colors, num_objects=num_boxes) + if label_colors: + label_colors = _parse_colors(label_colors, num_objects=num_boxes) + else: + label_colors = colors.copy() if font is None: if font_size is not None: @@ -243,7 +251,7 @@ def draw_bounding_boxes( else: draw = ImageDraw.Draw(img_to_draw) - for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] + for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type] if fill: fill_color = color + (100,) draw.rectangle(bbox, width=width, outline=color, fill=fill_color) @@ -252,7 +260,7 @@ def draw_bounding_boxes( if label is not None: margin = width + 1 - draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) out = F.pil_to_tensor(img_to_draw) if original_dtype.is_floating_point: