From 83a5ab6421bb8224a2c17c469d6f7f7d1f18e445 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 7 Nov 2023 11:22:34 +0100 Subject: [PATCH 1/3] refactor to_pil_image --- torchvision/transforms/functional.py | 46 ++++++++++------------------ 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d176e00a8da..80ed8e8431d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -258,41 +258,27 @@ def to_pil_image(pic, mode=None): if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(to_pil_image) - if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): + if isinstance(pic, Image.Image): + return pic + if isinstance(pic, torch.Tensor): + if pic.is_floating_point() and mode != "F": + pic = pic.mul(255).byte() + if pic.ndim == 3: + pic = pic.permute((1, 2, 0)) + pic = pic.numpy(force=True) + elif not isinstance(pic, np.ndarray): raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") - elif isinstance(pic, torch.Tensor): - if pic.ndimension() not in {2, 3}: - raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.") - - elif pic.ndimension() == 2: - # if 2D image, add channel dimension (CHW) - pic = pic.unsqueeze(0) - - # check number of channels - if pic.shape[-3] > 4: - raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.") - - elif isinstance(pic, np.ndarray): - if pic.ndim not in {2, 3}: - raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.") - - elif pic.ndim == 2: - # if 2D image, add channel dimension (HWC) - pic = np.expand_dims(pic, 2) + if pic.ndim == 2: + # if 2D image, add channel dimension (HWC) + pic = np.expand_dims(pic, 2) + if pic.ndim != 3: + raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.") - # check number of channels - if pic.shape[-1] > 4: - raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.") + if pic.shape[-1] > 4: + raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.") npimg = pic - if isinstance(pic, torch.Tensor): - if pic.is_floating_point() and mode != "F": - pic = pic.mul(255).byte() - npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) - - if not isinstance(npimg, np.ndarray): - raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}") if npimg.shape[2] == 1: expected_mode = None From 91981e6407d4b53404b39f9b55e0048d5438a2cb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 7 Nov 2023 11:28:50 +0100 Subject: [PATCH 2/3] align numpy and torch on floating point inputs Co-authored-by: Nicolas Hug --- test/test_transforms.py | 10 ++++++---- torchvision/transforms/functional.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 7c92baa9f5c..16d0e7e5d94 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -661,7 +661,7 @@ def test_1_channel_float_tensor_to_pil_image(self): @pytest.mark.parametrize( "img_data, expected_mode", [ - (torch.Tensor(4, 4, 1).uniform_().numpy(), "F"), + (torch.Tensor(4, 4, 1).uniform_().numpy(), "L"), (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"), (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"), (torch.IntTensor(4, 4, 1).random_().numpy(), "I"), @@ -671,6 +671,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) assert img.mode == expected_mode + if np.issubdtype(img_data.dtype, np.floating): + img_data = (img_data * 255).astype(np.uint8) # note: we explicitly convert img's dtype because pytorch doesn't support uint16 # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype)) @@ -741,7 +743,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe @pytest.mark.parametrize( "img_data, expected_mode", [ - (torch.Tensor(4, 4).uniform_().numpy(), "F"), + (torch.Tensor(4, 4).uniform_().numpy(), "L"), (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"), (torch.ShortTensor(4, 4).random_().numpy(), "I;16"), (torch.IntTensor(4, 4).random_().numpy(), "I"), @@ -751,6 +753,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) assert img.mode == expected_mode + if np.issubdtype(img_data.dtype, np.floating): + img_data = (img_data * 255).astype(np.uint8) np.testing.assert_allclose(img_data, img) @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"]) @@ -874,8 +878,6 @@ def test_ndarray_bad_types_to_pil_image(self): trans(np.ones([4, 4, 1], np.uint16)) with pytest.raises(TypeError, match=reg_msg): trans(np.ones([4, 4, 1], np.uint32)) - with pytest.raises(TypeError, match=reg_msg): - trans(np.ones([4, 4, 1], np.float64)) with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 80ed8e8431d..2f118072eb1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -261,8 +261,6 @@ def to_pil_image(pic, mode=None): if isinstance(pic, Image.Image): return pic if isinstance(pic, torch.Tensor): - if pic.is_floating_point() and mode != "F": - pic = pic.mul(255).byte() if pic.ndim == 3: pic = pic.permute((1, 2, 0)) pic = pic.numpy(force=True) @@ -280,6 +278,9 @@ def to_pil_image(pic, mode=None): npimg = pic + if np.issubdtype(npimg.dtype, np.floating) and mode != "F": + npimg = (npimg * 255).astype(np.uint8) + if npimg.shape[2] == 1: expected_mode = None npimg = npimg[:, :, 0] From eeb81e08a9ac6ec71dcb762e2a8642a632745b62 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 7 Nov 2023 11:38:04 +0100 Subject: [PATCH 3/3] cleanup --- torchvision/transforms/functional.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 2f118072eb1..7cbe2d99071 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -258,8 +258,6 @@ def to_pil_image(pic, mode=None): if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(to_pil_image) - if isinstance(pic, Image.Image): - return pic if isinstance(pic, torch.Tensor): if pic.ndim == 3: pic = pic.permute((1, 2, 0))