Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GaussianNoise #8381

Merged
merged 6 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions torchvision/transforms/_functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,22 @@ def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.equalize(img)

@torch.jit.unused
def gaussian_noise(img: Image.Image, mean: float = 0., var: float = 1.0) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")

if var < 0:
raise ValueError(f"var shouldn't be negative. Got {var}")

z = np.random.normal(
loc=mean,
scale=var,
size=(
*get_image_size(img),
get_image_num_channels(img),
),
)

return img + z
1 change: 1 addition & 0 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"RandomAutocontrast",
"RandomEqualize",
"ElasticTransform",
"GaussianNoise",
]


Expand Down
21 changes: 21 additions & 0 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,24 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)


class GaussianNoise(Transform):
"""Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default.

If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Args:
mean (float): Mean of the sampled gaussian distribution. Default is 0.
var (float): Variance of the sampled gaussian distribution. Default is 1.
"""

def __init__(self, mean: float = 0., var: float = 1.) -> None:
super().__init__()
self.mean = mean
self.var = var

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, var=self.var)
25 changes: 25 additions & 0 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,28 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int])
@_register_kernel_internal(permute_channels, tv_tensors.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image(video, permutation=permutation)

def gaussian_noise(inpt: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GaussianNoise`"""
if torch.jit.is_scripting():
return gaussian_noise_image(inpt, mean=mean, var=var)

_log_api_usage_once(gaussian_noise)

kernel = _get_kernel(gaussian_noise, type(inpt))
return kernel(inpt, mean=mean, var=var)

@_register_kernel_internal(gaussian_noise, torch.Tensor)
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
def gaussian_noise_image(image: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor:
if var < 0:
raise ValueError(f"var shouldn't be negative. Got {var}")

if image.numel() == 0:
return image

z = mean + torch.randn_like(image) * var

return image + z

_gaussian_noise_pil = _register_kernel_internal(gaussian_noise, PIL.Image.Image)(_FP.gaussian_noise)