diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 3d04b88f50e..95451d2866e 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -85,7 +85,7 @@ def test_plot(self, dataset: VHR10) -> None: x['prediction_labels'] = x['class'] x['prediction_boxes'] = x['bbox_xyxy'] x['prediction_scores'] = torch.Tensor([scores[i]]) - if 'masks' in x: + if 'mask' in x: x['prediction_masks'] = x['mask'] dataset.plot(x, show_feats='masks') plt.close() diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index b0f0f2de340..6f8be71852a 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -459,7 +459,7 @@ def plot( ) # Add masks - if show_feats in {'masks', 'both'} and 'masks' in sample: + if show_feats in {'masks', 'both'} and 'mask' in sample: mask = masks[i] contours = skimage.measure.find_contours(mask, 0.5) for verts in contours: