Skip to content

Commit

Permalink
Updated sampling for kernel in blurs
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Nov 21, 2024
1 parent 03c8836 commit 07aa3ce
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 12 deletions.
41 changes: 37 additions & 4 deletions albumentations/augmentations/blur/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import cv2
import numpy as np
from albucore import clipped, float32_io, maybe_process_in_chunks, preserve_channel_dim
from albucore import clipped, float32_io, maybe_process_in_chunks, preserve_channel_dim, uint8_io
from pydantic import ValidationInfo

from albumentations.augmentations.functional import convolve
Expand All @@ -26,10 +26,8 @@ def blur(img: np.ndarray, ksize: int) -> np.ndarray:


@preserve_channel_dim
@uint8_io
def median_blur(img: np.ndarray, ksize: int) -> np.ndarray:
if img.dtype == np.float32 and ksize not in {3, 5}:
raise ValueError(f"Invalid ksize value {ksize}. For a float32 image the only valid ksize values are 3 and 5")

blur_fn = maybe_process_in_chunks(cv2.medianBlur, ksize=ksize)
return blur_fn(img)

Expand Down Expand Up @@ -224,3 +222,38 @@ def create_motion_kernel(
kernel[center, center] = 1

return kernel


def sample_odd_from_range(random_state: Random, low: int, high: int) -> int:
"""Sample an odd number from the range [low, high] (inclusive).
Args:
random_state: instance of random.Random
low: lower bound (will be converted to nearest valid odd number)
high: upper bound (will be converted to nearest valid odd number)
Returns:
Randomly sampled odd number from the range
Note:
- Input values will be converted to nearest valid odd numbers:
* Values less than 3 will become 3
* Even values will be rounded up to next odd number
- After normalization, high must be >= low
"""
# Normalize low value
low = max(3, low + (low % 2 == 0))
# Normalize high value
high = max(3, high + (high % 2 == 0))

# Ensure high >= low after normalization
high = max(high, low)

if low == high:
return low

# Calculate number of possible odd values
num_odd_values = (high - low) // 2 + 1
# Generate random index and convert to corresponding odd number
rand_idx = random_state.randint(0, num_odd_values - 1)
return low + (2 * rand_idx)
25 changes: 17 additions & 8 deletions albumentations/augmentations/blur/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ def apply(self, img: np.ndarray, kernel: int, **params: Any) -> np.ndarray:
return fblur.blur(img, kernel)

def get_params(self) -> dict[str, Any]:
return {"kernel": self.random_generator.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2)))}
kernel = fblur.sample_odd_from_range(
self.py_random,
self.blur_limit[0],
self.blur_limit[1],
)
return {"kernel": kernel}

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("blur_limit",)
Expand Down Expand Up @@ -237,9 +242,11 @@ def apply(self, img: np.ndarray, kernel: np.ndarray, **params: Any) -> np.ndarra
return fmain.convolve(img, kernel=kernel)

def get_params(self) -> dict[str, Any]:
ksize = self.py_random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2)))
if ksize <= TWO:
raise ValueError(f"ksize must be > 2. Got: {ksize}")
ksize = fblur.sample_odd_from_range(
self.py_random,
self.blur_limit[0],
self.blur_limit[1],
)

angle = self.py_random.uniform(*self.angle_range)
direction = self.py_random.uniform(*self.direction_range)
Expand Down Expand Up @@ -411,9 +418,11 @@ def apply(self, img: np.ndarray, ksize: int, sigma: float, **params: Any) -> np.
return fblur.gaussian_blur(img, ksize, sigma=sigma)

def get_params(self) -> dict[str, float]:
ksize = self.py_random.randrange(self.blur_limit[0], self.blur_limit[1] + 1)
if ksize != 0 and ksize % 2 != 1:
ksize = (ksize + 1) % (self.blur_limit[1] + 1)
ksize = fblur.sample_odd_from_range(
self.py_random,
self.blur_limit[0],
self.blur_limit[1],
)

return {"ksize": ksize, "sigma": self.py_random.uniform(*self.sigma_limit)}

Expand Down Expand Up @@ -667,7 +676,7 @@ def apply(self, img: np.ndarray, kernel: np.ndarray, **params: Any) -> np.ndarra
return fmain.convolve(img, kernel=kernel)

def get_params(self) -> dict[str, np.ndarray]:
ksize = self.py_random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2)
ksize = fblur.sample_odd_from_range(self.py_random, self.blur_limit[0], self.blur_limit[1])
sigma_x = self.py_random.uniform(*self.sigma_x_limit)
sigma_y = self.py_random.uniform(*self.sigma_y_limit)
angle = np.deg2rad(self.py_random.uniform(*self.rotate_limit))
Expand Down
40 changes: 40 additions & 0 deletions tests/functional/test_blur.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from random import Random
import pytest
from albumentations.augmentations.blur import functional as fblur


@pytest.mark.parametrize(
"low, high, expected_range",
[
(-8, 7, {3, 5, 7}), # negative low
(2, 6, {3, 5, 7}), # even values
(1, 4, {3, 5}), # low < 3
(4, 4, {5}), # same even value
(3, 3, {3}), # same odd value
(2, 2, {3}), # same even value < 3
(-4, -2, {3}), # all negative values
],
ids=[
"negative_low",
"even_values",
"low_less_than_3",
"same_even_value",
"same_odd_value",
"same_even_value_less_than_3",
"all_negative",
]
)
def test_sample_odd_from_range(low: int, high: int, expected_range: set[int]):
"""Test sampling odd numbers from a range."""
random_state = Random(42)

results = set()
for _ in range(50): # Sample multiple times to get all possible values
value = fblur.sample_odd_from_range(random_state, low, high)
results.add(value)
# Verify each value is odd
assert value % 2 == 1
# Verify value is >= 3
assert value >= 3

assert results == expected_range, f"Failed for low={low}, high={high}"

0 comments on commit 07aa3ce

Please sign in to comment.