diff --git a/gallery/others/plot_visualization_utils.py b/gallery/others/plot_visualization_utils.py index 98089c54dbb..9804fbad648 100644 --- a/gallery/others/plot_visualization_utils.py +++ b/gallery/others/plot_visualization_utils.py @@ -418,7 +418,7 @@ def show(imgs): show(res) # %% -# As we see the keypoints appear as colored circles over the image. +# As we see, the keypoints appear as colored circles over the image. # The coco keypoints for a person are ordered and represent the following list.\ coco_keypoints = [ @@ -460,3 +460,63 @@ def show(imgs): res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3) show(res) + +# %% +# That looks pretty good. +# +# .. _draw_keypoints_with_visibility: +# +# Drawing Keypoints with Visibility +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Let's have a look at the results, another keypoint prediction module produced, and show the connectivity: + +prediction = torch.tensor( + [[[208.0176, 214.2409, 1.0000], + [000.0000, 000.0000, 0.0000], + [197.8246, 210.6392, 1.0000], + [000.0000, 000.0000, 0.0000], + [178.6378, 217.8425, 1.0000], + [221.2086, 253.8591, 1.0000], + [160.6502, 269.4662, 1.0000], + [243.9929, 304.2822, 1.0000], + [138.4654, 328.8935, 1.0000], + [277.5698, 340.8990, 1.0000], + [153.4551, 374.5145, 1.0000], + [000.0000, 000.0000, 0.0000], + [226.0053, 370.3125, 1.0000], + [221.8081, 455.5516, 1.0000], + [273.9723, 448.9486, 1.0000], + [193.6275, 546.1933, 1.0000], + [273.3727, 545.5930, 1.0000]]] +) + +res = draw_keypoints(person_int, prediction, connectivity=connect_skeleton, colors="blue", radius=4, width=3) +show(res) + +# %% +# What happened there? +# The model, which predicted the new keypoints, +# can't detect the three points that are hidden on the upper left body of the skateboarder. +# More precisely, the model predicted that `(x, y, vis) = (0, 0, 0)` for the left_eye, left_ear, and left_hip. +# So we definitely don't want to display those keypoints and connections, and you don't have to. +# Looking at the parameters of :func:`~torchvision.utils.draw_keypoints`, +# we can see that we can pass a visibility tensor as an additional argument. +# Given the models' prediction, we have the visibility as the third keypoint dimension, we just need to extract it. +# Let's split the ``prediction`` into the keypoint coordinates and their respective visibility, +# and pass both of them as arguments to :func:`~torchvision.utils.draw_keypoints`. + +coordinates, visibility = prediction.split([2, 1], dim=-1) +visibility = visibility.bool() + +res = draw_keypoints( + person_int, coordinates, visibility=visibility, connectivity=connect_skeleton, colors="blue", radius=4, width=3 +) +show(res) + +# %% +# We can see that the undetected keypoints are not draw and the invisible keypoint connections were skipped. +# This can reduce the noise on images with multiple detections, or in cases like ours, +# when the keypoint-prediction model missed some detections. +# Most torch keypoint-prediction models return the visibility for every prediction, ready for you to use it. +# The :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn` model, +# which we used in the first case, does so too. diff --git a/test/assets/fakedata/draw_keypoints_visibility.png b/test/assets/fakedata/draw_keypoints_visibility.png new file mode 100644 index 00000000000..8cd34f84539 Binary files /dev/null and b/test/assets/fakedata/draw_keypoints_visibility.png differ diff --git a/test/test_utils.py b/test/test_utils.py index 020327d67fd..49dc553de3e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -361,6 +361,77 @@ def test_draw_keypoints_colored(colors): assert_equal(img, img_cp) +@pytest.mark.parametrize("connectivity", [[(0, 1)], [(0, 1), (1, 2)]]) +@pytest.mark.parametrize( + "vis", + [ + torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.bool), + torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.float).unsqueeze_(-1), + ], +) +def test_draw_keypoints_visibility(connectivity, vis): + # Keypoints is declared on top as global variable + keypoints_cp = keypoints.clone() + + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + img_cp = img.clone() + + vis_cp = vis if vis is None else vis.clone() + + result = utils.draw_keypoints( + image=img, + keypoints=keypoints, + connectivity=connectivity, + colors="red", + visibility=vis, + ) + assert result.size(0) == 3 + assert_equal(keypoints, keypoints_cp) + assert_equal(img, img_cp) + + # compare with a fakedata image + # connect the key points 0 to 1 for both skeletons and do not show the other key points + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoints_visibility.png" + ) + if not os.path.exists(path): + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) + + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + + if vis_cp is None: + assert vis is None + else: + assert_equal(vis, vis_cp) + assert vis.dtype == vis_cp.dtype + + +def test_draw_keypoints_visibility_default(): + # Keypoints is declared on top as global variable + keypoints_cp = keypoints.clone() + + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + img_cp = img.clone() + + result = utils.draw_keypoints( + image=img, + keypoints=keypoints, + connectivity=[(0, 1)], + colors="red", + visibility=None, + ) + assert result.size(0) == 3 + assert_equal(keypoints, keypoints_cp) + assert_equal(img, img_cp) + + # compare against fakedata image, which connects 0->1 for both key-point skeletons + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + + def test_draw_keypoints_errors(): h, w = 10, 10 img = torch.full((3, 100, 100), 0, dtype=torch.uint8) @@ -379,6 +450,18 @@ def test_draw_keypoints_errors(): with pytest.raises(ValueError, match="keypoints must be of shape"): invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float) utils.draw_keypoints(image=img, keypoints=invalid_keypoints) + with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")): + one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool) + utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility) + with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")): + three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool) + utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility) + with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"): + vis_wrong_n = torch.ones((3, 3), dtype=torch.bool) + utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n) + with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"): + vis_wrong_k = torch.ones((2, 4), dtype=torch.bool) + utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k) @pytest.mark.parametrize("batch", (True, False)) diff --git a/torchvision/utils.py b/torchvision/utils.py index 530f938f7a7..79e533d4663 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -331,22 +331,36 @@ def draw_keypoints( colors: Optional[Union[str, Tuple[int, int, int]]] = None, radius: int = 2, width: int = 3, + visibility: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Draws Keypoints on given RGB image. The values of the input image should be uint8 between 0 and 255. + Keypoints can be drawn for multiple instances at a time. + + This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint. Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances, in the format [x, y]. - connectivity (List[Tuple[int, int]]]): A List of tuple where, - each tuple contains pair of keypoints to be connected. + connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints + to be connected. + If at least one of the two connected keypoints has a ``visibility`` of False, + this specific connection is not drawn. + Exclusions due to invisibility are computed per-instance. colors (str, Tuple): The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. radius (int): Integer denoting radius of keypoint. width (int): Integer denoting width of line connecting keypoints. + visibility (Tensor): Tensor of shape (num_instances, K) specifying the visibility of the K + keypoints for each of the N instances. + True means that the respective keypoint is visible and should be drawn. + False means invisible, so neither the point nor possible connections containing it are drawn. + The input tensor will be cast to bool. + Default ``None`` means that all the keypoints are visible. + For more details, see :ref:`draw_keypoints_with_visibility`. Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. @@ -354,6 +368,7 @@ def draw_keypoints( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(draw_keypoints) + # validate image if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype != torch.uint8: @@ -363,24 +378,45 @@ def draw_keypoints( elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") + # validate keypoints if keypoints.ndim != 3: raise ValueError("keypoints must be of shape (num_instances, K, 2)") + # validate visibility + if visibility is None: # set default + visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool) + # If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction + # model, make sure visibility has shape (num_instances, K). + # Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place. + visibility = visibility.squeeze(-1) + if visibility.ndim != 2: + raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}") + if visibility.shape != keypoints.shape[:-1]: + raise ValueError( + "keypoints and visibility must have the same dimensionality for num_instances and K. " + f"Got {visibility.shape = } and {keypoints.shape = }" + ) + ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) img_kpts = keypoints.to(torch.int64).tolist() - - for kpt_id, kpt_inst in enumerate(img_kpts): - for inst_id, kpt in enumerate(kpt_inst): - x1 = kpt[0] - radius - x2 = kpt[0] + radius - y1 = kpt[1] - radius - y2 = kpt[1] + radius + img_vis = visibility.cpu().bool().tolist() + + for kpt_inst, vis_inst in zip(img_kpts, img_vis): + for kpt_coord, kp_vis in zip(kpt_inst, vis_inst): + if not kp_vis: + continue + x1 = kpt_coord[0] - radius + x2 = kpt_coord[0] + radius + y1 = kpt_coord[1] - radius + y2 = kpt_coord[1] + radius draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) if connectivity: for connection in connectivity: + if (not vis_inst[connection[0]]) or (not vis_inst[connection[1]]): + continue start_pt_x = kpt_inst[connection[0]][0] start_pt_y = kpt_inst[connection[0]][1]