From 7731f74a980c6a130030e1cc5946393bc57a59bb Mon Sep 17 00:00:00 2001 From: James Tran Date: Sun, 16 Jul 2023 21:37:59 -0400 Subject: [PATCH] Remove dependency on tf.keras for RandomFlip --- .../layers/preprocessing/random_flip.py | 69 ++++++++++++++----- 1 file changed, 51 insertions(+), 18 deletions(-) diff --git a/keras_core/layers/preprocessing/random_flip.py b/keras_core/layers/preprocessing/random_flip.py index c036fd455..6601280ac 100644 --- a/keras_core/layers/preprocessing/random_flip.py +++ b/keras_core/layers/preprocessing/random_flip.py @@ -3,6 +3,9 @@ from keras_core import backend from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer +from keras_core.ops import convert_to_numpy +from keras_core.ops import convert_to_tensor +from keras_core.ops import shape from keras_core.utils import backend_utils from keras_core.utils.module_utils import tensorflow as tf @@ -52,33 +55,63 @@ def __init__( "Install it via `pip install tensorflow`." ) - super().__init__(name=name) + super().__init__(name=name, **kwargs) + + self.mode = mode + if mode == HORIZONTAL: + self.horizontal = True + self.vertical = False + elif mode == VERTICAL: + self.horizontal = False + self.vertical = True + elif mode == HORIZONTAL_AND_VERTICAL: + self.horizontal = True + self.vertical = True + else: + raise ValueError( + f"RandomFlip layer {self.name} received an unknown mode " + f"argument {mode}" + ) + self.seed = seed or backend.random.make_default_seed() - self.layer = tf.keras.layers.RandomFlip( - mode=mode, - name=name, - seed=self.seed, - **kwargs, - ) + self.supports_jit = False self._convert_input_args = False self._allow_non_tensor_positional_args = True def call(self, inputs, training=True): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): - inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) - outputs = self.layer.call(inputs, training=training) - if ( - backend.backend() != "tensorflow" - and not backend_utils.in_tf_graph() - ): - outputs = backend.convert_to_tensor(outputs) + inputs = convert_to_tensor(convert_to_numpy(inputs)) + + inputs = backend.cast(inputs, self.compute_dtype) + if training: + outputs = self._random_flipped_inputs(inputs) + else: + outputs = inputs + + if (backend.backend() != "tensorflow" and not backend_utils.in_tf_graph()): + outputs = convert_to_tensor(outputs) return outputs + def _random_flipped_inputs(self, inputs): + flipped_outputs = inputs + + if self.horizontal: + flipped_outputs = tf.image.stateless_random_flip_left_right(flipped_outputs, seed=[self.seed, 0]) + + if self.vertical: + flipped_outputs = tf.image.stateless_random_flip_up_down(flipped_outputs, seed=[self.seed, 0]) + + flipped_outputs.set_shape(shape(inputs)) + return flipped_outputs + def compute_output_shape(self, input_shape): - return tuple(self.layer.compute_output_shape(input_shape)) + return tuple(input_shape) def get_config(self): - config = self.layer.get_config() - config.update({"seed": self.seed}) - return config + config = { + "mode": self.mode, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config}