Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gaussian noise transform #6192 #6233

Closed
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9be5e13
adds gaussian noise transform
parth-shastri Jul 3, 2022
42952f6
adds gaussian noise transform
parth-shastri Jul 3, 2022
db9756a
Update torchvision/transforms/transforms.py
parth-shastri Jul 4, 2022
05a52af
Update torchvision/transforms/transforms.py
parth-shastri Jul 4, 2022
c380281
Delete _C.pyd
parth-shastri Jul 4, 2022
431a7e0
added GaussianNoise transform
parth-shastri Jul 4, 2022
5fc7c85
adds the GaussianNoise transform
parth-shastri Jul 4, 2022
179908b
adds GaussianNoise transform
parth-shastri Jul 4, 2022
396abba
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 6, 2022
5d8b0f7
fixes on the lint tests and the plot_transforms
parth-shastri Jul 12, 2022
98e4e98
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 12, 2022
aa9b2e7
test
parth-shastri Jul 13, 2022
223074f
test
parth-shastri Jul 13, 2022
a26ed67
Merge branch 'add-gaussian-noise-transform' of https://github.com/par…
parth-shastri Jul 13, 2022
6d57443
fixes the plot_transforms bug
parth-shastri Jul 13, 2022
1d1fbcd
Update torchvision/transforms/transforms.py
parth-shastri Jul 27, 2022
ff80571
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 27, 2022
533e76f
Update gallery/plot_transforms.py
parth-shastri Jul 27, 2022
b8d98d7
Update test_transforms.py
parth-shastri Jul 27, 2022
35ac3c9
update
parth-shastri Jul 27, 2022
42d49ae
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 27, 2022
74f92b1
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 31, 2022
54234f9
Update test_transforms.py
parth-shastri Jul 31, 2022
28fbd4b
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Aug 3, 2022
6a95453
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Aug 10, 2022
24804f6
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Aug 12, 2022
b1bb81f
lint
parth-shastri Aug 12, 2022
e08c9c1
lint updated
parth-shastri Aug 12, 2022
7fc04aa
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Sep 7, 2022
8a500fe
adds functional transforms, fixed sigma
parth-shastri Sep 7, 2022
8b560fb
updated lint, adds functional transforms, fixed sigma
parth-shastri Sep 7, 2022
ac15585
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Oct 5, 2022
3a85c34
Update torchvision/transforms/functional_tensor.py
parth-shastri Oct 5, 2022
5892695
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Oct 6, 2022
e6b4e45
suggested changes
parth-shastri Oct 6, 2022
58d525d
update
parth-shastri Oct 6, 2022
021ecba
fixed docs
parth-shastri Oct 6, 2022
5956088
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Oct 6, 2022
92d024f
Merge branch 'main' into add-gaussian-noise-transform
datumbox Oct 27, 2022
bfde863
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Nov 9, 2022
dbc3e1a
Merge branch 'main' into add-gaussian-noise-transform
datumbox Nov 10, 2022
d18195b
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Nov 27, 2022
a8d8137
fixes for random calls in functional transforms
parth-shastri Nov 27, 2022
2f7f558
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Transforms on PIL Image and torch.\*Tensor
Resize
TenCrop
GaussianBlur
GaussianNoise
RandomInvert
RandomPosterize
RandomSolarize
Expand Down
9 changes: 9 additions & 0 deletions gallery/plot_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)

####################################
# GaussianNoise
# ~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.GaussianNoise` transform
# perturbs the input image with gaussian noise.
noisy = T.GaussianNoise(mean=0, sigma=(0.1, 2.0))
noisy_imgs = [noisy(orig_img) for _ in range(2)]
plot(noisy_imgs)

####################################
# RandomPerspective
# ~~~~~~~~~~~~~~~~~
Expand Down
18 changes: 18 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,24 @@ def test_gaussian_blur_asserts():
transforms.GaussianBlur(3, "sigma_string")


def test_gaussian_noise():
np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
img = F.to_pil_image(np_img, "RGB")
transforms.GaussianNoise(2.0, (0.1, 2.0))(img)

with pytest.raises(TypeError, match="Tensor is not a torch image"):
transforms.GaussianNoise(2.0, (0.1, 2.0))(torch.ones(4))

with pytest.raises(ValueError, match="Mean should be a positive number"):
transforms.GaussianNoise(-1)

with pytest.raises(ValueError, match="If sigma is a single number, it must be positive."):
transforms.GaussianNoise(2.0, -1)

with pytest.raises(ValueError, match="sigma should be a single number or a list/tuple with length 2."):
transforms.GaussianNoise(2.0, (1, 2, 3))


