From f8d6f8fffeec1d70e6153ce39abb14a65c229f21 Mon Sep 17 00:00:00 2001 From: David Miguel Susano Pinto Date: Thu, 28 Nov 2024 14:55:36 +0000 Subject: [PATCH] Add `label_colors` argument to `draw_bounding_boxes` (#8578) Co-authored-by: Nicolas Hug --- .../draw_boxes_different_label_colors.png | Bin 0 -> 723 bytes test/test_utils.py | 15 +++++++++++++++ torchvision/utils.py | 12 ++++++++++-- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 test/assets/fakedata/draw_boxes_different_label_colors.png diff --git a/test/assets/fakedata/draw_boxes_different_label_colors.png b/test/assets/fakedata/draw_boxes_different_label_colors.png new file mode 100644 index 0000000000000000000000000000000000000000..721789306028259853cd007822e341adb72f8184 GIT binary patch literal 723 zcmeAS@N?(olHy`uVBq!ia0vp^DIm$3LF` z{6)dGF<^mHL{4fz1?RO-#}~>x3~f=Xlvc6}S4@gsTfZQggTKXQyBJ5(6;-uOO&*TC zr<%NUF1!l69B`***Nv6)(u;Kuz5Q#)du-m}?&rUs{{EX+%@e<`@N~e^pUL}Psd-Pj z{XFbjV*0H>+thz&b3^lI3WPLGeI?o2Y`uDw@SY`G|J*7o++Fav>e1(ii@$BWdHe9s zYicF`YSL4~Hfd)3?l85p ze>dCB|F-C10~JweoeZXpVf)ucxw_8Vr}ak6&vs^f|NLgPf+l|#VS7GXE2F;^WoPqk zEvHLMH>nA%{-5*zZ~EU!*Op#ivEdxk>w2EU2Z~!GUhH)8X*=}o;l4jshbB$=HG`|_ zbo3{uzH%8$9_ME|OsjtthuwHt;B#*h`(?SXALe{*JloITz0jtfA3S^78o^$ZRTC2> zBHU)~c$B%+`ikLF&1Y*iOxr3G>=U$o*JQ7zrA61KE%S*si%4`8jlHsH$(*9evr={o zUt7SYZeqAtGWy$wwu4-&BhO|ojkMaJ!OE?(;Q&+jfm`lv@26XoYGm7oOt$a%?R!kC zI6Oiqc3pIEP~TyLlyg=eA{J|Ic(rTQKfc$yb>B}ilfE*EeRT-GRq4lOtNWRXi^P{p zwk_OiyN+9XQpvx`KJsfWaP42WP5bJj_^MT1rO#H2DNouWFg2?nF#f9L>s!TlUk7fB z&5chHJ;=3IbNy`o?{VtSCEC_J-~>4b=q3&_Xe))O6=J?Or>{KPe(}xsNE6Zc>4_)i fxF-DVu46hdb?$Xf*_d)*VrB4j^>bP0l+XkKlV?y| literal 0 HcmV?d00001 diff --git a/test/test_utils.py b/test/test_utils.py index e89bef4a6d9..8dfe3a1080f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -116,6 +116,21 @@ def test_draw_boxes(): assert_equal(img, img_cp) +@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1") +def test_draw_boxes_with_coloured_labels(): + img = torch.full((3, 100, 100), 255, dtype=torch.uint8) + labels = ["a", "b", "c", "d"] + colors = ["green", "#FF00FF", (0, 255, 0), "red"] + label_colors = ["green", "red", (0, 255, 0), "#FF00FF"] + result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors) + + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_colors.png" + ) + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + + @pytest.mark.parametrize("fill", [True, False]) def test_draw_boxes_dtypes(fill): img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8) diff --git a/torchvision/utils.py b/torchvision/utils.py index b69edcb572e..0c62db363f2 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -161,6 +161,7 @@ def draw_bounding_boxes( width: int = 1, font: Optional[str] = None, font_size: Optional[int] = None, + label_colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, ) -> torch.Tensor: """ @@ -184,9 +185,12 @@ def draw_bounding_boxes( also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. font_size (int): The requested font size in points. + label_colors (color or list of colors, optional): Colors for the label text. See the description of the + `colors` argument for details. Defaults to the same colors used for the boxes. Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. + """ import torchvision.transforms.v2.functional as F # noqa @@ -219,6 +223,10 @@ def draw_bounding_boxes( ) colors = _parse_colors(colors, num_objects=num_boxes) + if label_colors: + label_colors = _parse_colors(label_colors, num_objects=num_boxes) + else: + label_colors = colors.copy() if font is None: if font_size is not None: @@ -243,7 +251,7 @@ def draw_bounding_boxes( else: draw = ImageDraw.Draw(img_to_draw) - for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] + for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type] if fill: fill_color = color + (100,) draw.rectangle(bbox, width=width, outline=color, fill=fill_color) @@ -252,7 +260,7 @@ def draw_bounding_boxes( if label is not None: margin = width + 1 - draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) out = F.pil_to_tensor(img_to_draw) if original_dtype.is_floating_point: