Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grascale transform cannot increase the number of color channels #8167

Closed
landoskape opened this issue Dec 18, 2023 · 2 comments
Closed

Grascale transform cannot increase the number of color channels #8167

landoskape opened this issue Dec 18, 2023 · 2 comments

Comments

@landoskape
Copy link

landoskape commented Dec 18, 2023

🐛 Describe the bug

Background

As I understand it, the transform called Grayscale (found here in v2), should allow a user to request 1 or 3 output channels with the keyword argument num_output_channels.

Expected and Actual Behavior

I was attempting to use Grayscale to transform a grayscale image to RGB (it's a hack, I know, but a simple way to use a pretrained network like AlexNet on a grayscale dataset like MNIST). So, I tried to add Grayscale as a transform that would increase the number of output channels from 1 to 3. This failed, and maintained a single color channel after the transform.

Comment

In the case where a dataloader produces an image, label tuple, one can simply call images = images.repeat(1, 3, 1, 1) after each batch is generated by the dataloader. But it's nice to have everything prepackaged in a composed transform.

If what I've described here is not the desired behavior of Grayscale for the torch maintainers, I understand, but wonder if it would be useful to make another transform called GrayscaleToRGB with the optional input argument map that defines how much energy to put into each color channel from the gray values? (e.g. rgb = gray * map.view(3, 1, 1))

Minimum Working Example

Here is a MWE with a toy dataset to show the observed behavior:

import torch
from torchvision.transforms import v2 as transforms

class random_image(torch.utils.data.Dataset):
    def __init__(self, num_input_channels=1, transform=None):
        self.images = torch.normal(0, 1, (100, num_input_channels, 28, 28))
        self.transform = transform

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.images[idx]
        if self.transform:
            sample = self.transform(sample)
        
        return sample
    
    def __getitem_notransform__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.images[idx]

    
def get_transform(num_output_channels=1):
    transform = transforms.Compose([
        transforms.ToImage(), 
        transforms.Grayscale(num_output_channels=num_output_channels)
        ])
    return transform

# Convert RGB to Grayscale... (works as expected)
dataset = random_image(num_input_channels=3, transform=get_transform(num_output_channels=1))
item = dataset.__getitem__(1)
item_pretransform = dataset.__getitem_notransform__(1)
print("\nConverting RGB image to grayscale (num_input_channels=3, num_output_channels=1)")
print("Before:", item_pretransform.shape, "  After:", item.shape)
print("This works as expected")

# Convert RGB to RGB.... (why one would do this, don't know)
dataset = random_image(num_input_channels=3, transform=get_transform(num_output_channels=3))
item = dataset.__getitem__(1)
item_pretransform = dataset.__getitem_notransform__(1)
print("\nConverting RGB image to RGB (num_input_channels=3, num_output_channels=3)")
print("Before:", item_pretransform.shape, "  After:", item.shape)
print("This works as expected, (it's never really necessary, but confirms that Grayscale can maintain 3 color channels)")

# Convert Grayscale to RGB.... this doesn't work
dataset = random_image(num_input_channels=1, transform=get_transform(num_output_channels=3))
item = dataset.__getitem__(1)
item_pretransform = dataset.__getitem_notransform__(1)
print("\nConverting Grayscale image to RGB (num_input_channels=1, num_output_channels=3)")
print("Before:", item_pretransform.shape, "  After:", item.shape)
print("This fails to expand the grayscale image to have 3 channels")

Output

# Converting RGB image to grayscale (num_input_channels=3, num_output_channels=1)
# Before: torch.Size([3, 28, 28])   After: torch.Size([1, 28, 28])
# This works as expected

# Converting RGB image to RGB (num_input_channels=3, num_output_channels=3)
# Before: torch.Size([3, 28, 28])   After: torch.Size([3, 28, 28])
# This works as expected, (it's never really necessary, but confirms that Grayscale can maintain 3 color channels)

# Converting Grayscale image to RGB (num_input_channels=1, num_output_channels=3)
# Before: torch.Size([1, 28, 28])   After: torch.Size([1, 28, 28])
# This fails to expand the grayscale image to have 3 channels

Versions

Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Enterprise
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.18 | packaged by conda-forge | (main, Aug 30 2023, 03:40:31) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 536.23
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=3600
DeviceID=CPU0
Family=198
L2CacheSize=2048
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=3600
Name=12th Gen Intel(R) Core(TM) i7-12700K
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchvision==0.16.1
[conda] blas 2.120 mkl conda-forge
[conda] blas-devel 3.9.0 20_win64_mkl conda-forge
[conda] libblas 3.9.0 20_win64_mkl conda-forge
[conda] libcblas 3.9.0 20_win64_mkl conda-forge
[conda] liblapack 3.9.0 20_win64_mkl conda-forge
[conda] liblapacke 3.9.0 20_win64_mkl conda-forge
[conda] mkl 2023.2.0 h6a75c08_50497 conda-forge
[conda] mkl-devel 2023.2.0 h57928b3_50497 conda-forge
[conda] mkl-include 2023.2.0 h6a75c08_50497 conda-forge
[conda] numpy 1.26.2 py39hddb5d58_0 conda-forge
[conda] pytorch 2.1.1 py3.9_cuda12.1_cudnn8_0 pytorch
[conda] pytorch-cuda 12.1 hde6ce7c_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 2.1.1 pypi_0 pypi
[conda] torchvision 0.16.1 pypi_0 pypi

@NicolasHug
Copy link
Member

NicolasHug commented Jan 15, 2024

Thank you for the report @landoskape . I agree with you that one would expect

Grayscale(num_output_channels=3)(img_with_one_channel)

to result in an output with 3 channels.

I think this is something @ahmadsharif1 might be interested in fixing. (If not, LMK Ahmad and we'll ask someone else to take it up). Thanks!

@NicolasHug
Copy link
Member

Closed by #8229

Thanks for the report @landoskape and @ahmadsharif1 for the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants