-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add random_posterization processing layer (#20688)
* Add random_posterization processing layer * Add test cases * correct failed case
- Loading branch information
Showing
5 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
keras/src/layers/preprocessing/image_preprocessing/random_posterization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
|
||
|
||
@keras_export("keras.layers.RandomPosterization") | ||
class RandomPosterization(BaseImagePreprocessingLayer): | ||
"""Reduces the number of bits for each color channel. | ||
References: | ||
- [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501) | ||
- [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719) | ||
Args: | ||
value_range: a tuple or a list of two elements. The first value | ||
represents the lower bound for values in passed images, the second | ||
represents the upper bound. Images passed to the layer should have | ||
values within `value_range`. Defaults to `(0, 255)`. | ||
factor: integer, the number of bits to keep for each channel. Must be a | ||
value between 1-8. | ||
""" | ||
|
||
_USE_BASE_FACTOR = False | ||
_FACTOR_BOUNDS = (1, 8) | ||
_MAX_FACTOR = 8 | ||
_VALUE_RANGE_VALIDATION_ERROR = ( | ||
"The `value_range` argument should be a list of two numbers. " | ||
) | ||
|
||
def __init__( | ||
self, | ||
factor, | ||
value_range=(0, 255), | ||
data_format=None, | ||
seed=None, | ||
**kwargs, | ||
): | ||
super().__init__(data_format=data_format, **kwargs) | ||
self._set_factor(factor) | ||
self._set_value_range(value_range) | ||
self.seed = seed | ||
self.generator = self.backend.random.SeedGenerator(seed) | ||
|
||
def _set_value_range(self, value_range): | ||
if not isinstance(value_range, (tuple, list)): | ||
raise ValueError( | ||
self._VALUE_RANGE_VALIDATION_ERROR | ||
+ f"Received: value_range={value_range}" | ||
) | ||
if len(value_range) != 2: | ||
raise ValueError( | ||
self._VALUE_RANGE_VALIDATION_ERROR | ||
+ f"Received: value_range={value_range}" | ||
) | ||
self.value_range = sorted(value_range) | ||
|
||
def get_random_transformation(self, data, training=True, seed=None): | ||
if isinstance(data, dict): | ||
images = data["images"] | ||
else: | ||
images = data | ||
images_shape = self.backend.shape(images) | ||
rank = len(images_shape) | ||
if rank == 3: | ||
batch_size = 1 | ||
elif rank == 4: | ||
batch_size = images_shape[0] | ||
else: | ||
raise ValueError( | ||
"Expected the input image to be rank 3 or 4. Received: " | ||
f"inputs.shape={images_shape}" | ||
) | ||
|
||
if seed is None: | ||
seed = self._get_seed_generator(self.backend._backend) | ||
|
||
if self.factor[0] != self.factor[1]: | ||
factor = self.backend.random.randint( | ||
(batch_size,), | ||
minval=self.factor[0], | ||
maxval=self.factor[1], | ||
seed=seed, | ||
dtype="uint8", | ||
) | ||
else: | ||
factor = ( | ||
self.backend.numpy.ones((batch_size,), dtype="uint8") | ||
* self.factor[0] | ||
) | ||
|
||
shift_factor = self._MAX_FACTOR - factor | ||
return {"shift_factor": shift_factor} | ||
|
||
def transform_images(self, images, transformation=None, training=True): | ||
if training: | ||
shift_factor = transformation["shift_factor"] | ||
|
||
shift_factor = self.backend.numpy.reshape( | ||
shift_factor, self.backend.shape(shift_factor) + (1, 1, 1) | ||
) | ||
|
||
images = self._transform_value_range( | ||
images, | ||
original_range=self.value_range, | ||
target_range=(0, 255), | ||
dtype=self.compute_dtype, | ||
) | ||
|
||
images = self.backend.cast(images, "uint8") | ||
images = self.backend.numpy.bitwise_left_shift( | ||
self.backend.numpy.bitwise_right_shift(images, shift_factor), | ||
shift_factor, | ||
) | ||
images = self.backend.cast(images, self.compute_dtype) | ||
|
||
images = self._transform_value_range( | ||
images, | ||
original_range=(0, 255), | ||
target_range=self.value_range, | ||
dtype=self.compute_dtype, | ||
) | ||
|
||
return images | ||
|
||
def transform_labels(self, labels, transformation, training=True): | ||
return labels | ||
|
||
def transform_segmentation_masks( | ||
self, segmentation_masks, transformation, training=True | ||
): | ||
return segmentation_masks | ||
|
||
def transform_bounding_boxes( | ||
self, bounding_boxes, transformation, training=True | ||
): | ||
return bounding_boxes | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"factor": self.factor, | ||
"value_range": self.value_range, | ||
"seed": self.seed, | ||
} | ||
) | ||
return config | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape |
85 changes: 85 additions & 0 deletions
85
keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
import pytest | ||
from tensorflow import data as tf_data | ||
|
||
import keras | ||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import testing | ||
|
||
|
||
class RandomPosterizationTest(testing.TestCase): | ||
@pytest.mark.requires_trainable_backend | ||
def test_layer(self): | ||
self.run_layer_test( | ||
layers.RandomPosterization, | ||
init_kwargs={ | ||
"factor": 1, | ||
"value_range": (20, 200), | ||
"seed": 1, | ||
}, | ||
input_shape=(8, 3, 4, 3), | ||
supports_masking=False, | ||
expected_output_shape=(8, 3, 4, 3), | ||
) | ||
|
||
def test_random_posterization_inference(self): | ||
seed = 3481 | ||
layer = layers.RandomPosterization(1, [0, 255]) | ||
np.random.seed(seed) | ||
inputs = np.random.randint(0, 255, size=(224, 224, 3)) | ||
output = layer(inputs, training=False) | ||
self.assertAllClose(inputs, output) | ||
|
||
def test_random_posterization_basic(self): | ||
seed = 3481 | ||
layer = layers.RandomPosterization( | ||
1, [0, 255], data_format="channels_last", seed=seed | ||
) | ||
np.random.seed(seed) | ||
inputs = np.asarray( | ||
[[[128.0, 235.0, 87.0], [12.0, 1.0, 23.0], [24.0, 18.0, 121.0]]] | ||
) | ||
output = layer(inputs) | ||
expected_output = np.asarray( | ||
[[[128.0, 128.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]] | ||
) | ||
self.assertAllClose(expected_output, output) | ||
|
||
def test_random_posterization_value_range_0_to_1(self): | ||
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) | ||
|
||
layer = layers.RandomPosterization(1, [0, 1.0]) | ||
adjusted_image = layer(image) | ||
|
||
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) | ||
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) | ||
|
||
def test_random_posterization_value_range_0_to_255(self): | ||
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255) | ||
|
||
layer = layers.RandomPosterization(1, [0, 255]) | ||
adjusted_image = layer(image) | ||
|
||
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) | ||
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255)) | ||
|
||
def test_random_posterization_randomness(self): | ||
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) | ||
|
||
layer = layers.RandomPosterization(1, [0, 255]) | ||
adjusted_images = layer(image) | ||
|
||
self.assertNotAllClose(adjusted_images, image) | ||
|
||
def test_tf_data_compatibility(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandomPosterization(1, [0, 255]) | ||
|
||
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) | ||
for output in ds.take(1): | ||
output.numpy() |