From 3dd958bcd6f7e9d5f7e007919205716ea1215df7 Mon Sep 17 00:00:00 2001 From: Ugeun Park <37043543+shashaka@users.noreply.github.com> Date: Mon, 23 Dec 2024 04:01:34 +0900 Subject: [PATCH] Add training status condition during image processing (#20677) * Add training status condition during image processing * Revert "Add training status condition during image processing" This reverts commit 8fc5ae2c28c239663fe0f2e8ac7fa15037f41a7d. * Reapply "Add training status condition during image processing" This reverts commit 25a4bd1332c7a5794dc872f5aa6ddddf6ed1606b. * Revert center_crop --- .../image_preprocessing/equalization.py | 55 +++-- .../image_preprocessing/random_crop.py | 198 +++++++++--------- .../image_preprocessing/random_flip.py | 73 +++---- .../image_preprocessing/random_grayscale.py | 24 ++- .../image_preprocessing/random_rotation.py | 61 +++--- .../image_preprocessing/random_translation.py | 87 ++++---- .../image_preprocessing/random_zoom.py | 139 ++++++------ .../image_preprocessing/solarization.py | 57 ++--- 8 files changed, 361 insertions(+), 333 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/equalization.py b/keras/src/layers/preprocessing/image_preprocessing/equalization.py index e58f4b254c0..555713bf854 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/equalization.py +++ b/keras/src/layers/preprocessing/image_preprocessing/equalization.py @@ -170,25 +170,31 @@ def _apply_equalization(self, channel, hist): ) return self.backend.numpy.take(lookup_table, indices) - def transform_images(self, images, transformations=None, **kwargs): - images = self.backend.cast(images, self.compute_dtype) - - if self.data_format == "channels_first": - channels = [] - for i in range(self.backend.core.shape(images)[-3]): - channel = images[..., i, :, :] - equalized = self._equalize_channel(channel, self.value_range) - channels.append(equalized) - equalized_images = self.backend.numpy.stack(channels, axis=-3) - else: - channels = [] - for i in range(self.backend.core.shape(images)[-1]): - channel = images[..., i] - equalized = self._equalize_channel(channel, self.value_range) - channels.append(equalized) - equalized_images = self.backend.numpy.stack(channels, axis=-1) - - return self.backend.cast(equalized_images, self.compute_dtype) + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + + if self.data_format == "channels_first": + channels = [] + for i in range(self.backend.core.shape(images)[-3]): + channel = images[..., i, :, :] + equalized = self._equalize_channel( + channel, self.value_range + ) + channels.append(equalized) + equalized_images = self.backend.numpy.stack(channels, axis=-3) + else: + channels = [] + for i in range(self.backend.core.shape(images)[-1]): + channel = images[..., i] + equalized = self._equalize_channel( + channel, self.value_range + ) + channels.append(equalized) + equalized_images = self.backend.numpy.stack(channels, axis=-1) + + return self.backend.cast(equalized_images, self.compute_dtype) + return images def compute_output_shape(self, input_shape): return input_shape @@ -196,14 +202,19 @@ def compute_output_shape(self, input_shape): def compute_output_spec(self, inputs, **kwargs): return inputs - def transform_bounding_boxes(self, bounding_boxes, **kwargs): + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): return bounding_boxes - def transform_labels(self, labels, transformations=None, **kwargs): + def transform_labels(self, labels, transformation, training=True): return labels def transform_segmentation_masks( - self, segmentation_masks, transformations, **kwargs + self, segmentation_masks, transformation, training=True ): return segmentation_masks diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py index 62571e69a93..f67469089f9 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py @@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None): return h_start, w_start def transform_images(self, images, transformation, training=True): - images = self.backend.cast(images, self.compute_dtype) - crop_box_hstart, crop_box_wstart = transformation - crop_height = self.height - crop_width = self.width + if training: + images = self.backend.cast(images, self.compute_dtype) + crop_box_hstart, crop_box_wstart = transformation + crop_height = self.height + crop_width = self.width - if self.data_format == "channels_last": - if len(images.shape) == 4: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - images = images[ - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - if len(images.shape) == 4: - images = images[ - :, - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] + if self.data_format == "channels_last": + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] else: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] + if len(images.shape) == 4: + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] - shape = self.backend.shape(images) - new_height = shape[self.height_axis] - new_width = shape[self.width_axis] - if ( - not isinstance(new_height, int) - or not isinstance(new_width, int) - or new_height != self.height - or new_width != self.width - ): - # Resize images if size mismatch or - # if size mismatch cannot be determined - # (in the case of a TF dynamic shape). - images = self.backend.image.resize( - images, - size=(self.height, self.width), - data_format=self.data_format, - ) - # Resize may have upcasted the outputs - images = self.backend.cast(images, self.compute_dtype) + shape = self.backend.shape(images) + new_height = shape[self.height_axis] + new_width = shape[self.width_axis] + if ( + not isinstance(new_height, int) + or not isinstance(new_width, int) + or new_height != self.height + or new_width != self.width + ): + # Resize images if size mismatch or + # if size mismatch cannot be determined + # (in the case of a TF dynamic shape). + images = self.backend.image.resize( + images, + size=(self.height, self.width), + data_format=self.data_format, + ) + # Resize may have upcasted the outputs + images = self.backend.cast(images, self.compute_dtype) return images def transform_labels(self, labels, transformation, training=True): @@ -197,56 +198,59 @@ def transform_bounding_boxes( "labels": (num_boxes, num_classes), } """ - h_start, w_start = transformation - if not self.backend.is_tensor(bounding_boxes["boxes"]): - bounding_boxes = densify_bounding_boxes( - bounding_boxes, backend=self.backend - ) - boxes = bounding_boxes["boxes"] - # Convert to a standard xyxy as operations are done xyxy by default. - boxes = convert_format( - boxes=boxes, - source=self.bounding_box_format, - target="xyxy", - height=self.height, - width=self.width, - ) - h_start = self.backend.cast(h_start, boxes.dtype) - w_start = self.backend.cast(w_start, boxes.dtype) - if len(self.backend.shape(boxes)) == 3: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), - ], - axis=-1, - ) - else: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), - ], - axis=-1, + + if training: + h_start, w_start = transformation + if not self.backend.is_tensor(bounding_boxes["boxes"]): + bounding_boxes = densify_bounding_boxes( + bounding_boxes, backend=self.backend + ) + boxes = bounding_boxes["boxes"] + # Convert to a standard xyxy as operations are done xyxy by default. + boxes = convert_format( + boxes=boxes, + source=self.bounding_box_format, + target="xyxy", + height=self.height, + width=self.width, ) + h_start = self.backend.cast(h_start, boxes.dtype) + w_start = self.backend.cast(w_start, boxes.dtype) + if len(self.backend.shape(boxes)) == 3: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), + ], + axis=-1, + ) + else: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), + ], + axis=-1, + ) - # Convert to user defined bounding box format - boxes = convert_format( - boxes=boxes, - source="xyxy", - target=self.bounding_box_format, - height=self.height, - width=self.width, - ) + # Convert to user defined bounding box format + boxes = convert_format( + boxes=boxes, + source="xyxy", + target=self.bounding_box_format, + height=self.height, + width=self.width, + ) - return { - "boxes": boxes, - "labels": bounding_boxes["labels"], - } + return { + "boxes": boxes, + "labels": bounding_boxes["labels"], + } + return bounding_boxes def transform_segmentation_masks( self, segmentation_masks, transformation, training=True diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_flip.py b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py index 519379685d1..83deff5fc05 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_flip.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py @@ -101,9 +101,6 @@ def transform_bounding_boxes( transformation, training=True, ): - if backend_utils.in_tf_graph(): - self.backend.set_backend("tensorflow") - def _flip_boxes_horizontal(boxes): x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1) outputs = self.backend.numpy.concatenate( @@ -134,46 +131,50 @@ def _transform_xyxy(boxes, box_flips): ) return bboxes - flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1) + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") - if self.data_format == "channels_first": - height_axis = -2 - width_axis = -1 - else: - height_axis = -3 - width_axis = -2 + flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1) - input_height, input_width = ( - transformation["input_shape"][height_axis], - transformation["input_shape"][width_axis], - ) + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 - bounding_boxes = convert_format( - bounding_boxes, - source=self.bounding_box_format, - target="rel_xyxy", - height=input_height, - width=input_width, - ) + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) - bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips) + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="rel_xyxy", + height=input_height, + width=input_width, + ) - bounding_boxes = clip_to_image_size( - bounding_boxes=bounding_boxes, - height=input_height, - width=input_width, - bounding_box_format="xyxy", - ) + bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips) - bounding_boxes = convert_format( - bounding_boxes, - source="rel_xyxy", - target=self.bounding_box_format, - height=input_height, - width=input_width, - ) + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + ) - self.backend.reset() + self.backend.reset() return bounding_boxes diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 804e9323a0f..e03a626852e 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -71,17 +71,21 @@ def get_random_transformation(self, images, training=True, seed=None): ) return should_apply - def transform_images(self, images, transformations=None, **kwargs): - should_apply = ( - transformations - if transformations is not None - else self.get_random_transformation(images) - ) + def transform_images(self, images, transformation, training=True): + if training: + should_apply = ( + transformation + if transformation is not None + else self.get_random_transformation(images) + ) - grayscale_images = self.backend.image.rgb_to_grayscale( - images, data_format=self.data_format - ) - return self.backend.numpy.where(should_apply, grayscale_images, images) + grayscale_images = self.backend.image.rgb_to_grayscale( + images, data_format=self.data_format + ) + return self.backend.numpy.where( + should_apply, grayscale_images, images + ) + return images def compute_output_shape(self, input_shape): return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py index 70221b9fa69..ea1e4b882fe 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py @@ -131,37 +131,38 @@ def transform_bounding_boxes( transformation, training=True, ): - ops = self.backend - boxes = bounding_boxes["boxes"] - height = transformation["image_height"] - width = transformation["image_width"] - batch_size = transformation["batch_size"] - boxes = converters.affine_transform( - boxes=boxes, - angle=transformation["angle"], - translate_x=ops.numpy.zeros([batch_size]), - translate_y=ops.numpy.zeros([batch_size]), - scale=ops.numpy.ones([batch_size]), - shear_x=ops.numpy.zeros([batch_size]), - shear_y=ops.numpy.zeros([batch_size]), - height=height, - width=width, - ) + if training: + ops = self.backend + boxes = bounding_boxes["boxes"] + height = transformation["image_height"] + width = transformation["image_width"] + batch_size = transformation["batch_size"] + boxes = converters.affine_transform( + boxes=boxes, + angle=transformation["angle"], + translate_x=ops.numpy.zeros([batch_size]), + translate_y=ops.numpy.zeros([batch_size]), + scale=ops.numpy.ones([batch_size]), + shear_x=ops.numpy.zeros([batch_size]), + shear_y=ops.numpy.zeros([batch_size]), + height=height, + width=width, + ) - bounding_boxes["boxes"] = boxes - bounding_boxes = converters.clip_to_image_size( - bounding_boxes, - height=height, - width=width, - bounding_box_format="xyxy", - ) - bounding_boxes = converters.convert_format( - bounding_boxes, - source="xyxy", - target=self.bounding_box_format, - height=height, - width=width, - ) + bounding_boxes["boxes"] = boxes + bounding_boxes = converters.clip_to_image_size( + bounding_boxes, + height=height, + width=width, + bounding_box_format="xyxy", + ) + bounding_boxes = converters.convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=height, + width=width, + ) return bounding_boxes def transform_segmentation_masks( diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_translation.py b/keras/src/layers/preprocessing/image_preprocessing/random_translation.py index 60e29e0a5b9..1dc69a0db45 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_translation.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_translation.py @@ -215,55 +215,56 @@ def transform_bounding_boxes( transformation, training=True, ): - if backend_utils.in_tf_graph(): - self.backend.set_backend("tensorflow") - - if self.data_format == "channels_first": - height_axis = -2 - width_axis = -1 - else: - height_axis = -3 - width_axis = -2 - - input_height, input_width = ( - transformation["input_shape"][height_axis], - transformation["input_shape"][width_axis], - ) + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) - bounding_boxes = convert_format( - bounding_boxes, - source=self.bounding_box_format, - target="xyxy", - height=input_height, - width=input_width, - ) + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=input_height, + width=input_width, + ) - translations = transformation["translations"] - transform = self._get_translation_matrix(translations) + translations = transformation["translations"] + transform = self._get_translation_matrix(translations) - w_shift_factor, h_shift_factor = self.get_transformed_x_y( - 0, 0, transform - ) - bounding_boxes = self.get_shifted_bbox( - bounding_boxes, w_shift_factor, h_shift_factor - ) + w_shift_factor, h_shift_factor = self.get_transformed_x_y( + 0, 0, transform + ) + bounding_boxes = self.get_shifted_bbox( + bounding_boxes, w_shift_factor, h_shift_factor + ) - bounding_boxes = clip_to_image_size( - bounding_boxes=bounding_boxes, - height=input_height, - width=input_width, - bounding_box_format="xyxy", - ) + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="xyxy", + ) - bounding_boxes = convert_format( - bounding_boxes, - source="xyxy", - target=self.bounding_box_format, - height=input_height, - width=input_width, - ) + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + ) - self.backend.reset() + self.backend.reset() return bounding_boxes diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py b/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py index ec0f03d1c2e..80b29b8e6ad 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py @@ -217,84 +217,87 @@ def transform_bounding_boxes( transformation, training=True, ): - if backend_utils.in_tf_graph(): - self.backend.set_backend("tensorflow") - - width_zoom = transformation["width_zoom"] - height_zoom = transformation["height_zoom"] - inputs_shape = transformation["input_shape"] - - if self.data_format == "channels_first": - height = inputs_shape[-2] - width = inputs_shape[-1] - else: - height = inputs_shape[-3] - width = inputs_shape[-2] - - bounding_boxes = convert_format( - bounding_boxes, - source=self.bounding_box_format, - target="xyxy", - height=height, - width=width, - ) + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + width_zoom = transformation["width_zoom"] + height_zoom = transformation["height_zoom"] + inputs_shape = transformation["input_shape"] + + if self.data_format == "channels_first": + height = inputs_shape[-2] + width = inputs_shape[-1] + else: + height = inputs_shape[-3] + width = inputs_shape[-2] + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=height, + width=width, + ) - zooms = self.backend.cast( - self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1), - dtype="float32", - ) - transform = self._get_zoom_matrix(zooms, height, width) + zooms = self.backend.cast( + self.backend.numpy.concatenate( + [width_zoom, height_zoom], axis=1 + ), + dtype="float32", + ) + transform = self._get_zoom_matrix(zooms, height, width) - w_start, h_start = self.get_transformed_x_y( - 0, - 0, - transform, - ) + w_start, h_start = self.get_transformed_x_y( + 0, + 0, + transform, + ) - w_end, h_end = self.get_transformed_x_y( - width, - height, - transform, - ) + w_end, h_end = self.get_transformed_x_y( + width, + height, + transform, + ) - bounding_boxes = self.get_clipped_bbox( - bounding_boxes, h_end, h_start, w_end, w_start - ) + bounding_boxes = self.get_clipped_bbox( + bounding_boxes, h_end, h_start, w_end, w_start + ) - height_transformed = h_end - h_start - width_transformed = w_end - w_start + height_transformed = h_end - h_start + width_transformed = w_end - w_start - height_transformed = self.backend.numpy.expand_dims( - height_transformed, -1 - ) - width_transformed = self.backend.numpy.expand_dims( - width_transformed, -1 - ) + height_transformed = self.backend.numpy.expand_dims( + height_transformed, -1 + ) + width_transformed = self.backend.numpy.expand_dims( + width_transformed, -1 + ) - bounding_boxes = convert_format( - bounding_boxes, - source="xyxy", - target="rel_xyxy", - height=height_transformed, - width=width_transformed, - ) + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target="rel_xyxy", + height=height_transformed, + width=width_transformed, + ) - bounding_boxes = clip_to_image_size( - bounding_boxes=bounding_boxes, - height=height_transformed, - width=width_transformed, - bounding_box_format="rel_xyxy", - ) + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=height_transformed, + width=width_transformed, + bounding_box_format="rel_xyxy", + ) - bounding_boxes = convert_format( - bounding_boxes, - source="rel_xyxy", - target=self.bounding_box_format, - height=height, - width=width, - ) + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=height, + width=width, + ) - self.backend.reset() + self.backend.reset() return bounding_boxes diff --git a/keras/src/layers/preprocessing/image_preprocessing/solarization.py b/keras/src/layers/preprocessing/image_preprocessing/solarization.py index a49d3930f8a..2a8fcee5efa 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/solarization.py +++ b/keras/src/layers/preprocessing/image_preprocessing/solarization.py @@ -156,33 +156,36 @@ def get_random_transformation(self, data, training=True, seed=None): def transform_images(self, images, transformation, training=True): images = self.backend.cast(images, self.compute_dtype) - if transformation is None: - return images - - thresholds = transformation["thresholds"] - additions = transformation["additions"] - images = self._transform_value_range( - images, - original_range=self.value_range, - target_range=(0, 255), - dtype=self.compute_dtype, - ) - results = images + additions - results = self.backend.numpy.clip(results, 0, 255) - results = self.backend.numpy.where( - results < thresholds, results, 255 - results - ) - results = self._transform_value_range( - results, - original_range=(0, 255), - target_range=self.value_range, - dtype=self.compute_dtype, - ) - if results.dtype == images.dtype: - return results - if backend.is_int_dtype(images.dtype): - results = self.backend.numpy.round(results) - return _saturate_cast(results, images.dtype, self.backend) + + if training: + if transformation is None: + return images + + thresholds = transformation["thresholds"] + additions = transformation["additions"] + images = self._transform_value_range( + images, + original_range=self.value_range, + target_range=(0, 255), + dtype=self.compute_dtype, + ) + results = images + additions + results = self.backend.numpy.clip(results, 0, 255) + results = self.backend.numpy.where( + results < thresholds, results, 255 - results + ) + results = self._transform_value_range( + results, + original_range=(0, 255), + target_range=self.value_range, + dtype=self.compute_dtype, + ) + if results.dtype == images.dtype: + return results + if backend.is_int_dtype(images.dtype): + results = self.backend.numpy.round(results) + return _saturate_cast(results, images.dtype, self.backend) + return images def transform_labels(self, labels, transformation, training=True): return labels