Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training status condition during image processing #20677

Merged
merged 4 commits into from
Dec 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions keras/src/layers/preprocessing/image_preprocessing/equalization.py
Original file line number Diff line number Diff line change
@@ -170,40 +170,51 @@ def _apply_equalization(self, channel, hist):
)
return self.backend.numpy.take(lookup_table, indices)

def transform_images(self, images, transformations=None, **kwargs):
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
return images

def compute_output_shape(self, input_shape):
return input_shape

def compute_output_spec(self, inputs, **kwargs):
return inputs

def transform_bounding_boxes(self, bounding_boxes, **kwargs):
def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
return bounding_boxes

def transform_labels(self, labels, transformations=None, **kwargs):
def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

198 changes: 101 additions & 97 deletions keras/src/layers/preprocessing/image_preprocessing/random_crop.py
Original file line number Diff line number Diff line change
@@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None):
return h_start, w_start

def transform_images(self, images, transformation, training=True):
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width
if training:
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width

if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]

shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
@@ -197,56 +198,59 @@ def transform_bounding_boxes(
"labels": (num_boxes, num_classes),
}
"""
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,

if training:
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,
)

# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)
# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)

return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}
return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
73 changes: 37 additions & 36 deletions keras/src/layers/preprocessing/image_preprocessing/random_flip.py
Original file line number Diff line number Diff line change
@@ -101,9 +101,6 @@ def transform_bounding_boxes(
transformation,
training=True,
):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

def _flip_boxes_horizontal(boxes):
x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)
outputs = self.backend.numpy.concatenate(
@@ -134,46 +131,50 @@ def _transform_xyxy(boxes, box_flips):
)
return bboxes

flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)
if training:
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2
flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)

input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)
if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
)
input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)

bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)
bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)
bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)
bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)

self.backend.reset()
self.backend.reset()

return bounding_boxes

Original file line number Diff line number Diff line change
@@ -71,17 +71,21 @@ def get_random_transformation(self, images, training=True, seed=None):
)
return should_apply

def transform_images(self, images, transformations=None, **kwargs):
should_apply = (
transformations
if transformations is not None
else self.get_random_transformation(images)
)
def transform_images(self, images, transformation, training=True):
if training:
should_apply = (
transformation
if transformation is not None
else self.get_random_transformation(images)
)

grayscale_images = self.backend.image.rgb_to_grayscale(
images, data_format=self.data_format
)
return self.backend.numpy.where(should_apply, grayscale_images, images)
grayscale_images = self.backend.image.rgb_to_grayscale(
images, data_format=self.data_format
)
return self.backend.numpy.where(
should_apply, grayscale_images, images
)
return images

def compute_output_shape(self, input_shape):
return input_shape
Original file line number Diff line number Diff line change
@@ -131,37 +131,38 @@ def transform_bounding_boxes(
transformation,
training=True,
):
ops = self.backend
boxes = bounding_boxes["boxes"]
height = transformation["image_height"]
width = transformation["image_width"]
batch_size = transformation["batch_size"]
boxes = converters.affine_transform(
boxes=boxes,
angle=transformation["angle"],
translate_x=ops.numpy.zeros([batch_size]),
translate_y=ops.numpy.zeros([batch_size]),
scale=ops.numpy.ones([batch_size]),
shear_x=ops.numpy.zeros([batch_size]),
shear_y=ops.numpy.zeros([batch_size]),
height=height,
width=width,
)
if training:
ops = self.backend
boxes = bounding_boxes["boxes"]
height = transformation["image_height"]
width = transformation["image_width"]
batch_size = transformation["batch_size"]
boxes = converters.affine_transform(
boxes=boxes,
angle=transformation["angle"],
translate_x=ops.numpy.zeros([batch_size]),
translate_y=ops.numpy.zeros([batch_size]),
scale=ops.numpy.ones([batch_size]),
shear_x=ops.numpy.zeros([batch_size]),
shear_y=ops.numpy.zeros([batch_size]),
height=height,
width=width,
)

bounding_boxes["boxes"] = boxes
bounding_boxes = converters.clip_to_image_size(
bounding_boxes,
height=height,
width=width,
bounding_box_format="xyxy",
)
bounding_boxes = converters.convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
height=height,
width=width,
)
bounding_boxes["boxes"] = boxes
bounding_boxes = converters.clip_to_image_size(
bounding_boxes,
height=height,
width=width,
bounding_box_format="xyxy",
)
bounding_boxes = converters.convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
height=height,
width=width,
)
return bounding_boxes

def transform_segmentation_masks(
Original file line number Diff line number Diff line change
@@ -215,55 +215,56 @@ def transform_bounding_boxes(
transformation,
training=True,
):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2

input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)
if training:
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2

input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=input_height,
width=input_width,
)
bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=input_height,
width=input_width,
)

translations = transformation["translations"]
transform = self._get_translation_matrix(translations)
translations = transformation["translations"]
transform = self._get_translation_matrix(translations)

w_shift_factor, h_shift_factor = self.get_transformed_x_y(
0, 0, transform
)
bounding_boxes = self.get_shifted_bbox(
bounding_boxes, w_shift_factor, h_shift_factor
)
w_shift_factor, h_shift_factor = self.get_transformed_x_y(
0, 0, transform
)
bounding_boxes = self.get_shifted_bbox(
bounding_boxes, w_shift_factor, h_shift_factor
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)
bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)
bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)

self.backend.reset()
self.backend.reset()

return bounding_boxes

139 changes: 71 additions & 68 deletions keras/src/layers/preprocessing/image_preprocessing/random_zoom.py
Original file line number Diff line number Diff line change
@@ -217,84 +217,87 @@ def transform_bounding_boxes(
transformation,
training=True,
):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

width_zoom = transformation["width_zoom"]
height_zoom = transformation["height_zoom"]
inputs_shape = transformation["input_shape"]

if self.data_format == "channels_first":
height = inputs_shape[-2]
width = inputs_shape[-1]
else:
height = inputs_shape[-3]
width = inputs_shape[-2]

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=height,
width=width,
)
if training:
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

width_zoom = transformation["width_zoom"]
height_zoom = transformation["height_zoom"]
inputs_shape = transformation["input_shape"]

if self.data_format == "channels_first":
height = inputs_shape[-2]
width = inputs_shape[-1]
else:
height = inputs_shape[-3]
width = inputs_shape[-2]

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=height,
width=width,
)

zooms = self.backend.cast(
self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1),
dtype="float32",
)
transform = self._get_zoom_matrix(zooms, height, width)
zooms = self.backend.cast(
self.backend.numpy.concatenate(
[width_zoom, height_zoom], axis=1
),
dtype="float32",
)
transform = self._get_zoom_matrix(zooms, height, width)

w_start, h_start = self.get_transformed_x_y(
0,
0,
transform,
)
w_start, h_start = self.get_transformed_x_y(
0,
0,
transform,
)

w_end, h_end = self.get_transformed_x_y(
width,
height,
transform,
)
w_end, h_end = self.get_transformed_x_y(
width,
height,
transform,
)

bounding_boxes = self.get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)
bounding_boxes = self.get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)

height_transformed = h_end - h_start
width_transformed = w_end - w_start
height_transformed = h_end - h_start
width_transformed = w_end - w_start

height_transformed = self.backend.numpy.expand_dims(
height_transformed, -1
)
width_transformed = self.backend.numpy.expand_dims(
width_transformed, -1
)
height_transformed = self.backend.numpy.expand_dims(
height_transformed, -1
)
width_transformed = self.backend.numpy.expand_dims(
width_transformed, -1
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target="rel_xyxy",
height=height_transformed,
width=width_transformed,
)
bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target="rel_xyxy",
height=height_transformed,
width=width_transformed,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=height_transformed,
width=width_transformed,
bounding_box_format="rel_xyxy",
)
bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=height_transformed,
width=width_transformed,
bounding_box_format="rel_xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=height,
width=width,
)
bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=height,
width=width,
)

self.backend.reset()
self.backend.reset()

return bounding_boxes

57 changes: 30 additions & 27 deletions keras/src/layers/preprocessing/image_preprocessing/solarization.py
Original file line number Diff line number Diff line change
@@ -156,33 +156,36 @@ def get_random_transformation(self, data, training=True, seed=None):

def transform_images(self, images, transformation, training=True):
images = self.backend.cast(images, self.compute_dtype)
if transformation is None:
return images

thresholds = transformation["thresholds"]
additions = transformation["additions"]
images = self._transform_value_range(
images,
original_range=self.value_range,
target_range=(0, 255),
dtype=self.compute_dtype,
)
results = images + additions
results = self.backend.numpy.clip(results, 0, 255)
results = self.backend.numpy.where(
results < thresholds, results, 255 - results
)
results = self._transform_value_range(
results,
original_range=(0, 255),
target_range=self.value_range,
dtype=self.compute_dtype,
)
if results.dtype == images.dtype:
return results
if backend.is_int_dtype(images.dtype):
results = self.backend.numpy.round(results)
return _saturate_cast(results, images.dtype, self.backend)

if training:
if transformation is None:
return images

thresholds = transformation["thresholds"]
additions = transformation["additions"]
images = self._transform_value_range(
images,
original_range=self.value_range,
target_range=(0, 255),
dtype=self.compute_dtype,
)
results = images + additions
results = self.backend.numpy.clip(results, 0, 255)
results = self.backend.numpy.where(
results < thresholds, results, 255 - results
)
results = self._transform_value_range(
results,
original_range=(0, 255),
target_range=self.value_range,
dtype=self.compute_dtype,
)
if results.dtype == images.dtype:
return results
if backend.is_int_dtype(images.dtype):
results = self.backend.numpy.round(results)
return _saturate_cast(results, images.dtype, self.backend)
return images

def transform_labels(self, labels, transformation, training=True):
return labels
Loading