Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plasma brightness contrast #2152

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [Normalize](https://explore.albumentations.ai/transform/Normalize)
- [PixelDistributionAdaptation](https://explore.albumentations.ai/transform/PixelDistributionAdaptation)
- [PlanckianJitter](https://explore.albumentations.ai/transform/PlanckianJitter)
- [PlasmaBrightnessContrast](https://explore.albumentations.ai/transform/PlasmaBrightnessContrast)
- [Posterize](https://explore.albumentations.ai/transform/Posterize)
- [RGBShift](https://explore.albumentations.ai/transform/RGBShift)
- [RandomBrightness](https://explore.albumentations.ai/transform/RandomBrightness)
Expand Down
160 changes: 160 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,3 +2192,163 @@ def apply_salt_and_pepper(
result[salt_mask] = MAX_VALUES_BY_DTYPE[img.dtype]
result[pepper_mask] = 0
return result


def get_grid_size(size: int, target_shape: tuple[int, int]) -> int:
"""Round up to nearest power of 2."""
return 2 ** int(np.ceil(np.log2(max(size, *target_shape))))


def random_offset(current_size: int, total_size: int, roughness: float, random_generator: np.random.Generator) -> float:
"""Calculate random offset based on current grid size."""
return (random_generator.random() - 0.5) * (current_size / total_size) ** (roughness / 2)


def initialize_grid(grid_size: int, random_generator: np.random.Generator) -> np.ndarray:
"""Initialize grid with random corners."""
pattern = np.zeros((grid_size + 1, grid_size + 1), dtype=np.float32)
for corner in [(0, 0), (0, -1), (-1, 0), (-1, -1)]:
pattern[corner] = random_generator.random()
return pattern


def square_step(
pattern: np.ndarray,
y: int,
x: int,
step: int,
grid_size: int,
roughness: float,
random_generator: np.random.Generator,
) -> float:
"""Compute center value during square step."""
corners = [
pattern[y, x], # top-left
pattern[y, x + step], # top-right
pattern[y + step, x], # bottom-left
pattern[y + step, x + step], # bottom-right
]
return sum(corners) / 4.0 + random_offset(step, grid_size, roughness, random_generator)


def diamond_step(
pattern: np.ndarray,
y: int,
x: int,
half: int,
grid_size: int,
roughness: float,
random_generator: np.random.Generator,
) -> float:
"""Compute edge value during diamond step."""
points = []
if y >= half:
points.append(pattern[y - half, x])
if y + half <= grid_size:
points.append(pattern[y + half, x])
if x >= half:
points.append(pattern[y, x - half])
if x + half <= grid_size:
points.append(pattern[y, x + half])

return sum(points) / len(points) + random_offset(half * 2, grid_size, roughness, random_generator)


def generate_plasma_pattern(
target_shape: tuple[int, int],
size: int,
roughness: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate a plasma fractal pattern using the Diamond-Square algorithm.

The Diamond-Square algorithm creates a natural-looking noise pattern by recursively
subdividing a grid and adding random displacements at each step. The roughness
parameter controls how quickly the random displacements decrease with each iteration.

Args:
target_shape: Final shape (height, width) of the pattern
size: Initial size of the pattern grid. Will be rounded up to nearest power of 2.
Larger values create more detailed patterns.
roughness: Controls pattern roughness. Higher values create more rough/sharp transitions.
Typical values are between 1.0 and 5.0.
random_generator: NumPy random generator.

Returns:
Normalized plasma pattern array of shape target_shape with values in [0, 1]
"""
# Initialize grid
grid_size = get_grid_size(size, target_shape)
pattern = initialize_grid(grid_size, random_generator)

# Diamond-Square algorithm
step_size = grid_size
while step_size > 1:
half_step = step_size // 2

# Square step
for y in range(0, grid_size, step_size):
for x in range(0, grid_size, step_size):
if half_step > 0:
pattern[y + half_step, x + half_step] = square_step(
pattern,
y,
x,
step_size,
half_step,
roughness,
random_generator,
)

# Diamond step
for y in range(0, grid_size + 1, half_step):
for x in range((y + half_step) % step_size, grid_size + 1, step_size):
pattern[y, x] = diamond_step(pattern, y, x, half_step, grid_size, roughness, random_generator)

step_size = half_step

min_pattern = pattern.min()

# Normalize to [0, 1] range
pattern = (pattern - min_pattern) / (pattern.max() - min_pattern)

return (
fgeometric.resize(pattern, target_shape, interpolation=cv2.INTER_LINEAR)
if pattern.shape != target_shape
else pattern
)


@clipped
def apply_plasma_brightness_contrast(
img: np.ndarray,
brightness_factor: float,
contrast_factor: float,
plasma_pattern: np.ndarray,
) -> np.ndarray:
"""Apply plasma-based brightness and contrast adjustments.

The plasma pattern is used to create spatially-varying adjustments:
1. Brightness is modified by adding the pattern * brightness_factor
2. Contrast is modified by interpolating between mean and original
using the pattern * contrast_factor
"""
result = img.copy()

max_value = MAX_VALUES_BY_DTYPE[img.dtype]

# Expand plasma pattern to match image dimensions
plasma_pattern = plasma_pattern[..., np.newaxis] if img.ndim > MONO_CHANNEL_DIMENSIONS else plasma_pattern

# Apply brightness adjustment
if brightness_factor != 0:
brightness_adjustment = plasma_pattern * brightness_factor * max_value
result = np.clip(result + brightness_adjustment, 0, max_value)

# Apply contrast adjustment
if contrast_factor != 0:
mean = result.mean()
contrast_weights = plasma_pattern * contrast_factor + 1
result = np.clip(mean + (result - mean) * contrast_weights, 0, max_value)

return result
2 changes: 1 addition & 1 deletion albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def resize(img: np.ndarray, target_shape: tuple[int, int], interpolation: int) -
if target_shape == img.shape[:2]:
return img

height, width = target_shape
height, width = target_shape[:2]
resize_fn = maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation)
return resize_fn(img)

Expand Down
35 changes: 15 additions & 20 deletions albumentations/augmentations/tk/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
jpeg_quality: tuple[int, int] = (50, 50),
always_apply: bool = False,
p: float = 0.5,
always_apply: bool = False,
):
warn(
"RandomJPEG is a specialized version of ImageCompression. "
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
UserWarning,
stacklevel=2,
)
super().__init__(p=p, always_apply=always_apply)
super().__init__(p=p)


class RandomVerticalFlip(VerticalFlip):
Expand Down Expand Up @@ -342,8 +342,6 @@ class RandomPerspective(Perspective):
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomPerspective
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
distortion_scale: float = Field(ge=0, le=1)
fill: ColorType
Expand Down Expand Up @@ -439,8 +437,6 @@ class RandomAffine(Affine):
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomAffine
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
degrees: ScaleFloatType
translate: tuple[float, float]
Expand Down Expand Up @@ -760,8 +756,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
hue: tuple[float, float] = (0, 0),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomHue is a specialized version of ColorJitter. "
Expand Down Expand Up @@ -826,8 +822,8 @@ def __init__(
self,
clip_limit: float | tuple[float, float] = (1, 4),
tile_grid_size: tuple[int, int] = (8, 8),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomClahe is an alias for CLAHE transform. Consider using CLAHE directly from albumentations.CLAHE.",
Expand Down Expand Up @@ -892,8 +888,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
contrast: tuple[float, float] = (1, 1),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomContrast is a specialized version of RandomBrightnessContrast. "
Expand Down Expand Up @@ -958,8 +954,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
brightness: tuple[float, float] = (1, 1),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomBrightness is a specialized version of RandomBrightnessContrast. "
Expand Down Expand Up @@ -1027,8 +1023,8 @@ def __init__(
self,
num_drop_channels: int = 1,
fill_value: float = 0,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomChannelDropout is an alias for ChannelDropout transform. "
Expand Down Expand Up @@ -1087,8 +1083,8 @@ class InitSchema(BaseTransformInitSchema):

def __init__(
self,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomEqualize is a specialized version of Equalize transform. "
Expand Down Expand Up @@ -1159,8 +1155,8 @@ def __init__(
self,
kernel_size: ScaleIntType = (3, 7),
sigma: ScaleFloatType = 0,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomGaussianBlur is an alias for GaussianBlur transform. "
Expand All @@ -1172,7 +1168,6 @@ def __init__(
blur_limit=kernel_size,
sigma_limit=sigma,
p=p,
always_apply=always_apply,
)
self.kernel_size = kernel_size
self.sigma = sigma
Expand Down Expand Up @@ -1234,8 +1229,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
mode: Literal["blackbody", "cied"] = "blackbody",
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomPlanckianJitter is a specialized version of PlanckianJitter transform. "
Expand Down Expand Up @@ -1303,8 +1298,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
kernel_size: tuple[int, int] = (3, 3),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomMedianBlur is a specialized version of MedianBlur with a probability parameter. "
Expand Down Expand Up @@ -1370,8 +1365,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
thresholds: tuple[float, float] = (0.1, 0.1),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomSolarize is an alias for Solarize transform. "
Expand Down Expand Up @@ -1441,8 +1436,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
num_bits: tuple[int, int] = (3, 3),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomPosterize is an alias for Posterize transform. "
Expand Down Expand Up @@ -1503,8 +1498,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
saturation: tuple[float, float] = (1.0, 1.0),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
super().__init__(
brightness=(1.0, 1.0), # No brightness change
Expand Down Expand Up @@ -1568,8 +1563,8 @@ def __init__(
self,
mean: float = 0.0,
sigma: float = 0.1,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"GaussianNoise is a specialized version of GaussNoise that follows torchvision's API. "
Expand Down
Loading