diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 5c21897cdf8..d4c94e4760e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -347,6 +347,7 @@ Color v2.RandomChannelPermutation v2.RandomPhotometricDistort v2.Grayscale + v2.RGB v2.RandomGrayscale v2.GaussianBlur v2.RandomInvert @@ -364,6 +365,7 @@ Functionals v2.functional.permute_channels v2.functional.rgb_to_grayscale + v2.functional.grayscale_to_rgb v2.functional.to_grayscale v2.functional.gaussian_blur v2.functional.invert @@ -584,7 +586,7 @@ Conversion while performing the conversion, while some may not do any scaling. By scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0, 255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`. - + .. autosummary:: :toctree: generated/ :template: class.rst diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 49855400e85..1ad47dda02e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5005,6 +5005,54 @@ def test_random_transform_correctness(self, num_input_channels): assert_equal(actual, expected, rtol=0, atol=1) +class TestGrayscaleToRgb: + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image(self, dtype, device): + check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device)) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + def test_functional(self, make_input): + check_functional(F.grayscale_to_rgb, make_input()) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.rgb_to_grayscale_image, torch.Tensor), + (F._rgb_to_grayscale_image_pil, PIL.Image.Image), + (F.rgb_to_grayscale_image, tv_tensors.Image), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + def test_transform(self, make_input): + check_transform(transforms.RGB(), make_input(color_space="GRAY")) + + @pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)]) + def test_image_correctness(self, fn): + image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") + + actual = fn(image) + expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image))) + + assert_equal(actual, expected, rtol=0, atol=1) + + def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self): + image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") + + output_image = F.grayscale_to_rgb(image) + assert_equal(output_image[0][0][0], output_image[1][0][0]) + output_image[0][0][0] = output_image[0][0][0] + 1 + assert output_image[0][0][0] != output_image[1][0][0] + + def test_rgb_image_is_unchanged(self): + image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB") + assert_equal(image.shape[-3], 3) + assert_equal(F.grayscale_to_rgb(image), image) + + class TestRandomZoomOut: # Tests are light because this largely relies on the already tested `pad` kernels. diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index dbc0474d307..fea39d3cf20 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -18,6 +18,7 @@ RandomPhotometricDistort, RandomPosterize, RandomSolarize, + RGB, ) from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index d20953451ab..49b4a8d8b10 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -54,6 +54,20 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) +class RGB(Transform): + """Convert images or videos to RGB (if they are already not RGB). + + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions + """ + + def __init__(self): + super().__init__() + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.grayscale_to_rgb, inpt) + + class ColorJitter(Transform): """Randomly change the brightness, contrast, saturation and hue of an image or video. diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 8f71a7463a7..69f5f4521fa 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -63,6 +63,8 @@ equalize, equalize_image, equalize_video, + grayscale_to_rgb, + grayscale_to_rgb_image, invert, invert_image, invert_video, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 2b9c1e738ca..3025f876dff 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -65,6 +65,32 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int return _FP.to_grayscale(image, num_output_channels=num_output_channels) +def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details.""" + if torch.jit.is_scripting(): + return grayscale_to_rgb_image(inpt) + + _log_api_usage_once(grayscale_to_rgb) + + kernel = _get_kernel(grayscale_to_rgb, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(grayscale_to_rgb, torch.Tensor) +@_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image) +def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor: + if image.shape[-3] >= 3: + # Image already has RGB channels. We don't need to do anything. + return image + # rgb_to_grayscale can be used to add channels so we reuse that function. + return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True) + + +@_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image) +def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: + return image.convert(mode="RGB") + + def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) fp = image1.is_floating_point()