From 98cc52dc03931aed5775ebeba86465b493ae7463 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 21 Jul 2023 05:48:45 +0000 Subject: [PATCH 1/3] Converting RandomTranslation --- .../preprocessing/random_translation.py | 188 ++++++++++++++---- 1 file changed, 153 insertions(+), 35 deletions(-) diff --git a/keras_core/layers/preprocessing/random_translation.py b/keras_core/layers/preprocessing/random_translation.py index 6c3eef693..6bb3c553e 100644 --- a/keras_core/layers/preprocessing/random_translation.py +++ b/keras_core/layers/preprocessing/random_translation.py @@ -1,14 +1,11 @@ -import numpy as np - from keras_core import backend from keras_core.api_export import keras_core_export -from keras_core.layers.layer import Layer -from keras_core.utils import backend_utils -from keras_core.utils.module_utils import tensorflow as tf +from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer +from keras_core.random.seed_generator import SeedGenerator @keras_core_export("keras_core.layers.RandomTranslation") -class RandomTranslation(Layer): +class RandomTranslation(TFDataLayer): """A preprocessing layer which randomly translates images during training. This layer will apply random translations to each image during training, @@ -65,6 +62,13 @@ class RandomTranslation(Layer): `(..., height, width, channels)`, in `"channels_last"` format. """ + _FACTOR_VALIDATION_ERROR = ( + "The `factor` argument should be a number (or a list of two numbers) " + "in the range [-1.0, 1.0]. " + ) + _SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + def __init__( self, height_factor, @@ -73,43 +77,157 @@ def __init__( interpolation="bilinear", seed=None, fill_value=0.0, - name=None, + data_format=None, **kwargs, ): - if not tf.available: - raise ImportError( - "Layer RandomTranslation requires TensorFlow. " - "Install it via `pip install tensorflow`." - ) - - super().__init__(name=name) - self.layer = tf.keras.layers.RandomTranslation( - height_factor=height_factor, - width_factor=width_factor, - fill_mode=fill_mode, - interpolation=interpolation, - seed=seed, - fill_value=fill_value, - name=name, - **kwargs, + super().__init__(**kwargs) + self.height_factor = height_factor + self.height_lower, self.height_upper = self._set_factor( + height_factor, "height_factor" ) + self.width_factor = width_factor + self.width_lower, self.width_upper = self._set_factor( + width_factor, "width_factor" + ) + self._check_fill_mode_and_interpolation(fill_mode, interpolation) + + self.fill_mode = fill_mode + self.fill_value = fill_value + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + self.data_format = backend.standardize_data_format(data_format) + self.supports_jit = False - self._convert_input_args = False - self._allow_non_tensor_positional_args = True + + def _set_factor(self, factor, factor_name): + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + self._check_factor_range(factor[0]) + self._check_factor_range(factor[1]) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + self._check_factor_range(factor) + factor = abs(factor) + lower, upper = [-factor, factor] + else: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + return lower, upper + + def _check_factor_range(self, input_number): + if input_number > 1.0 or input_number < -1.0: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: input_number={input_number}" + ) + + def _check_fill_mode_and_interpolation(self, fill_mode, interpolation): + if fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) def call(self, inputs, training=True): - if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): - inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) - outputs = self.layer.call(inputs, training=training) - if ( - backend.backend() != "tensorflow" - and not backend_utils.in_tf_graph() - ): - outputs = backend.convert_to_tensor(outputs) + inputs = self.backend.cast(inputs, self.compute_dtype) + if training: + return self._randomly_translate_inputs(inputs) + else: + return inputs + + def _randomly_translate_inputs(self, inputs): + unbatched = len(inputs.shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + + batch_size = self.backend.shape(inputs)[0] + if self.data_format == "channels_first": + height = inputs.shape[-2] + width = inputs.shape[-1] + else: + height = inputs.shape[-3] + width = inputs.shape[-2] + + seed_generator = self._get_seed_generator(self.backend._backend) + height_translate = self.backend.random.uniform( + minval=self.height_lower, + maxval=self.height_upper, + shape=[batch_size, 1], + seed=seed_generator, + ) + height_translate = height_translate * height + width_translate = self.backend.random.uniform( + minval=self.width_lower, + maxval=self.width_upper, + shape=[batch_size, 1], + seed=seed_generator, + ) + width_translate = width_translate * width + translations = self.backend.cast( + self.backend.numpy.concatenate( + [width_translate, height_translate], axis=1 + ), + dtype="float32", + ) + + outputs = self.backend.image.affine_transform( + inputs, + transform=self._get_translation_matrix(translations), + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + if unbatched: + outputs = self.backend.numpy.squeeze(outputs, axis=0) return outputs + def _get_translation_matrix(self, translations): + num_translations = self.backend.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # translation matrices are always float32. + return self.backend.numpy.concatenate( + [ + self.backend.numpy.ones((num_translations, 1)), + self.backend.numpy.zeros((num_translations, 1)), + -translations[:, 0:1], + self.backend.numpy.zeros((num_translations, 1)), + self.backend.numpy.ones((num_translations, 1)), + -translations[:, 1:], + self.backend.numpy.zeros((num_translations, 2)), + ], + axis=1, + ) + def compute_output_shape(self, input_shape): - return tuple(self.layer.compute_output_shape(input_shape)) + return input_shape def get_config(self): - return self.layer.get_config() + base_config = super().get_config() + config = { + "height_factor": self.height_factor, + "width_factor": self.width_factor, + "fill_mode": self.fill_mode, + "interpolation": self.interpolation, + "seed": self.seed, + "fill_value": self.fill_value, + "data_format": self.data_format, + } + return {**base_config, **config} From df7bd1ab9a9d53b5126fa42c327982905d3f157b Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 21 Jul 2023 09:15:23 +0000 Subject: [PATCH 2/3] Update docstring --- .../preprocessing/random_translation.py | 93 +++++++++++-------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/keras_core/layers/preprocessing/random_translation.py b/keras_core/layers/preprocessing/random_translation.py index 6bb3c553e..de2726e0e 100644 --- a/keras_core/layers/preprocessing/random_translation.py +++ b/keras_core/layers/preprocessing/random_translation.py @@ -15,51 +15,64 @@ class RandomTranslation(TFDataLayer): of integer or floating point dtype. By default, the layer will output floats. - **Note:** This layer is safe to use inside a `tf.data` pipeline - (independently of which backend you're using). - - Args: - height_factor: a float represented as fraction of value, or a tuple of - size 2 representing lower and upper bound for shifting vertically. A - negative value means shifting image up, while a positive value means - shifting image down. When represented as a single positive float, this - value is used for both the upper and lower bound. For instance, - `height_factor=(-0.2, 0.3)` results in an output shifted by a random - amount in the range `[-20%, +30%]`. `height_factor=0.2` results in an - output height shifted by a random amount in the range `[-20%, +20%]`. - width_factor: a float represented as fraction of value, or a tuple of size - 2 representing lower and upper bound for shifting horizontally. A - negative value means shifting image left, while a positive value means - shifting image right. When represented as a single positive float, - this value is used for both the upper and lower bound. For instance, - `width_factor=(-0.2, 0.3)` results in an output shifted left by 20%, - and shifted right by 30%. `width_factor=0.2` results - in an output height shifted left or right by 20%. - fill_mode: Points outside the boundaries of the input are filled according - to the given mode - (one of `{"constant", "reflect", "wrap", "nearest"}`). - - *reflect*: `(d c b a | a b c d | d c b a)` The input is extended by - reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` The input is extended by - filling all values beyond the edge with the same constant value - k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by - wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` The input is extended by - the nearest pixel. - interpolation: Interpolation mode. Supported values: `"nearest"`, - `"bilinear"`. - seed: Integer. Used to create a random seed. - fill_value: a float represents the value to be filled outside the - boundaries when `fill_mode="constant"`. - Input shape: 3D (unbatched) or 4D (batched) tensor with shape: - `(..., height, width, channels)`, in `"channels_last"` format. + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. Output shape: 3D (unbatched) or 4D (batched) tensor with shape: - `(..., height, width, channels)`, in `"channels_last"` format. + `(..., target_height, target_width, channels)`, + or `(..., channels, target_height, target_width)`, + in `"channels_first"` format. + + **Note:** This layer is safe to use inside a `tf.data` pipeline + (independently of which backend you're using). + + Args: + height_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for shifting vertically. A + negative value means shifting image up, while a positive value means + shifting image down. When represented as a single positive float, + this value is used for both the upper and lower bound. For instance, + `height_factor=(-0.2, 0.3)` results in an output shifted by a random + amount in the range `[-20%, +30%]`. `height_factor=0.2` results in + an output height shifted by a random amount in the range + `[-20%, +20%]`. + width_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for shifting horizontally. + A negative value means shifting image left, while a positive value + means shifting image right. When represented as a single positive + float, this value is used for both the upper and lower bound. For + instance, `width_factor=(-0.2, 0.3)` results in an output shifted + left by 20%, and shifted right by 30%. `width_factor=0.2` results + in an output height shifted left or right by 20%. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode + (one of `{"constant", "reflect", "wrap", "nearest"}`). + - *reflect*: `(d c b a | a b c d | d c b a)` The input is extended + by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` The input is extended + by filling all values beyond the edge with the same constant + value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by + wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` The input is extended + by the nearest pixel. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + seed: Integer. Used to create a random seed. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode="constant"`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. """ _FACTOR_VALIDATION_ERROR = ( From 5de965d1a4b56bb8ed146ad51f0f1b1036a23061 Mon Sep 17 00:00:00 2001 From: chiuhongyu <20734616+james77777778@users.noreply.github.com> Date: Sat, 22 Jul 2023 15:26:27 +0800 Subject: [PATCH 3/3] Address comments --- .../preprocessing/random_translation.py | 24 ++- .../preprocessing/random_translation_test.py | 166 ++++++++++++++---- 2 files changed, 140 insertions(+), 50 deletions(-) diff --git a/keras_core/layers/preprocessing/random_translation.py b/keras_core/layers/preprocessing/random_translation.py index de2726e0e..1c9ea0a95 100644 --- a/keras_core/layers/preprocessing/random_translation.py +++ b/keras_core/layers/preprocessing/random_translation.py @@ -102,7 +102,17 @@ def __init__( self.width_lower, self.width_upper = self._set_factor( width_factor, "width_factor" ) - self._check_fill_mode_and_interpolation(fill_mode, interpolation) + + if fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) self.fill_mode = fill_mode self.fill_value = fill_value @@ -141,18 +151,6 @@ def _check_factor_range(self, input_number): + f"Received: input_number={input_number}" ) - def _check_fill_mode_and_interpolation(self, fill_mode, interpolation): - if fill_mode not in self._SUPPORTED_FILL_MODE: - raise NotImplementedError( - f"Unknown `fill_mode` {fill_mode}. Expected of one " - f"{self._SUPPORTED_FILL_MODE}." - ) - if interpolation not in self._SUPPORTED_INTERPOLATION: - raise NotImplementedError( - f"Unknown `interpolation` {interpolation}. Expected of one " - f"{self._SUPPORTED_INTERPOLATION}." - ) - def call(self, inputs, training=True): inputs = self.backend.cast(inputs, self.compute_dtype) if training: diff --git a/keras_core/layers/preprocessing/random_translation_test.py b/keras_core/layers/preprocessing/random_translation_test.py index f0f7da095..f88ac0357 100644 --- a/keras_core/layers/preprocessing/random_translation_test.py +++ b/keras_core/layers/preprocessing/random_translation_test.py @@ -26,8 +26,38 @@ def test_random_translation(self, height_factor, width_factor): run_training_check=False, ) - def test_random_translation_up_numeric_reflect(self): - input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)) + @parameterized.named_parameters( + ("bad_len", [0.1, 0.2, 0.3], 0.0), + ("bad_type", {"dummy": 0.3}, 0.0), + ("exceed_range_single", -1.1, 0.0), + ("exceed_range_tuple", (-1.1, 0.0), 0.0), + ) + def test_random_translation_with_bad_factor( + self, height_factor, width_factor + ): + with self.assertRaises(ValueError): + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": height_factor, + "width_factor": width_factor, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 4), + supports_masking=False, + run_training_check=False, + ) + + def test_random_translation_with_inference_mode(self): + input_data = np.random.random((1, 4, 4, 3)) + expected_output = input_data + layer = layers.RandomTranslation(0.2, 0.1) + output = layer(input_data, training=False) + self.assertAllClose(output, expected_output) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_up_numeric_reflect(self, data_format): + input_image = np.arange(0, 25) expected_output = np.asarray( [ [5, 6, 7, 8, 9], @@ -37,14 +67,22 @@ def test_random_translation_up_numeric_reflect(self): [20, 21, 22, 23, 24], ] ) - expected_output = backend.convert_to_tensor( - np.reshape(expected_output, (1, 5, 5, 1)) - ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) self.run_layer_test( layers.RandomTranslation, init_kwargs={ "height_factor": (-0.2, -0.2), "width_factor": 0.0, + "data_format": data_format, }, input_shape=None, input_data=input_image, @@ -53,10 +91,9 @@ def test_random_translation_up_numeric_reflect(self): run_training_check=False, ) - def test_random_translation_up_numeric_constant(self): - input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)).astype( - "float32" - ) + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_up_numeric_constant(self, data_format): + input_image = np.arange(0, 25).astype("float32") # Shifting by -.2 * 5 = 1 pixel. expected_output = np.asarray( [ @@ -67,15 +104,23 @@ def test_random_translation_up_numeric_constant(self): [0, 0, 0, 0, 0], ] ) - expected_output = backend.convert_to_tensor( - np.reshape(expected_output, (1, 5, 5, 1)), dtype="float32" - ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)), dtype="float32" + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)), dtype="float32" + ) self.run_layer_test( layers.RandomTranslation, init_kwargs={ "height_factor": (-0.2, -0.2), "width_factor": 0.0, "fill_mode": "constant", + "data_format": data_format, }, input_shape=None, input_data=input_image, @@ -84,8 +129,9 @@ def test_random_translation_up_numeric_constant(self): run_training_check=False, ) - def test_random_translation_down_numeric_reflect(self): - input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)) + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_down_numeric_reflect(self, data_format): + input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. expected_output = np.asarray( [ @@ -96,14 +142,22 @@ def test_random_translation_down_numeric_reflect(self): [15, 16, 17, 18, 19], ] ) - expected_output = backend.convert_to_tensor( - np.reshape(expected_output, (1, 5, 5, 1)) - ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) self.run_layer_test( layers.RandomTranslation, init_kwargs={ "height_factor": (0.2, 0.2), "width_factor": 0.0, + "data_format": data_format, }, input_shape=None, input_data=input_image, @@ -112,8 +166,11 @@ def test_random_translation_down_numeric_reflect(self): run_training_check=False, ) - def test_random_translation_asymmetric_size_numeric_reflect(self): - input_image = np.reshape(np.arange(0, 16), (1, 8, 2, 1)) + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_asymmetric_size_numeric_reflect( + self, data_format + ): + input_image = np.arange(0, 16) # Shifting by .2 * 5 = 1 pixel. expected_output = np.asarray( [ @@ -127,14 +184,22 @@ def test_random_translation_asymmetric_size_numeric_reflect(self): [6, 7], ] ) - expected_output = backend.convert_to_tensor( - np.reshape(expected_output, (1, 8, 2, 1)) - ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 8, 2, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 8, 2, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 8, 2)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 8, 2)) + ) self.run_layer_test( layers.RandomTranslation, init_kwargs={ "height_factor": (0.5, 0.5), "width_factor": 0.0, + "data_format": data_format, }, input_shape=None, input_data=input_image, @@ -143,8 +208,9 @@ def test_random_translation_asymmetric_size_numeric_reflect(self): run_training_check=False, ) - def test_random_translation_down_numeric_constant(self): - input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)) + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_down_numeric_constant(self, data_format): + input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. expected_output = np.asarray( [ @@ -155,9 +221,16 @@ def test_random_translation_down_numeric_constant(self): [15, 16, 17, 18, 19], ] ) - expected_output = backend.convert_to_tensor( - np.reshape(expected_output, (1, 5, 5, 1)) - ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) self.run_layer_test( layers.RandomTranslation, init_kwargs={ @@ -165,6 +238,7 @@ def test_random_translation_down_numeric_constant(self): "width_factor": 0.0, "fill_mode": "constant", "fill_value": 0.0, + "data_format": data_format, }, input_shape=None, input_data=input_image, @@ -173,8 +247,9 @@ def test_random_translation_down_numeric_constant(self): run_training_check=False, ) - def test_random_translation_left_numeric_reflect(self): - input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)) + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_left_numeric_reflect(self, data_format): + input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. expected_output = np.asarray( [ @@ -185,14 +260,22 @@ def test_random_translation_left_numeric_reflect(self): [21, 22, 23, 24, 24], ] ) - expected_output = backend.convert_to_tensor( - np.reshape(expected_output, (1, 5, 5, 1)) - ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) self.run_layer_test( layers.RandomTranslation, init_kwargs={ "height_factor": 0.0, "width_factor": (-0.2, -0.2), + "data_format": data_format, }, input_shape=None, input_data=input_image, @@ -201,8 +284,9 @@ def test_random_translation_left_numeric_reflect(self): run_training_check=False, ) - def test_random_translation_left_numeric_constant(self): - input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)) + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_left_numeric_constant(self, data_format): + input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. expected_output = np.asarray( [ @@ -213,9 +297,16 @@ def test_random_translation_left_numeric_constant(self): [21, 22, 23, 24, 0], ] ) - expected_output = backend.convert_to_tensor( - np.reshape(expected_output, (1, 5, 5, 1)) - ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) self.run_layer_test( layers.RandomTranslation, init_kwargs={ @@ -223,6 +314,7 @@ def test_random_translation_left_numeric_constant(self): "width_factor": (-0.2, -0.2), "fill_mode": "constant", "fill_value": 0.0, + "data_format": data_format, }, input_shape=None, input_data=input_image,