Skip to content

Commit

Permalink
Add training status condition during image processing
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Dec 22, 2024
1 parent 0d3ba37 commit 8fc5ae2
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 437 deletions.
211 changes: 107 additions & 104 deletions keras/src/layers/preprocessing/image_preprocessing/center_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,84 +94,85 @@ def _get_clipped_bbox(bounding_boxes, h_end, h_start, w_end, w_start):
)
return bounding_boxes

input_shape = transformation["input_shape"]
if training:
input_shape = transformation["input_shape"]

init_height, init_width = _get_height_width(input_shape)
init_height, init_width = _get_height_width(input_shape)

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=init_height,
width=init_width,
)
bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=init_height,
width=init_width,
)

h_diff = init_height - self.height
w_diff = init_width - self.width
h_diff = init_height - self.height
w_diff = init_width - self.width

if h_diff >= 0 and w_diff >= 0:
h_start = int(h_diff / 2)
w_start = int(w_diff / 2)
if h_diff >= 0 and w_diff >= 0:
h_start = int(h_diff / 2)
w_start = int(w_diff / 2)

h_end = h_start + self.height
w_end = w_start + self.width
h_end = h_start + self.height
w_end = w_start + self.width

bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)
else:
width = init_width
height = init_height
target_height = self.height
target_width = self.width

crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)

h_start = crop_box_hstart
w_start = crop_box_wstart

h_end = crop_box_hstart + crop_height
w_end = crop_box_wstart + crop_width
bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)
else:
width = init_width
height = init_height
target_height = self.height
target_width = self.width

crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)

h_start = crop_box_hstart
w_start = crop_box_wstart

h_end = crop_box_hstart + crop_height
w_end = crop_box_wstart + crop_width
bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target="rel_xyxy",
height=crop_height,
width=crop_width,
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target="xyxy",
height=self.height,
width=self.width,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=self.height,
width=self.width,
bounding_box_format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target="rel_xyxy",
height=crop_height,
width=crop_width,
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=self.height,
width=self.width,
bounding_box_format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)

return bounding_boxes

def transform_segmentation_masks(
Expand All @@ -183,60 +184,62 @@ def transform_segmentation_masks(

def transform_images(self, images, transformation=None, training=True):
inputs = self.backend.cast(images, self.compute_dtype)
if self.data_format == "channels_first":
init_height = inputs.shape[-2]
init_width = inputs.shape[-1]
else:
init_height = inputs.shape[-3]
init_width = inputs.shape[-2]
if training:
if self.data_format == "channels_first":
init_height = inputs.shape[-2]
init_width = inputs.shape[-1]
else:
init_height = inputs.shape[-3]
init_width = inputs.shape[-2]

if init_height is None or init_width is None:
# Dynamic size case. TODO.
raise ValueError(
"At this time, CenterCrop can only "
"process images with a static spatial "
f"shape. Received: inputs.shape={inputs.shape}"
)
if init_height is None or init_width is None:
# Dynamic size case. TODO.
raise ValueError(
"At this time, CenterCrop can only "
"process images with a static spatial "
f"shape. Received: inputs.shape={inputs.shape}"
)

h_diff = init_height - self.height
w_diff = init_width - self.width
h_diff = init_height - self.height
w_diff = init_width - self.width

h_start = int(h_diff / 2)
w_start = int(w_diff / 2)
h_start = int(h_diff / 2)
w_start = int(w_diff / 2)

if h_diff >= 0 and w_diff >= 0:
if len(inputs.shape) == 4:
if self.data_format == "channels_first":
if h_diff >= 0 and w_diff >= 0:
if len(inputs.shape) == 4:
if self.data_format == "channels_first":
return inputs[
:,
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
return inputs[
:,
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
elif len(inputs.shape) == 3:
if self.data_format == "channels_first":
elif len(inputs.shape) == 3:
if self.data_format == "channels_first":
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
return inputs[
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
return image_utils.smart_resize(
inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
)
return image_utils.smart_resize(
inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
)
return images

def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
Expand Down
55 changes: 33 additions & 22 deletions keras/src/layers/preprocessing/image_preprocessing/equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,40 +170,51 @@ def _apply_equalization(self, channel, hist):
)
return self.backend.numpy.take(lookup_table, indices)

def transform_images(self, images, transformations=None, **kwargs):
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(channel, self.value_range)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channels = []
for i in range(self.backend.core.shape(images)[-3]):
channel = images[..., i, :, :]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-3)
else:
channels = []
for i in range(self.backend.core.shape(images)[-1]):
channel = images[..., i]
equalized = self._equalize_channel(
channel, self.value_range
)
channels.append(equalized)
equalized_images = self.backend.numpy.stack(channels, axis=-1)

return self.backend.cast(equalized_images, self.compute_dtype)
return images

def compute_output_shape(self, input_shape):
return input_shape

def compute_output_spec(self, inputs, **kwargs):
return inputs

def transform_bounding_boxes(self, bounding_boxes, **kwargs):
def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
return bounding_boxes

def transform_labels(self, labels, transformations=None, **kwargs):
def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

Expand Down
Loading

0 comments on commit 8fc5ae2

Please sign in to comment.