Skip to content

Commit

Permalink
Add training status condition during image processing (#20677)
Browse files Browse the repository at this point in the history
* Add training status condition during image processing

* Revert "Add training status condition during image processing"

This reverts commit 8fc5ae2.

* Reapply "Add training status condition during image processing"

This reverts commit 25a4bd1.

* Revert center_crop
  • Loading branch information
shashaka authored Dec 22, 2024
1 parent 0d3ba37 commit 3dd958b
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 333 deletions.
55 changes: 33 additions & 22 deletions keras/src/layers/preprocessing/image_preprocessing/equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,40 +170,51 @@ def _apply_equalization(self, channel, hist):
)
return self.backend.numpy.take(lookup_table, indices)

def transform_images(self, images, transformations=None, **kwargs):
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
return images

def compute_output_shape(self, input_shape):
return input_shape

def compute_output_spec(self, inputs, **kwargs):
return inputs

def transform_bounding_boxes(self, bounding_boxes, **kwargs):
def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
return bounding_boxes

def transform_labels(self, labels, transformations=None, **kwargs):
def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

Expand Down
198 changes: 101 additions & 97 deletions keras/src/layers/preprocessing/image_preprocessing/random_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None):
return h_start, w_start

def transform_images(self, images, transformation, training=True):
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width
if training:
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width

if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]

shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
Expand All @@ -197,56 +198,59 @@ def transform_bounding_boxes(
"labels": (num_boxes, num_classes),
}
"""
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,

if training:
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,
)

# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)
# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)

return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}
return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
Expand Down
73 changes: 37 additions & 36 deletions keras/src/layers/preprocessing/image_preprocessing/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ def transform_bounding_boxes(
transformation,
training=True,
):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

def _flip_boxes_horizontal(boxes):
x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)
outputs = self.backend.numpy.concatenate(
Expand Down Expand Up @@ -134,46 +131,50 @@ def _transform_xyxy(boxes, box_flips):
)
return bboxes

flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)
if training:
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2
flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)

input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)
if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
)
input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)

bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)
bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)
bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)
bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)

self.backend.reset()
self.backend.reset()

return bounding_boxes

Expand Down
Loading

0 comments on commit 3dd958b

Please sign in to comment.