Skip to content

Commit

Permalink
Add random_sharpness processing layer (#20697)
Browse files Browse the repository at this point in the history
* Add random_sharpness.py

* Update random_sharpness

* Add test cases

* Fix failed test case
  • Loading branch information
shashaka authored Dec 29, 2024
1 parent 2b073b6 commit 6ce93a4
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
RandomSaturation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
RandomSharpness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
RandomSaturation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
RandomSharpness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
RandomSaturation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
RandomSharpness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
168 changes: 168 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator


@keras_export("keras.layers.RandomSharpness")
class RandomSharpness(BaseImagePreprocessingLayer):
"""Randomly performs the sharpness operation on given images.
The sharpness operation first performs a blur, then blends between the
original image and the processed image. This operation adjusts the clarity
of the edges in an image, ranging from blurred to enhanced sharpness.
Args:
factor: A tuple of two floats or a single float.
`factor` controls the extent to which the image sharpness
is impacted. `factor=0.0` results in a fully blurred image,
`factor=0.5` applies no operation (preserving the original image),
and `factor=1.0` enhances the sharpness beyond the original. Values
should be between `0.0` and `1.0`. If a tuple is used, a `factor`
is sampled between the two values for every image augmented.
If a single float is used, a value between `0.0` and the passed
float is sampled. To ensure the value is always the same,
pass a tuple with two identical floats: `(0.5, 0.5)`.
value_range: the range of values the incoming images will have.
Represented as a two-number tuple written `[low, high]`. This is
typically either `[0, 1]` or `[0, 255]` depending on how your
preprocessing pipeline is set up.
seed: Integer. Used to create a random seed.
"""

_USE_BASE_FACTOR = False
_FACTOR_BOUNDS = (0, 1)

_VALUE_RANGE_VALIDATION_ERROR = (
"The `value_range` argument should be a list of two numbers. "
)

def __init__(
self,
factor,
value_range=(0, 255),
data_format=None,
seed=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)
self._set_factor(factor)
self._set_value_range(value_range)
self.seed = seed
self.generator = SeedGenerator(seed)

def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self.value_range = sorted(value_range)

def get_random_transformation(self, data, training=True, seed=None):
if isinstance(data, dict):
images = data["images"]
else:
images = data
images_shape = self.backend.shape(images)
rank = len(images_shape)
if rank == 3:
batch_size = 1
elif rank == 4:
batch_size = images_shape[0]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received: "
f"inputs.shape={images_shape}"
)

if seed is None:
seed = self._get_seed_generator(self.backend._backend)

factor = self.backend.random.uniform(
(batch_size,),
minval=self.factor[0],
maxval=self.factor[1],
seed=seed,
)
return {"factor": factor}

def transform_images(self, images, transformation=None, training=True):
images = self.backend.cast(images, self.compute_dtype)
if training:
if self.data_format == "channels_first":
images = self.backend.numpy.swapaxes(images, -3, -1)

sharpness_factor = self.backend.cast(
transformation["factor"] * 2, dtype=self.compute_dtype
)
sharpness_factor = self.backend.numpy.reshape(
sharpness_factor, (-1, 1, 1, 1)
)

num_channels = self.backend.shape(images)[-1]

a, b = 1.0 / 13.0, 5.0 / 13.0
kernel = self.backend.convert_to_tensor(
[[a, a, a], [a, b, a], [a, a, a]], dtype=self.compute_dtype
)
kernel = self.backend.numpy.reshape(kernel, (3, 3, 1, 1))
kernel = self.backend.numpy.tile(kernel, [1, 1, num_channels, 1])
kernel = self.backend.cast(kernel, self.compute_dtype)

smoothed_image = self.backend.nn.depthwise_conv(
images,
kernel,
strides=1,
padding="same",
data_format="channels_last",
)

smoothed_image = self.backend.cast(
smoothed_image, dtype=self.compute_dtype
)
images = images + (1.0 - sharpness_factor) * (
smoothed_image - images
)

images = self.backend.numpy.clip(
images, self.value_range[0], self.value_range[1]
)

if self.data_format == "channels_first":
images = self.backend.numpy.swapaxes(images, -3, -1)

return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

def transform_bounding_boxes(
self, bounding_boxes, transformation, training=True
):
return bounding_boxes

def get_config(self):
config = super().get_config()
config.update(
{
"factor": self.factor,
"value_range": self.value_range,
"seed": self.seed,
}
)
return config

def compute_output_shape(self, input_shape):
return input_shape
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

import keras
from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandomSharpnessTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomSharpness,
init_kwargs={
"factor": 0.75,
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_random_sharpness_value_range(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)

layer = layers.RandomSharpness(0.2)
adjusted_image = layer(image)

self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))

def test_random_sharpness_no_op(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))

layer = layers.RandomSharpness((0.5, 0.5))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)

def test_random_sharpness_randomness(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]

layer = layers.RandomSharpness(0.2)
adjusted_images = layer(image)

self.assertNotAllClose(adjusted_images, image)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandomSharpness(
factor=0.5, data_format=data_format, seed=1337
)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

0 comments on commit 6ce93a4

Please sign in to comment.