From 5b299743442b64afaeeec01e925ddbeb112aad3c Mon Sep 17 00:00:00 2001 From: Ugeun Park <37043543+shashaka@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:36:45 +0900 Subject: [PATCH] implement transform_bounding_boxes for random_shear (#20704) --- .../image_preprocessing/random_shear.py | 142 +++++++++++++++++- .../image_preprocessing/random_shear_test.py | 126 ++++++++++++++++ 2 files changed, 266 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_shear.py b/keras/src/layers/preprocessing/image_preprocessing/random_shear.py index 26b742e41fa..74390c77c77 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_shear.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_shear.py @@ -2,7 +2,14 @@ from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 BaseImagePreprocessingLayer, ) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils @keras_export("keras.layers.RandomShear") @@ -175,7 +182,7 @@ def get_random_transformation(self, data, training=True, seed=None): ) * invert ) - return {"shear_factor": shear_factor} + return {"shear_factor": shear_factor, "input_shape": images_shape} def transform_images(self, images, transformation, training=True): images = self.backend.cast(images, self.compute_dtype) @@ -231,13 +238,144 @@ def _get_shear_matrix(self, shear_factors): def transform_labels(self, labels, transformation, training=True): return labels + def get_transformed_x_y(self, x, y, transform): + a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split( + transform, 8, axis=-1 + ) + + k = c0 * x + c1 * y + 1 + x_transformed = (a0 * x + a1 * y + a2) / k + y_transformed = (b0 * x + b1 * y + b2) / k + return x_transformed, y_transformed + + def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor): + bboxes = bounding_boxes["boxes"] + x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1) + + w_shift_factor = self.backend.convert_to_tensor( + w_shift_factor, dtype=x1.dtype + ) + h_shift_factor = self.backend.convert_to_tensor( + h_shift_factor, dtype=x1.dtype + ) + + if len(bboxes.shape) == 3: + w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1) + h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1) + + bounding_boxes["boxes"] = self.backend.numpy.concatenate( + [ + x1 - w_shift_factor, + x2 - h_shift_factor, + x3 - w_shift_factor, + x4 - h_shift_factor, + ], + axis=-1, + ) + return bounding_boxes + def transform_bounding_boxes( self, bounding_boxes, transformation, training=True, ): - raise NotImplementedError + def _get_height_width(transformation): + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) + return input_height, input_width + + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + input_height, input_width = _get_height_width(transformation) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="rel_xyxy", + height=input_height, + width=input_width, + dtype=self.compute_dtype, + ) + + bounding_boxes = self._shear_bboxes(bounding_boxes, transformation) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="rel_xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + dtype=self.compute_dtype, + ) + + self.backend.reset() + + return bounding_boxes + + def _shear_bboxes(self, bounding_boxes, transformation): + shear_factor = self.backend.cast( + transformation["shear_factor"], dtype=self.compute_dtype + ) + shear_x_amount, shear_y_amount = self.backend.numpy.split( + shear_factor, 2, axis=-1 + ) + + x1, y1, x2, y2 = self.backend.numpy.split( + bounding_boxes["boxes"], 4, axis=-1 + ) + x1 = self.backend.numpy.squeeze(x1, axis=-1) + y1 = self.backend.numpy.squeeze(y1, axis=-1) + x2 = self.backend.numpy.squeeze(x2, axis=-1) + y2 = self.backend.numpy.squeeze(y2, axis=-1) + + if shear_x_amount is not None: + x1_top = x1 - (shear_x_amount * y1) + x1_bottom = x1 - (shear_x_amount * y2) + x1 = self.backend.numpy.where(shear_x_amount < 0, x1_top, x1_bottom) + + x2_top = x2 - (shear_x_amount * y1) + x2_bottom = x2 - (shear_x_amount * y2) + x2 = self.backend.numpy.where(shear_x_amount < 0, x2_bottom, x2_top) + + if shear_y_amount is not None: + y1_left = y1 - (shear_y_amount * x1) + y1_right = y1 - (shear_y_amount * x2) + y1 = self.backend.numpy.where(shear_y_amount > 0, y1_right, y1_left) + + y2_left = y2 - (shear_y_amount * x1) + y2_right = y2 - (shear_y_amount * x2) + y2 = self.backend.numpy.where(shear_y_amount > 0, y2_left, y2_right) + + boxes = self.backend.numpy.concatenate( + [ + self.backend.numpy.expand_dims(x1, axis=-1), + self.backend.numpy.expand_dims(y1, axis=-1), + self.backend.numpy.expand_dims(x2, axis=-1), + self.backend.numpy.expand_dims(y2, axis=-1), + ], + axis=-1, + ) + bounding_boxes["boxes"] = boxes + + return bounding_boxes def transform_segmentation_masks( self, segmentation_masks, transformation, training=True diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py index 70e1745d9dc..b1ec2861182 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py @@ -1,11 +1,13 @@ import numpy as np import pytest +from absl.testing import parameterized from tensorflow import data as tf_data import keras from keras.src import backend from keras.src import layers from keras.src import testing +from keras.src.utils import backend_utils class RandomShearTest(testing.TestCase): @@ -74,3 +76,127 @@ def test_tf_data_compatibility(self): ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) for output in ds.take(1): output.numpy() + + @parameterized.named_parameters( + ( + "with_x_shift", + [[1.0, 0.0]], + [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]], + ), + ( + "with_y_shift", + [[0.0, 1.0]], + [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]], + ), + ( + "with_xy_shift", + [[1.0, 1.0]], + [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]], + ), + ) + def test_random_shear_bounding_boxes(self, translation, expected_boxes): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + layer = layers.RandomShear( + x_factor=0.5, + y_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "shear_factor": backend_utils.convert_tf_tensor( + np.array(translation) + ), + "input_shape": image_shape, + } + output = layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation=transformation, + training=True, + ) + + self.assertAllClose(output["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "with_x_shift", + [[1.0, 0.0]], + [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]], + ), + ( + "with_y_shift", + [[0.0, 1.0]], + [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]], + ), + ( + "with_xy_shift", + [[1.0, 1.0]], + [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]], + ), + ) + def test_random_shear_tf_data_bounding_boxes( + self, translation, expected_boxes + ): + data_format = backend.config.image_data_format() + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + layer = layers.RandomShear( + x_factor=0.5, + y_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "shear_factor": backend_utils.convert_tf_tensor( + np.array(translation) + ), + "input_shape": image_shape, + } + + ds = ds.map( + lambda x: layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["boxes"], expected_boxes)