Skip to content

Commit

Permalink
Correct failed test case
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Dec 21, 2024
1 parent dcbad19 commit c366d3e
Showing 1 changed file with 39 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
BaseImagePreprocessingLayer,
)
from keras.src.random.seed_generator import SeedGenerator
from keras.src.utils import backend_utils


@keras_export("keras.layers.RandomColorJitter")
Expand Down Expand Up @@ -114,13 +115,47 @@ def __init__(
def transform_images(self, images, transformation, training=True):
if training:
if self.brightness_factor is not None:
images = self.random_brightness(images)
if backend_utils.in_tf_graph():
self.random_brightness.backend.set_backend("tensorflow")
transformation = (
self.random_brightness.get_random_transformation(
images,
seed=self._get_seed_generator(self.backend._backend),
)
)
images = self.random_brightness.transform_images(
images, transformation
)
if self.contrast_factor is not None:
images = self.random_contrast(images)
if backend_utils.in_tf_graph():
self.random_contrast.backend.set_backend("tensorflow")
transformation = self.random_contrast.get_random_transformation(
images, seed=self._get_seed_generator(self.backend._backend)
)
images = self.random_contrast.transform_images(
images, transformation
)
if self.saturation_factor is not None:
images = self.random_saturation(images)
if backend_utils.in_tf_graph():
self.random_saturation.backend.set_backend("tensorflow")
transformation = (
self.random_saturation.get_random_transformation(
images,
seed=self._get_seed_generator(self.backend._backend),
)
)
images = self.random_saturation.transform_images(
images, transformation
)
if self.hue_factor is not None:
images = self.random_hue(images)
if backend_utils.in_tf_graph():
self.random_hue.backend.set_backend("tensorflow")
transformation = self.random_hue.get_random_transformation(
images, seed=self._get_seed_generator(self.backend._backend)
)
images = self.random_hue.transform_images(
images, transformation
)
images = self.backend.cast(images, self.compute_dtype)
return images

Expand Down

0 comments on commit c366d3e

Please sign in to comment.