Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training status condition during image processing #20677

Merged
merged 4 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions keras/src/layers/preprocessing/image_preprocessing/equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,40 +170,51 @@ def _apply_equalization(self, channel, hist):
)
return self.backend.numpy.take(lookup_table, indices)

def transform_images(self, images, transformations=None, **kwargs):
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
return images

def compute_output_shape(self, input_shape):
return input_shape

def compute_output_spec(self, inputs, **kwargs):
return inputs

def transform_bounding_boxes(self, bounding_boxes, **kwargs):
def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
return bounding_boxes

def transform_labels(self, labels, transformations=None, **kwargs):
def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

Expand Down
198 changes: 101 additions & 97 deletions keras/src/layers/preprocessing/image_preprocessing/random_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None):
return h_start, w_start

def transform_images(self, images, transformation, training=True):
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width
if training:
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width

if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]

shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
Expand All @@ -197,56 +198,59 @@ def transform_bounding_boxes(
"labels": (num_boxes, num_classes),
}
"""
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,

if training:
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,
)

# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)
# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)

return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}
return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
Expand Down
73 changes: 37 additions & 36 deletions keras/src/layers/preprocessing/image_preprocessing/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ def transform_bounding_boxes(
transformation,
training=True,
):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

def _flip_boxes_horizontal(boxes):
x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)
outputs = self.backend.numpy.concatenate(
Expand Down Expand Up @@ -134,46 +131,50 @@ def _transform_xyxy(boxes, box_flips):
)
return bboxes

flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)
if training:
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2
flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)

input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)
if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
)
input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)

bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)
bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)
bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)
bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)

self.backend.reset()
self.backend.reset()

return bounding_boxes

Expand Down
Loading
Loading