def test_lambda():
trans = transforms.Lambda(lambda x: x.add(10))
x = torch.randn(10)
Expand Down
49 changes: 49 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,55 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
return output


def gaussian_noise(img: Tensor, mean: float, sigma: Optional[List[float]]) -> Tensor:
"""Performs Gaussian blurring on the image by given kernel.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
img (PIL Image or Tensor): Image to be blurred
mean (float): Mean of the desired noise corruption.
sigma (sequence of floats or float, optional): Gaussian noise standard deviation. Can be a
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
same sigma in both X/Y directions.

.. note::
In torchscript mode sigma as single float is
not supported, use a sequence of length 1: ``[sigma, ]``.

Returns:
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(gaussian_noise)

if sigma is None:
raise ValueError("The value of sigma cannot be None.")

if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is unclear here about whether sigma is Optional or not. Docstring and typehint says optional, but you raise an error if sigma is None. Let's remove optional and assume that sigma is not None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parth-shastri check above comment and the code. Why do you check sigma is None if sigma is not intended to be Optional.

if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if isinstance(sigma, (list, tuple)) and len(sigma) == 2:
sigma = torch.empty(1).uniform_(sigma[0], sigma[1]).item()
parth-shastri marked this conversation as resolved.
Show resolved Hide resolved
if sigma <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")

t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")

t_img = pil_to_tensor(img)

output = F_t.gaussian_noise(t_img, mean, sigma)

if not isinstance(img, torch.Tensor):
output = to_pil_image(output, mode=img.mode)

return output


def invert(img: Tensor) -> Tensor:
"""Invert the colors of an RGB/grayscale image.

Expand Down
25 changes: 23 additions & 2 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,6 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
kernel.dtype,
],
)

# padding = (left, right, top, bottom)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this comment

padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
Expand All @@ -758,6 +756,29 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
return img


def gaussian_noise(img: Tensor, mean: float, sigma: float) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")

_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
dtype,
],
)
parth-shastri marked this conversation as resolved.
Show resolved Hide resolved
# add the gaussian noise with the given mean and sigma.
normalize_img = img / 255.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't make this assumption that img range is [0, 255].
By default, we assume: [0, 1] for float images and [0, 255] for uint8 RGB images.
Let's not rescale image range and let user pick appropriate mean and sigma according to their data range.

noise = sigma * torch.randn_like(img) + mean
img = normalize_img + noise
img = torch.clip(img, 0, 1)
img = img * 255.0

img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img


def invert(img: Tensor) -> Tensor:

_assert_image_tensor(img)
Expand Down
59 changes: 59 additions & 0 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"RandomPerspective",
"RandomErasing",
"GaussianBlur",
"GaussianNoise",
parth-shastri marked this conversation as resolved.
Show resolved Hide resolved
"InterpolationMode",
"RandomInvert",
"RandomPosterize",
Expand Down Expand Up @@ -1837,6 +1838,64 @@ def __repr__(self) -> str:
return s


class GaussianNoise(torch.nn.Module):
"""Adds Gaussian noise to the image with specified mean and standard deviation.
If the image is torch Tensor, it is expected
to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
mean (float or sequence): Mean of the sampling gaussian distribution .
sigma (float or tuple of float (min, max)): Standard deviation to be used for
sampling the gaussian noise. If float, sigma is fixed. If it is tuple
of float (min, max), sigma is chosen uniformly at random to lie in the
given range.

Returns:
PIL Image or Tensor: Input image perturbed with Gaussian Noise.

"""

def __init__(self, mean, sigma=(0.1, 0.5)):
super().__init__()
_log_api_usage_once(self)

if mean < 0:
raise ValueError("Mean should be a positive number")

if isinstance(sigma, numbers.Number):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise ValueError("sigma should be a single number or a list/tuple with length 2.")

self.mean = mean
self.sigma = sigma

@staticmethod
def get_params(sigma_min: float, sigma_max: float) -> float:
return torch.empty(1).uniform_(sigma_min, sigma_max).item()

def forward(self, image: Tensor) -> Tensor:
"""
Args:
image (PIL Image or Tensor): image to be perturbed with gaussian noise.

Returns:
PIL Image or Tensor: Image added with gaussian noise.
"""
sigma = self.get_params(self.sigma[0], self.sigma[1])
output = F.gaussian_noise(image, self.mean, sigma)
return output

def __repr__(self) -> str:
s = f"{self.__class__.__name__}(mean={self.mean}, sigma={self.sigma})"
return s


def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
Expand Down