diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 9cded944c21..2cd8faceed3 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -270,7 +270,7 @@ def _resize_seg(self, results): results['scale'], interpolation='nearest', backend=self.backend) - results['gt_semantic_seg'] = gt_seg + results[key] = gt_seg def __call__(self, results): """Call function to resize images, bounding boxes, masks, semantic diff --git a/tests/test_data/test_pipelines/test_transform/test_transform.py b/tests/test_data/test_pipelines/test_transform/test_transform.py index 282c02980c5..d57f063d062 100644 --- a/tests/test_data/test_pipelines/test_transform/test_transform.py +++ b/tests/test_data/test_pipelines/test_transform/test_transform.py @@ -81,6 +81,28 @@ def test_resize(): assert results['img_shape'] == (800, 1280, 3) assert results['img'].dtype == results['img'].dtype == np.uint8 + results_seg = { + 'img': img, + 'img_shape': img.shape, + 'ori_shape': img.shape, + 'gt_semantic_seg': copy.deepcopy(img), + 'gt_seg': copy.deepcopy(img), + 'seg_fields': ['gt_semantic_seg', 'gt_seg'] + } + transform = dict( + type='Resize', + img_scale=(640, 400), + multiscale_mode='value', + keep_ratio=False) + resize_module = build_from_cfg(transform, PIPELINES) + results_seg = resize_module(results_seg) + assert results_seg['gt_semantic_seg'].shape == results_seg['gt_seg'].shape + assert results_seg['img_shape'] == (400, 640, 3) + assert results_seg['img_shape'] != results_seg['ori_shape'] + assert results_seg['gt_semantic_seg'].shape == results_seg['img_shape'] + assert np.equal(results_seg['gt_semantic_seg'], + results_seg['gt_seg']).all() + def test_flip(): # test assertion for invalid flip_ratio