From 5c07947a24b53612498dabc516ae275a6eb4a747 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 23 Jan 2024 15:07:47 -0800 Subject: [PATCH 1/8] Testing fixing this issue --- test/test_transforms_v2.py | 7 +++++-- torchvision/transforms/_functional_pil.py | 1 + torchvision/transforms/_functional_tensor.py | 2 ++ torchvision/transforms/v2/_color.py | 2 ++ torchvision/transforms/v2/functional/_color.py | 16 +++++++++++++++- 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 831a7e3b570..2b8e6078941 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4935,12 +4935,15 @@ def test_transform(self, transform, make_input): check_transform(transform, make_input()) @pytest.mark.parametrize("num_output_channels", [1, 3]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) @pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)]) - def test_image_correctness(self, num_output_channels, fn): - image = make_image(dtype=torch.uint8, device="cpu") + def test_image_correctness(self, num_output_channels, color_space, fn): + image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space) actual = fn(image, num_output_channels=num_output_channels) expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels)) + print(f"Ahmad test {num_output_channels=} {image.shape=} {actual.shape=}") + #assert_equal(True, False) assert_equal(actual, expected, rtol=0, atol=1) diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index 277848224ac..5f532b77ff8 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -332,6 +332,7 @@ def perspective( @torch.jit.unused def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: + print("Ahmad here3") if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 88dc9ca21cc..c4d2208da99 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -151,6 +151,8 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: if num_output_channels not in (1, 3): raise ValueError("num_output_channels should be either 1 or 3") + print("ahmad here") + if img.shape[-3] == 3: r, g, b = img.unbind(dim=-3) # This implementation closely follows the TF one: diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index d20953451ab..870fe3f38ba 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -22,10 +22,12 @@ class Grayscale(Transform): _v1_transform_cls = _transforms.Grayscale def __init__(self, num_output_channels: int = 1): + print("Ahmad here in init functional library 32") super().__init__() self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + print("Ahmad here outside functional") return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index b0189fd95ef..7d61aaef1e4 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -33,22 +33,36 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch. def _rgb_to_grayscale_image( image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True ) -> torch.Tensor: - if image.shape[-3] == 1: + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + print(f"Finally ahmad here aa {num_output_channels=} {image.shape=}") + if image.shape[-3] == 1 and num_output_channels == 1: + print("ahmad cloning") return image.clone() + if image.shape[-3] == 1 and num_output_channels == 3: + s = [-1] * len(image.shape) + s[-3] = 3 + image = image.expand(s) + + print(f"Finally ahmad here bb {num_output_channels=} {image.shape=}") + r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3) if preserve_dtype: l_img = l_img.to(image.dtype) + print(f"ahmad: {l_img.shape=}") if num_output_channels == 3: l_img = l_img.expand(image.shape) + print(f"ahmad: {l_img.shape=}") return l_img @_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, tv_tensors.Image) def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + print("ahmad here 4") if num_output_channels not in (1, 3): raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") return _rgb_to_grayscale_image(image, num_output_channels=num_output_channels, preserve_dtype=True) From 327ceda582c295b85a2c8df6f07605ec761f35b9 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 23 Jan 2024 15:20:42 -0800 Subject: [PATCH 2/8] Removed debug statements --- torchvision/transforms/v2/functional/_color.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 7d61aaef1e4..dd528dfd9f2 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -35,27 +35,19 @@ def _rgb_to_grayscale_image( ) -> torch.Tensor: if num_output_channels not in (1, 3): raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") - print(f"Finally ahmad here aa {num_output_channels=} {image.shape=}") if image.shape[-3] == 1 and num_output_channels == 1: - print("ahmad cloning") return image.clone() - if image.shape[-3] == 1 and num_output_channels == 3: s = [-1] * len(image.shape) s[-3] = 3 image = image.expand(s) - - print(f"Finally ahmad here bb {num_output_channels=} {image.shape=}") - r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3) if preserve_dtype: l_img = l_img.to(image.dtype) - print(f"ahmad: {l_img.shape=}") if num_output_channels == 3: l_img = l_img.expand(image.shape) - print(f"ahmad: {l_img.shape=}") return l_img From 6cca039c5514d2fd0039e7470203a376de9956e3 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 23 Jan 2024 15:22:08 -0800 Subject: [PATCH 3/8] Removed debug statements --- torchvision/transforms/_functional_pil.py | 1 - torchvision/transforms/_functional_tensor.py | 2 -- torchvision/transforms/v2/functional/_color.py | 1 - 3 files changed, 4 deletions(-) diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index 5f532b77ff8..277848224ac 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -332,7 +332,6 @@ def perspective( @torch.jit.unused def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: - print("Ahmad here3") if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index c4d2208da99..88dc9ca21cc 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -151,8 +151,6 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: if num_output_channels not in (1, 3): raise ValueError("num_output_channels should be either 1 or 3") - print("ahmad here") - if img.shape[-3] == 3: r, g, b = img.unbind(dim=-3) # This implementation closely follows the TF one: diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index dd528dfd9f2..3b8b57c6b3d 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -54,7 +54,6 @@ def _rgb_to_grayscale_image( @_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, tv_tensors.Image) def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: - print("ahmad here 4") if num_output_channels not in (1, 3): raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") return _rgb_to_grayscale_image(image, num_output_channels=num_output_channels, preserve_dtype=True) From 668f4bbc9e1118178419e166ee2d27673a7123a4 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 23 Jan 2024 15:22:41 -0800 Subject: [PATCH 4/8] . --- torchvision/transforms/v2/_color.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 870fe3f38ba..d20953451ab 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -22,12 +22,10 @@ class Grayscale(Transform): _v1_transform_cls = _transforms.Grayscale def __init__(self, num_output_channels: int = 1): - print("Ahmad here in init functional library 32") super().__init__() self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - print("Ahmad here outside functional") return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) From beaa2acd926c9ce266d85af783493b1117659c64 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 23 Jan 2024 15:42:13 -0800 Subject: [PATCH 5/8] Removed debug messages --- test/test_transforms_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 2b8e6078941..a0d398d553c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4942,8 +4942,6 @@ def test_image_correctness(self, num_output_channels, color_space, fn): actual = fn(image, num_output_channels=num_output_channels) expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels)) - print(f"Ahmad test {num_output_channels=} {image.shape=} {actual.shape=}") - #assert_equal(True, False) assert_equal(actual, expected, rtol=0, atol=1) From dcec85534a6e3684577c9e3734a548ba85a0e1c2 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 24 Jan 2024 12:51:34 -0800 Subject: [PATCH 6/8] Addressed comments --- torchvision/transforms/v2/functional/_color.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 3b8b57c6b3d..b8503d1a9a6 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -33,14 +33,13 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch. def _rgb_to_grayscale_image( image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True ) -> torch.Tensor: - if num_output_channels not in (1, 3): - raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + # TODO: Move the validation that num_output_channels is 1 or 3 to this function instead of callers. if image.shape[-3] == 1 and num_output_channels == 1: return image.clone() if image.shape[-3] == 1 and num_output_channels == 3: s = [-1] * len(image.shape) s[-3] = 3 - image = image.expand(s) + return image.expand(s) r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3) From 4a78738b517091485ffc96d71421829d5aefc118 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Jan 2024 10:10:48 +0000 Subject: [PATCH 7/8] Update torchvision/transforms/v2/functional/_color.py --- torchvision/transforms/v2/functional/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index b8503d1a9a6..3f140db7136 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -33,7 +33,7 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch. def _rgb_to_grayscale_image( image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True ) -> torch.Tensor: - # TODO: Move the validation that num_output_channels is 1 or 3 to this function instead of callers. + # TODO: Maybe move the validation that num_output_channels is 1 or 3 to this function instead of callers. if image.shape[-3] == 1 and num_output_channels == 1: return image.clone() if image.shape[-3] == 1 and num_output_channels == 3: From c731be6c1f747ed71d5fe7f594acc4624645bdd6 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 25 Jan 2024 12:55:40 -0800 Subject: [PATCH 8/8] Use repeat instead of expand so we get clones instead of views into the same tensor --- test/test_transforms_v2.py | 8 ++++++++ torchvision/transforms/v2/functional/_color.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a0d398d553c..b40d04fffdd 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4945,6 +4945,14 @@ def test_image_correctness(self, num_output_channels, color_space, fn): assert_equal(actual, expected, rtol=0, atol=1) + def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self): + image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") + + output_image = F.rgb_to_grayscale(image, num_output_channels=3) + assert_equal(output_image[0][0][0], output_image[1][0][0]) + output_image[0][0][0] = output_image[0][0][0] + 1 + assert output_image[0][0][0] != output_image[1][0][0] + @pytest.mark.parametrize("num_input_channels", [1, 3]) def test_random_transform_correctness(self, num_input_channels): image = make_image( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index b8503d1a9a6..722202bafcc 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -37,9 +37,9 @@ def _rgb_to_grayscale_image( if image.shape[-3] == 1 and num_output_channels == 1: return image.clone() if image.shape[-3] == 1 and num_output_channels == 3: - s = [-1] * len(image.shape) + s = [1] * len(image.shape) s[-3] = 3 - return image.expand(s) + return image.repeat(s) r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3)