Skip to content

Commit

Permalink
Fix iso (#1829)
Browse files Browse the repository at this point in the history
* iso_noise supports float32

* Fix in ISONoise

* Cleanup

* fix

* fix
  • Loading branch information
ternaus authored Jul 3, 2024
1 parent e6219ab commit dae9bd6
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 68 deletions.
10 changes: 5 additions & 5 deletions albumentations/augmentations/domain_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from albumentations.augmentations.utils import read_rgb_image
from albumentations.core.transforms_interface import BaseTransformInitSchema, ImageOnlyTransform
from albumentations.augmentations import functional as fmain


from albumentations.core.pydantic import NonNegativeFloatRangeType, ZeroOneRangeType
Expand Down Expand Up @@ -108,8 +109,8 @@ def get_params(self) -> dict[str, np.ndarray]:
"blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]),
}

def get_transform_init_args_names(self) -> tuple[str, str, str]:
return ("reference_images", "blend_ratio", "read_fn")
def get_transform_init_args_names(self) -> tuple[str, ...]:
return "reference_images", "blend_ratio", "read_fn"

def to_dict_private(self) -> dict[str, Any]:
msg = "HistogramMatching can not be serialized."
Expand Down Expand Up @@ -318,9 +319,8 @@ def apply(self, img: np.ndarray, reference_image: np.ndarray, blend_ratio: float
weight=blend_ratio,
transform_type=self.transform_type,
)
if needs_reconvert:
adapted = adapted.astype("float32") * (1 / 255)
return adapted

return fmain.to_float(adapted) if needs_reconvert else adapted

