Skip to content

Commit

Permalink
Correct failed test case
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Dec 21, 2024
1 parent c366d3e commit 6b25c98
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(

def transform_images(self, images, transformation, training=True):
if training:
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")
images = self.backend.cast(images, self.compute_dtype)
if self.brightness_factor is not None:
if backend_utils.in_tf_graph():
self.random_brightness.backend.set_backend("tensorflow")
Expand All @@ -132,6 +135,9 @@ def transform_images(self, images, transformation, training=True):
transformation = self.random_contrast.get_random_transformation(
images, seed=self._get_seed_generator(self.backend._backend)
)
transformation["contrast_factor"] = self.backend.cast(
transformation["contrast_factor"], dtype=self.compute_dtype
)
images = self.random_contrast.transform_images(
images, transformation
)
Expand Down

0 comments on commit 6b25c98

Please sign in to comment.