Skip to content

Commit

Permalink
Add random_posterization processing layer (#20688)
Browse files Browse the repository at this point in the history
* Add random_posterization processing layer

* Add test cases

* correct failed case
  • Loading branch information
shashaka authored Dec 27, 2024
1 parent 67d1ddf commit be1191f
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
RandomPosterization,
)
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
RandomPosterization,
)
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
RandomPosterization,
)
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)


@keras_export("keras.layers.RandomPosterization")
class RandomPosterization(BaseImagePreprocessingLayer):
"""Reduces the number of bits for each color channel.
References:
- [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501)
- [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719)
Args:
value_range: a tuple or a list of two elements. The first value
represents the lower bound for values in passed images, the second
represents the upper bound. Images passed to the layer should have
values within `value_range`. Defaults to `(0, 255)`.
factor: integer, the number of bits to keep for each channel. Must be a
value between 1-8.
"""

_USE_BASE_FACTOR = False
_FACTOR_BOUNDS = (1, 8)
_MAX_FACTOR = 8
_VALUE_RANGE_VALIDATION_ERROR = (
"The `value_range` argument should be a list of two numbers. "
)

def __init__(
self,
factor,
value_range=(0, 255),
data_format=None,
seed=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)
self._set_factor(factor)
self._set_value_range(value_range)
self.seed = seed
self.generator = self.backend.random.SeedGenerator(seed)

def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self.value_range = sorted(value_range)

def get_random_transformation(self, data, training=True, seed=None):
if isinstance(data, dict):
images = data["images"]
else:
images = data
images_shape = self.backend.shape(images)
rank = len(images_shape)
if rank == 3:
batch_size = 1
elif rank == 4:
batch_size = images_shape[0]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received: "
f"inputs.shape={images_shape}"
)

if seed is None:
seed = self._get_seed_generator(self.backend._backend)

if self.factor[0] != self.factor[1]:
factor = self.backend.random.randint(
(batch_size,),
minval=self.factor[0],
maxval=self.factor[1],
seed=seed,
dtype="uint8",
)
else:
factor = (
self.backend.numpy.ones((batch_size,), dtype="uint8")
* self.factor[0]
)

shift_factor = self._MAX_FACTOR - factor
return {"shift_factor": shift_factor}

def transform_images(self, images, transformation=None, training=True):
if training:
shift_factor = transformation["shift_factor"]

shift_factor = self.backend.numpy.reshape(
shift_factor, self.backend.shape(shift_factor) + (1, 1, 1)
)

images = self._transform_value_range(
images,
original_range=self.value_range,
target_range=(0, 255),
dtype=self.compute_dtype,
)

images = self.backend.cast(images, "uint8")
images = self.backend.numpy.bitwise_left_shift(
self.backend.numpy.bitwise_right_shift(images, shift_factor),
shift_factor,
)
images = self.backend.cast(images, self.compute_dtype)

images = self._transform_value_range(
images,
original_range=(0, 255),
target_range=self.value_range,
dtype=self.compute_dtype,
)

return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

def transform_bounding_boxes(
self, bounding_boxes, transformation, training=True
):
return bounding_boxes

def get_config(self):
config = super().get_config()
config.update(
{
"factor": self.factor,
"value_range": self.value_range,
"seed": self.seed,
}
)
return config

def compute_output_shape(self, input_shape):
return input_shape
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

import keras
from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandomPosterizationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomPosterization,
init_kwargs={
"factor": 1,
"value_range": (20, 200),
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_random_posterization_inference(self):
seed = 3481
layer = layers.RandomPosterization(1, [0, 255])
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_random_posterization_basic(self):
seed = 3481
layer = layers.RandomPosterization(
1, [0, 255], data_format="channels_last", seed=seed
)
np.random.seed(seed)
inputs = np.asarray(
[[[128.0, 235.0, 87.0], [12.0, 1.0, 23.0], [24.0, 18.0, 121.0]]]
)
output = layer(inputs)
expected_output = np.asarray(
[[[128.0, 128.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]
)
self.assertAllClose(expected_output, output)

def test_random_posterization_value_range_0_to_1(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)

layer = layers.RandomPosterization(1, [0, 1.0])
adjusted_image = layer(image)

self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))

def test_random_posterization_value_range_0_to_255(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255)

layer = layers.RandomPosterization(1, [0, 255])
adjusted_image = layer(image)

self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255))

def test_random_posterization_randomness(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)

layer = layers.RandomPosterization(1, [0, 255])
adjusted_images = layer(image)

self.assertNotAllClose(adjusted_images, image)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandomPosterization(1, [0, 255])

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

0 comments on commit be1191f

Please sign in to comment.