def get_params(self) -> dict[str, Any]:
return {
Expand Down
18 changes: 10 additions & 8 deletions albumentations/augmentations/domain_adaptation_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import cv2
import numpy as np
from albucore.functions import add_weighted
from albucore.utils import clip, clipped, get_num_channels, preserve_channel_dim
from albucore.utils import clip, clipped, is_multispectral_image, preserve_channel_dim
from skimage.exposure import match_histograms
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from typing_extensions import Protocol

from albumentations.augmentations.functional import center
from albumentations.core.types import MONO_CHANNEL_DIMENSIONS
import albumentations.augmentations.functional as fmain

__all__ = [
"fourier_domain_adaptation",
Expand Down Expand Up @@ -57,7 +58,7 @@ def from_colorspace(self, img: np.ndarray) -> np.ndarray:

def flatten(self, img: np.ndarray) -> np.ndarray:
img = self.to_colorspace(img)
img = img.astype("float32") / 255.0
img = fmain.to_float(img)
return img.reshape(-1, 3)

def reconstruct(self, pixels: np.ndarray, height: int, width: int) -> np.ndarray:
Expand Down Expand Up @@ -86,18 +87,18 @@ def __call__(self, image: np.ndarray) -> np.ndarray:
return self.reconstruct(result, height, width)


@clipped
@preserve_channel_dim
def adapt_pixel_distribution(
img: np.ndarray,
ref: np.ndarray,
transform_type: str = "pca",
weight: float = 0.5,
) -> np.ndarray:
initial_type = img.dtype
transformer = {"pca": PCA, "standard": StandardScaler, "minmax": MinMaxScaler}[transform_type]()
adapter = DomainAdapter(transformer=transformer, ref_img=ref)
result = adapter(img).astype("float32")
return (img.astype("float32") * (1 - weight) + result * weight).astype(initial_type)
result = adapter(img).astype(np.float32)
return img.astype(np.float32) * (1 - weight) + result * weight


def low_freq_mutate(amp_src: np.ndarray, amp_trg: np.ndarray, beta: float) -> np.ndarray:
Expand Down Expand Up @@ -157,20 +158,21 @@ def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: flo
return src_in_trg


@clipped
@preserve_channel_dim
def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: float) -> np.ndarray:
# Resize reference image only if necessary
if img.shape[:2] != reference_image.shape[:2]:
reference_image = cv2.resize(reference_image, dsize=(img.shape[1], img.shape[0]))

img, reference_image = np.squeeze(img), np.squeeze(reference_image)
img = np.squeeze(img)
reference_image = np.squeeze(reference_image)

# Determine if the images are multi-channel based on a predefined condition or shape analysis
is_multichannel = get_num_channels(img) > 1
is_multichannel = is_multispectral_image(img)

# Match histograms between the images
matched = match_histograms(img, reference_image, channel_axis=2 if is_multichannel else None)

# Blend the original image and the matched image

return add_weighted(matched, blend_ratio, img, 1 - blend_ratio)
74 changes: 28 additions & 46 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,7 @@ def image_compression(img: np.ndarray, quality: int, image_type: Literal[".jpg",
_, encoded_img = cv2.imencode(image_type, img, (int(quality_flag), quality))
img = cv2.imdecode(encoded_img, cv2.IMREAD_UNCHANGED)

if needs_float:
img = to_float(img, max_value=255)
return img
return to_float(img, max_value=255) if needs_float else img


@preserve_channel_dim
Expand Down Expand Up @@ -482,10 +480,7 @@ def add_snow(img: np.ndarray, snow_point: float, brightness_coeff: float) -> np.

image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)

if needs_float:
image_rgb = to_float(image_rgb, max_value=255)

return image_rgb
return to_float(image_rgb, max_value=255) if needs_float else image_rgb


@preserve_channel_dim
Expand Down Expand Up @@ -547,10 +542,7 @@ def add_rain(

image_rgb = cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)

if needs_float:
return to_float(image_rgb, max_value=255)

return image_rgb
return to_float(image_rgb, max_value=255) if needs_float else image_rgb


@preserve_channel_dim
Expand Down Expand Up @@ -602,10 +594,7 @@ def add_fog(img: np.ndarray, fog_coef: float, alpha_coef: float, haze_list: list

image_rgb = cv2.blur(img, (hw // 10, hw // 10))

if needs_float:
image_rgb = to_float(image_rgb, max_value=255)

return image_rgb
return to_float(image_rgb, max_value=255) if needs_float else image_rgb


@preserve_channel_dim
Expand Down Expand Up @@ -659,10 +648,7 @@ def add_sun_flare(
alp = alpha[num_times - i - 1] * alpha[num_times - i - 1] * alpha[num_times - i - 1]
output = add_weighted(overlay, alp, output, 1 - alp)

if needs_float:
return to_float(output, max_value=255)

return output
return to_float(output, max_value=255) if needs_float else output


@contiguous
Expand Down Expand Up @@ -740,10 +726,7 @@ def add_gravel(img: np.ndarray, gravels: list[Any]) -> np.ndarray:

image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)

if needs_float:
image_rgb = to_float(image_rgb, max_value=255)

return image_rgb
return to_float(image_rgb, max_value=255) if needs_float else image_rgb


def invert(img: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -784,33 +767,35 @@ def iso_noise(
image: np.ndarray,
color_shift: float = 0.05,
intensity: float = 0.5,
random_state: int | None = None,
random_state: np.random.RandomState | None = None,
) -> np.ndarray:
"""Apply poisson noise to an image to simulate camera sensor noise.
Args:
image (np.ndarray): Input image. Currently, only RGB, uint8 images are supported.
image (np.ndarray): Input image. Currently, only RGB images are supported.
color_shift (float): The amount of color shift to apply. Default is 0.05.
intensity (float): Multiplication factor for noise values. Values of ~0.5 produce a noticeable,
yet acceptable level of noise. Default is 0.5.
random_state (Optional[int]): If specified, this will set the random seed for the noise generation,
ensuring consistent results for the same input and seed.
random_state (Optional[np.random.RandomState]): If specified, this will be random state used
for noise generation.
Returns:
np.ndarray: The noised image.
Raises:
TypeError: If the input image's dtype is not uint8 or if the image is not RGB.
TypeError: If the input image's dtype is not RGB.
"""
if image.dtype != np.uint8:
msg = "Image must have uint8 channel type"
raise TypeError(msg)
if not is_rgb_image(image):
msg = "Image must be RGB"
raise TypeError(msg)

one_over_255 = float(1.0 / 255.0)
image = multiply(image, one_over_255).astype(np.float32)
input_dtype = image.dtype
factor = 1

if input_dtype == np.uint8:
image = to_float(image)
factor = MAX_VALUES_BY_DTYPE[input_dtype]

hls = cv2.cvtColor(image, cv2.COLOR_RGB2HLS)
_, stddev = cv2.meanStdDev(hls)

Expand All @@ -824,8 +809,7 @@ def iso_noise(
luminance = hls[..., 1]
luminance += (luminance_noise / 255) * (1.0 - luminance)

image = cv2.cvtColor(hls, cv2.COLOR_HLS2RGB) * 255
return image.astype(np.uint8)
return cv2.cvtColor(hls, cv2.COLOR_HLS2RGB) * factor


def to_gray(img: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -945,6 +929,7 @@ def mask_from_bbox(img: np.ndarray, bbox: tuple[int, int, int, int]) -> np.ndarr
return mask


@clipped
def fancy_pca(img: np.ndarray, alpha: float = 0.1) -> np.ndarray:
"""Perform 'Fancy PCA' augmentation
Expand All @@ -965,7 +950,7 @@ def fancy_pca(img: np.ndarray, alpha: float = 0.1) -> np.ndarray:

orig_img = img.astype(float).copy()

img = img / 255.0 # rescale to 0 to 1 range
img = to_float(img) # rescale to 0 to 1 range

# flatten image to columns of RGB
img_rs = img.reshape(-1, 3)
Expand Down Expand Up @@ -1005,8 +990,7 @@ def fancy_pca(img: np.ndarray, alpha: float = 0.1) -> np.ndarray:

# for image processing it was found that working with float 0.0 to 1.0
# was easier than integers between 0-255
# > orig_img /= 255.0
return clip(orig_img, np.uint8)
return orig_img


@preserve_channel_dim
Expand Down Expand Up @@ -1201,8 +1185,9 @@ def spatter(
) -> np.ndarray:
non_rgb_warning(img)

coef = MAX_VALUES_BY_DTYPE[img.dtype]
img = img.astype(np.float32) * (1 / coef)
dtype = img.dtype

img = to_float(img)

if mode == "rain":
if rain is None:
Expand All @@ -1222,7 +1207,7 @@ def spatter(
else:
raise ValueError("Unsupported spatter mode: " + str(mode))

return img * 255
return from_float(img, dtype=dtype)


def almost_equal_intervals(n: int, parts: int) -> np.ndarray:
Expand Down Expand Up @@ -1525,16 +1510,13 @@ def planckian_jitter(img: np.ndarray, temperature: int, mode: PlanckianJitterMod

coeffs = w_left * np.array(PLANCKIAN_COEFFS[mode][t_left]) + w_right * np.array(PLANCKIAN_COEFFS[mode][t_right])

image = img / 255.0 if img.dtype == np.uint8 else img
image = to_float(img) if img.dtype == np.uint8 else img

image[:, :, 0] = image[:, :, 0] * (coeffs[0] / coeffs[1])
image[:, :, 2] = image[:, :, 2] * (coeffs[2] / coeffs[1])
image[image > 1] = 1

if img.dtype == np.uint8:
return image * 255.0

return image
return from_float(image, dtype=img.dtype) if img.dtype == np.uint8 else image


def generate_approx_gaussian_noise(
Expand Down
11 changes: 7 additions & 4 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,18 +1787,21 @@ class ISONoise(ImageOnlyTransform):
image
Image types:
uint8
uint8, float32
Raises:
TypeError: If the input image is not RGB.
"""

class InitSchema(BaseTransformInitSchema):
color_shift: tuple[float, float] = Field(
color_shift: Annotated[tuple[float, float], AfterValidator(check_01), AfterValidator(nondecreasing)] = Field(
default=(0.01, 0.05),
description=(
"Variance range for color hue change. Measured as a fraction of 360 degree Hue angle in HLS colorspace."
),
)
intensity: tuple[float, float] = Field(
intensity: Annotated[tuple[float, float], AfterValidator(check_0plus), AfterValidator(nondecreasing)] = Field(
default=(0.1, 0.5),
description="Multiplicative factor that control strength of color and luminance noise.",
)
Expand Down Expand Up @@ -1832,7 +1835,7 @@ def get_params(self) -> dict[str, Any]:
}

def get_transform_init_args_names(self) -> tuple[str, str]:
return ("intensity", "color_shift")
return "intensity", "color_shift"


class CLAHE(ImageOnlyTransform):
Expand Down
2 changes: 0 additions & 2 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def test_image_only_augmentations_mask_persists(augmentation_cls, params):
A.Equalize,
A.FancyPCA,
A.FromFloat,
A.ISONoise,
A.Posterize,
},
),
Expand Down Expand Up @@ -296,7 +295,6 @@ def test_augmentations_wont_change_input(augmentation_cls, params):
A.CLAHE,
A.Equalize,
A.FancyPCA,
A.ISONoise,
A.Posterize,
A.RandomSizedBBoxSafeCrop,
A.BBoxSafeRandomCrop,
Expand Down
Loading

0 comments on commit dae9bd6

Please sign in to comment.