From aa02010b6d89b05b43d87fbd62d865c626f3b415 Mon Sep 17 00:00:00 2001
From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Date: Fri, 3 Nov 2023 20:49:51 -0700
Subject: [PATCH] fix tests for channels_first (#18723)

* fix tests for channels_first

* bug fix
---
 .../convolutional/conv_transpose_test.py      |   2 +-
 keras/ops/nn_test.py                          | 897 ++++++++++--------
 2 files changed, 504 insertions(+), 395 deletions(-)

diff --git a/keras/layers/convolutional/conv_transpose_test.py b/keras/layers/convolutional/conv_transpose_test.py
index 5b08f557720..d854be1dbf9 100644
--- a/keras/layers/convolutional/conv_transpose_test.py
+++ b/keras/layers/convolutional/conv_transpose_test.py
@@ -122,7 +122,7 @@ def np_conv2d_transpose(
         strides,
         padding,
         output_padding,
-        data_format,
+        "channels_last",
         dilation_rate,
     )
     jax_padding = compute_conv_transpose_padding_args_for_jax(
diff --git a/keras/ops/nn_test.py b/keras/ops/nn_test.py
index c19e82d1842..5254b351eb1 100644
--- a/keras/ops/nn_test.py
+++ b/keras/ops/nn_test.py
@@ -80,37 +80,80 @@ def test_log_softmax(self):
         self.assertEqual(knn.log_softmax(x, axis=-1).shape, (None, 2, 3))
 
     def test_max_pool(self):
-        x = KerasTensor([None, 8, 3])
-        self.assertEqual(knn.max_pool(x, 2, 1).shape, (None, 7, 3))
+        data_format = backend.config.image_data_format()
+        if data_format == "channels_last":
+            input_shape = (None, 8, 3)
+        else:
+            input_shape = (None, 3, 8)
+        x = KerasTensor(input_shape)
         self.assertEqual(
-            knn.max_pool(x, 2, 2, padding="same").shape, (None, 4, 3)
+            knn.max_pool(x, 2, 1).shape,
+            (None, 7, 3) if data_format == "channels_last" else (None, 3, 7),
+        )
+        self.assertEqual(
+            knn.max_pool(x, 2, 2, padding="same").shape,
+            (None, 4, 3) if data_format == "channels_last" else (None, 3, 4),
         )
 
-        x = KerasTensor([None, 8, None, 3])
-        self.assertEqual(knn.max_pool(x, 2, 1).shape, (None, 7, None, 3))
+        if data_format == "channels_last":
+            input_shape = (None, 8, None, 3)
+        else:
+            input_shape = (None, 3, 8, None)
+        x = KerasTensor(input_shape)
         self.assertEqual(
-            knn.max_pool(x, 2, 2, padding="same").shape, (None, 4, None, 3)
+            knn.max_pool(x, 2, 1).shape, (None, 7, None, 3)
+        ) if data_format == "channels_last" else (None, 3, 7, None)
+        self.assertEqual(
+            knn.max_pool(x, 2, 2, padding="same").shape,
+            (None, 4, None, 3)
+            if data_format == "channels_last"
+            else (None, 3, 4, None),
         )
         self.assertEqual(
             knn.max_pool(x, (2, 2), (2, 2), padding="same").shape,
-            (None, 4, None, 3),
+            (None, 4, None, 3)
+            if data_format == "channels_last"
+            else (None, 3, 4, None),
         )
 
     def test_average_pool(self):
-        x = KerasTensor([None, 8, 3])
-        self.assertEqual(knn.average_pool(x, 2, 1).shape, (None, 7, 3))
+        data_format = backend.config.image_data_format()
+        if data_format == "channels_last":
+            input_shape = (None, 8, 3)
+        else:
+            input_shape = (None, 3, 8)
+        x = KerasTensor(input_shape)
+        self.assertEqual(
+            knn.average_pool(x, 2, 1).shape,
+            (None, 7, 3) if data_format == "channels_last" else (None, 3, 7),
+        )
         self.assertEqual(
-            knn.average_pool(x, 2, 2, padding="same").shape, (None, 4, 3)
+            knn.average_pool(x, 2, 2, padding="same").shape,
+            (None, 4, 3) if data_format == "channels_last" else (None, 3, 4),
         )
 
-        x = KerasTensor([None, 8, None, 3])
-        self.assertEqual(knn.average_pool(x, 2, 1).shape, (None, 7, None, 3))
+        if data_format == "channels_last":
+            input_shape = (None, 8, None, 3)
+        else:
+            input_shape = (None, 3, 8, None)
+        x = KerasTensor(input_shape)
+        self.assertEqual(
+            knn.average_pool(x, 2, 1).shape,
+            (None, 7, None, 3)
+            if data_format == "channels_last"
+            else (None, 3, 7, None),
+        )
         self.assertEqual(
-            knn.average_pool(x, 2, 2, padding="same").shape, (None, 4, None, 3)
+            knn.average_pool(x, 2, 2, padding="same").shape,
+            (None, 4, None, 3)
+            if data_format == "channels_last"
+            else (None, 3, 4, None),
         )
         self.assertEqual(
             knn.average_pool(x, (2, 2), (2, 2), padding="same").shape,
-            (None, 4, None, 3),
+            (None, 4, None, 3)
+            if data_format == "channels_last"
+            else (None, 3, 4, None),
         )
 
     def test_multi_hot(self):
@@ -127,205 +170,297 @@ def test_multi_hot_dtype(self, dtype):
         self.assertEqual(backend.standardize_dtype(out.dtype), dtype)
 
     def test_conv(self):
+        data_format = backend.config.image_data_format()
         # Test 1D conv.
-        inputs_1d = KerasTensor([None, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 20, 3)
+        else:
+            input_shape = (None, 3, 20)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([4, 3, 2])
         for padding in ["valid", "VALID"]:
             self.assertEqual(
                 knn.conv(inputs_1d, kernel, 1, padding=padding).shape,
-                (None, 17, 2),
+                (None, 17, 2)
+                if data_format == "channels_last"
+                else (None, 2, 17),
             )
         for padding in ["same", "SAME"]:
             self.assertEqual(
                 knn.conv(inputs_1d, kernel, 1, padding=padding).shape,
-                (None, 20, 2),
+                (None, 20, 2)
+                if data_format == "channels_last"
+                else (None, 2, 20),
             )
         self.assertEqual(
             knn.conv(inputs_1d, kernel, (2,), dilation_rate=2).shape,
-            (None, 7, 2),
+            (None, 7, 2) if data_format == "channels_last" else (None, 2, 7),
         )
 
         # Test 2D conv.
-        inputs_2d = KerasTensor([None, 10, None, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 10, None, 3)
+        else:
+            input_shape = (None, 3, 10, None)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 3, 2])
         for padding in ["valid", "VALID"]:
             self.assertEqual(
                 knn.conv(inputs_2d, kernel, 1, padding=padding).shape,
-                (None, 9, None, 2),
+                (None, 9, None, 2)
+                if data_format == "channels_last"
+                else (None, 2, 9, None),
             )
         for padding in ["same", "SAME"]:
             self.assertEqual(
                 knn.conv(inputs_2d, kernel, 1, padding=padding).shape,
-                (None, 10, None, 2),
+                (None, 10, None, 2)
+                if data_format == "channels_last"
+                else (None, 2, 10, None),
             )
         self.assertEqual(
             knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape,
-            (None, 4, None, 2),
+            (None, 4, None, 2)
+            if data_format == "channels_last"
+            else (None, 2, 4, None),
         )
 
         # Test 2D conv - H, W specified
-        inputs_2d = KerasTensor([None, 10, 10, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 10, 10, 3)
+        else:
+            input_shape = (None, 3, 10, 10)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 3, 2])
         for padding in ["valid", "VALID"]:
             self.assertEqual(
                 knn.conv(inputs_2d, kernel, 1, padding=padding).shape,
-                (None, 9, 9, 2),
+                (None, 9, 9, 2)
+                if data_format == "channels_last"
+                else (None, 2, 9, 9),
             )
         for padding in ["same", "SAME"]:
             self.assertEqual(
                 knn.conv(inputs_2d, kernel, 1, padding=padding).shape,
-                (None, 10, 10, 2),
+                (None, 10, 10, 2)
+                if data_format == "channels_last"
+                else (None, 2, 10, 10),
             )
         self.assertEqual(
             knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape,
-            (None, 4, 9, 2),
+            (None, 4, 9, 2)
+            if data_format == "channels_last"
+            else (None, 2, 4, 9),
         )
 
         # Test 3D conv.
-        inputs_3d = KerasTensor([None, 8, None, 8, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 8, None, 8, 3)
+        else:
+            input_shape = (None, 3, 8, None, 8)
+        inputs_3d = KerasTensor(input_shape)
         kernel = KerasTensor([3, 3, 3, 3, 2])
         for padding in ["valid", "VALID"]:
             self.assertEqual(
                 knn.conv(inputs_3d, kernel, 1, padding=padding).shape,
-                (None, 6, None, 6, 2),
+                (None, 6, None, 6, 2)
+                if data_format == "channels_last"
+                else (None, 2, 6, None, 6),
             )
         for padding in ["same", "SAME"]:
             self.assertEqual(
                 knn.conv(inputs_3d, kernel, (2, 1, 2), padding=padding).shape,
-                (None, 4, None, 4, 2),
+                (None, 4, None, 4, 2)
+                if data_format == "channels_last"
+                else (None, 2, 4, None, 4),
             )
         self.assertEqual(
             knn.conv(
                 inputs_3d, kernel, 1, padding="valid", dilation_rate=(1, 2, 2)
             ).shape,
-            (None, 6, None, 4, 2),
+            (None, 6, None, 4, 2)
+            if data_format == "channels_last"
+            else (None, 2, 6, None, 4),
         )
 
     def test_depthwise_conv(self):
+        data_format = backend.config.image_data_format()
         # Test 1D depthwise conv.
-        inputs_1d = KerasTensor([None, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 20, 3)
+        else:
+            input_shape = (None, 3, 20)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([4, 3, 1])
         for padding in ["valid", "VALID"]:
             self.assertEqual(
                 knn.depthwise_conv(inputs_1d, kernel, 1, padding=padding).shape,
-                (None, 17, 3),
+                (None, 17, 3)
+                if data_format == "channels_last"
+                else (None, 3, 17),
             )
         for padding in ["same", "SAME"]:
             self.assertEqual(
                 knn.depthwise_conv(
                     inputs_1d, kernel, (1,), padding=padding
                 ).shape,
-                (None, 20, 3),
+                (None, 20, 3)
+                if data_format == "channels_last"
+                else (None, 3, 20),
             )
         self.assertEqual(
             knn.depthwise_conv(inputs_1d, kernel, 2, dilation_rate=2).shape,
-            (None, 7, 3),
+            (None, 7, 3) if data_format == "channels_last" else (None, 3, 7),
         )
 
         # Test 2D depthwise conv.
-        inputs_2d = KerasTensor([None, 10, 10, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 10, 10, 3)
+        else:
+            input_shape = (None, 3, 10, 10)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 3, 1])
         for padding in ["valid", "VALID"]:
             self.assertEqual(
                 knn.depthwise_conv(inputs_2d, kernel, 1, padding=padding).shape,
-                (None, 9, 9, 3),
+                (None, 9, 9, 3)
+                if data_format == "channels_last"
+                else (None, 3, 9, 9),
             )
         for padding in ["same", "SAME"]:
             self.assertEqual(
                 knn.depthwise_conv(
                     inputs_2d, kernel, (1, 2), padding=padding
                 ).shape,
-                (None, 10, 5, 3),
+                (None, 10, 5, 3)
+                if data_format == "channels_last"
+                else (None, 3, 10, 5),
             )
         self.assertEqual(
             knn.depthwise_conv(inputs_2d, kernel, 2, dilation_rate=2).shape,
-            (None, 4, 4, 3),
+            (None, 4, 4, 3)
+            if data_format == "channels_last"
+            else (None, 3, 4, 4),
         )
         self.assertEqual(
             knn.depthwise_conv(
                 inputs_2d, kernel, 2, dilation_rate=(2, 1)
             ).shape,
-            (None, 4, 5, 3),
+            (None, 4, 5, 3)
+            if data_format == "channels_last"
+            else (None, 3, 4, 5),
         )
 
     def test_separable_conv(self):
+        data_format = backend.config.image_data_format()
         # Test 1D separable conv.
-        inputs_1d = KerasTensor([None, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 20, 3)
+        else:
+            input_shape = (None, 3, 20)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([4, 3, 2])
         pointwise_kernel = KerasTensor([1, 6, 5])
         self.assertEqual(
             knn.separable_conv(
                 inputs_1d, kernel, pointwise_kernel, 1, padding="valid"
             ).shape,
-            (None, 17, 5),
+            (None, 17, 5) if data_format == "channels_last" else (None, 5, 17),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_1d, kernel, pointwise_kernel, 1, padding="same"
             ).shape,
-            (None, 20, 5),
+            (None, 20, 5) if data_format == "channels_last" else (None, 5, 20),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_1d, kernel, pointwise_kernel, 2, dilation_rate=2
             ).shape,
-            (None, 7, 5),
+            (None, 7, 5) if data_format == "channels_last" else (None, 5, 7),
         )
 
         # Test 2D separable conv.
-        inputs_2d = KerasTensor([None, 10, 10, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 10, 10, 3)
+        else:
+            input_shape = (None, 3, 10, 10)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 3, 2])
         pointwise_kernel = KerasTensor([1, 1, 6, 5])
         self.assertEqual(
             knn.separable_conv(
                 inputs_2d, kernel, pointwise_kernel, 1, padding="valid"
             ).shape,
-            (None, 9, 9, 5),
+            (None, 9, 9, 5)
+            if data_format == "channels_last"
+            else (None, 5, 9, 9),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_2d, kernel, pointwise_kernel, (1, 2), padding="same"
             ).shape,
-            (None, 10, 5, 5),
+            (None, 10, 5, 5)
+            if data_format == "channels_last"
+            else (None, 5, 10, 5),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_2d, kernel, pointwise_kernel, 2, dilation_rate=(2, 1)
             ).shape,
-            (None, 4, 5, 5),
+            (None, 4, 5, 5)
+            if data_format == "channels_last"
+            else (None, 5, 4, 5),
         )
 
     def test_conv_transpose(self):
-        inputs_1d = KerasTensor([None, 4, 3])
+        data_format = backend.config.image_data_format()
+        if data_format == "channels_last":
+            input_shape = (None, 4, 3)
+        else:
+            input_shape = (None, 3, 4)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 5, 3])
         self.assertEqual(
-            knn.conv_transpose(inputs_1d, kernel, 2).shape, (None, 8, 5)
+            knn.conv_transpose(inputs_1d, kernel, 2).shape,
+            (None, 8, 5) if data_format == "channels_last" else (None, 5, 8),
         )
         self.assertEqual(
             knn.conv_transpose(inputs_1d, kernel, 2, padding="same").shape,
-            (None, 8, 5),
+            (None, 8, 5) if data_format == "channels_last" else (None, 5, 8),
         )
         self.assertEqual(
             knn.conv_transpose(
                 inputs_1d, kernel, 5, padding="valid", output_padding=4
             ).shape,
-            (None, 21, 5),
+            (None, 21, 5) if data_format == "channels_last" else (None, 5, 21),
         )
 
-        inputs_2d = KerasTensor([None, 4, 4, 3])
+        if data_format == "channels_last":
+            input_shape = (None, 4, 4, 3)
+        else:
+            input_shape = (None, 3, 4, 4)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 5, 3])
         self.assertEqual(
-            knn.conv_transpose(inputs_2d, kernel, 2).shape, (None, 8, 8, 5)
+            knn.conv_transpose(inputs_2d, kernel, 2).shape,
+            (None, 8, 8, 5)
+            if data_format == "channels_last"
+            else (None, 5, 8, 8),
         )
         self.assertEqual(
             knn.conv_transpose(inputs_2d, kernel, (2, 2), padding="same").shape,
-            (None, 8, 8, 5),
+            (None, 8, 8, 5)
+            if data_format == "channels_last"
+            else (None, 5, 8, 8),
         )
         self.assertEqual(
             knn.conv_transpose(
                 inputs_2d, kernel, (5, 5), padding="valid", output_padding=4
             ).shape,
-            (None, 21, 21, 5),
+            (None, 21, 21, 5)
+            if data_format == "channels_last"
+            else (None, 5, 21, 21),
         )
 
     def test_one_hot(self):
@@ -418,199 +553,293 @@ def test_log_softmax(self):
         self.assertEqual(knn.log_softmax(x, axis=-1).shape, (1, 2, 3))
 
     def test_max_pool(self):
-        x = KerasTensor([1, 8, 3])
-        self.assertEqual(knn.max_pool(x, 2, 1).shape, (1, 7, 3))
-        self.assertEqual(knn.max_pool(x, 2, 2, padding="same").shape, (1, 4, 3))
+        data_format = backend.config.image_data_format()
+        if data_format == "channels_last":
+            input_shape = (1, 8, 3)
+        else:
+            input_shape = (1, 3, 8)
+        x = KerasTensor(input_shape)
+        self.assertEqual(
+            knn.max_pool(x, 2, 1).shape,
+            (1, 7, 3) if data_format == "channels_last" else (1, 3, 7),
+        )
+        self.assertEqual(
+            knn.max_pool(x, 2, 2, padding="same").shape,
+            (1, 4, 3) if data_format == "channels_last" else (1, 3, 4),
+        )
 
-        x = KerasTensor([1, 8, 8, 3])
-        self.assertEqual(knn.max_pool(x, 2, 1).shape, (1, 7, 7, 3))
+        if data_format == "channels_last":
+            input_shape = (1, 8, 8, 3)
+        else:
+            input_shape = (1, 3, 8, 8)
+        x = KerasTensor(input_shape)
         self.assertEqual(
-            knn.max_pool(x, 2, 2, padding="same").shape, (1, 4, 4, 3)
+            knn.max_pool(x, 2, 1).shape,
+            (1, 7, 7, 3) if data_format == "channels_last" else (1, 3, 7, 7),
         )
         self.assertEqual(
-            knn.max_pool(x, (2, 2), (2, 2), padding="same").shape, (1, 4, 4, 3)
+            knn.max_pool(x, 2, 2, padding="same").shape,
+            (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4),
+        )
+        self.assertEqual(
+            knn.max_pool(x, (2, 2), (2, 2), padding="same").shape,
+            (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4),
         )
 
     def test_average_pool(self):
-        x = KerasTensor([1, 8, 3])
-        self.assertEqual(knn.average_pool(x, 2, 1).shape, (1, 7, 3))
+        data_format = backend.config.image_data_format()
+        if data_format == "channels_last":
+            input_shape = (1, 8, 3)
+        else:
+            input_shape = (1, 3, 8)
+        x = KerasTensor(input_shape)
+        self.assertEqual(
+            knn.average_pool(x, 2, 1).shape,
+            (1, 7, 3) if data_format == "channels_last" else (1, 3, 7),
+        )
         self.assertEqual(
-            knn.average_pool(x, 2, 2, padding="same").shape, (1, 4, 3)
+            knn.average_pool(x, 2, 2, padding="same").shape,
+            (1, 4, 3) if data_format == "channels_last" else (1, 3, 4),
         )
 
-        x = KerasTensor([1, 8, 8, 3])
-        self.assertEqual(knn.average_pool(x, 2, 1).shape, (1, 7, 7, 3))
+        if data_format == "channels_last":
+            input_shape = (1, 8, 8, 3)
+        else:
+            input_shape = (1, 3, 8, 8)
+        x = KerasTensor(input_shape)
+        self.assertEqual(
+            knn.average_pool(x, 2, 1).shape,
+            (1, 7, 7, 3) if data_format == "channels_last" else (1, 3, 7, 7),
+        )
         self.assertEqual(
-            knn.average_pool(x, 2, 2, padding="same").shape, (1, 4, 4, 3)
+            knn.average_pool(x, 2, 2, padding="same").shape,
+            (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4),
         )
         self.assertEqual(
             knn.average_pool(x, (2, 2), (2, 2), padding="same").shape,
-            (1, 4, 4, 3),
+            (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4),
         )
 
     def test_conv(self):
+        data_format = backend.config.image_data_format()
         # Test 1D conv.
-        inputs_1d = KerasTensor([2, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 20, 3)
+        else:
+            input_shape = (2, 3, 20)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([4, 3, 2])
         self.assertEqual(
-            knn.conv(inputs_1d, kernel, 1, padding="valid").shape, (2, 17, 2)
+            knn.conv(inputs_1d, kernel, 1, padding="valid").shape,
+            (2, 17, 2) if data_format == "channels_last" else (2, 2, 17),
         )
         self.assertEqual(
-            knn.conv(inputs_1d, kernel, 1, padding="same").shape, (2, 20, 2)
+            knn.conv(inputs_1d, kernel, 1, padding="same").shape,
+            (2, 20, 2) if data_format == "channels_last" else (2, 2, 20),
         )
         self.assertEqual(
-            knn.conv(inputs_1d, kernel, (2,), dilation_rate=2).shape, (2, 7, 2)
+            knn.conv(inputs_1d, kernel, (2,), dilation_rate=2).shape,
+            (2, 7, 2) if data_format == "channels_last" else (2, 2, 7),
         )
 
         # Test 2D conv.
-        inputs_2d = KerasTensor([2, 10, 10, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 10, 10, 3)
+        else:
+            input_shape = (2, 3, 10, 10)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 3, 2])
         self.assertEqual(
-            knn.conv(inputs_2d, kernel, 1, padding="valid").shape, (2, 9, 9, 2)
+            knn.conv(inputs_2d, kernel, 1, padding="valid").shape,
+            (2, 9, 9, 2) if data_format == "channels_last" else (2, 2, 9, 9),
         )
         self.assertEqual(
-            knn.conv(inputs_2d, kernel, 1, padding="same").shape, (2, 10, 10, 2)
+            knn.conv(inputs_2d, kernel, 1, padding="same").shape,
+            (2, 10, 10, 2)
+            if data_format == "channels_last"
+            else (2, 2, 10, 10),
         )
         self.assertEqual(
             knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape,
-            (2, 4, 9, 2),
+            (2, 4, 9, 2) if data_format == "channels_last" else (2, 2, 4, 9),
         )
 
         # Test 3D conv.
-        inputs_3d = KerasTensor([2, 8, 8, 8, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 8, 8, 8, 3)
+        else:
+            input_shape = (2, 3, 8, 8, 8)
+        inputs_3d = KerasTensor(input_shape)
         kernel = KerasTensor([3, 3, 3, 3, 2])
         self.assertEqual(
             knn.conv(inputs_3d, kernel, 1, padding="valid").shape,
-            (2, 6, 6, 6, 2),
+            (2, 6, 6, 6, 2)
+            if data_format == "channels_last"
+            else (2, 2, 6, 6, 6),
         )
         self.assertEqual(
             knn.conv(inputs_3d, kernel, (2, 1, 2), padding="same").shape,
-            (2, 4, 8, 4, 2),
+            (2, 4, 8, 4, 2)
+            if data_format == "channels_last"
+            else (2, 2, 4, 8, 4),
         )
         self.assertEqual(
             knn.conv(
                 inputs_3d, kernel, 1, padding="valid", dilation_rate=(1, 2, 2)
             ).shape,
-            (2, 6, 4, 4, 2),
+            (2, 6, 4, 4, 2)
+            if data_format == "channels_last"
+            else (2, 2, 6, 4, 4),
         )
 
     def test_depthwise_conv(self):
+        data_format = backend.config.image_data_format()
         # Test 1D depthwise conv.
-        inputs_1d = KerasTensor([2, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 20, 3)
+        else:
+            input_shape = (2, 3, 20)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([4, 3, 1])
         self.assertEqual(
             knn.depthwise_conv(inputs_1d, kernel, 1, padding="valid").shape,
-            (2, 17, 3),
+            (2, 17, 3) if data_format == "channels_last" else (2, 3, 17),
         )
         self.assertEqual(
             knn.depthwise_conv(inputs_1d, kernel, (1,), padding="same").shape,
-            (2, 20, 3),
+            (2, 20, 3) if data_format == "channels_last" else (2, 3, 20),
         )
         self.assertEqual(
             knn.depthwise_conv(inputs_1d, kernel, 2, dilation_rate=2).shape,
-            (2, 7, 3),
+            (2, 7, 3) if data_format == "channels_last" else (2, 3, 7),
         )
 
         # Test 2D depthwise conv.
-        inputs_2d = KerasTensor([2, 10, 10, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 10, 10, 3)
+        else:
+            input_shape = (2, 3, 10, 10)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 3, 1])
         self.assertEqual(
             knn.depthwise_conv(inputs_2d, kernel, 1, padding="valid").shape,
-            (2, 9, 9, 3),
+            (2, 9, 9, 3) if data_format == "channels_last" else (2, 3, 9, 9),
         )
         self.assertEqual(
             knn.depthwise_conv(inputs_2d, kernel, (1, 2), padding="same").shape,
-            (2, 10, 5, 3),
+            (2, 10, 5, 3) if data_format == "channels_last" else (2, 3, 10, 5),
         )
         self.assertEqual(
             knn.depthwise_conv(inputs_2d, kernel, 2, dilation_rate=2).shape,
-            (2, 4, 4, 3),
+            (2, 4, 4, 3) if data_format == "channels_last" else (2, 3, 4, 4),
         )
         self.assertEqual(
             knn.depthwise_conv(
                 inputs_2d, kernel, 2, dilation_rate=(2, 1)
             ).shape,
-            (2, 4, 5, 3),
+            (2, 4, 5, 3) if data_format == "channels_last" else (2, 3, 4, 5),
         )
 
     def test_separable_conv(self):
-        # Test 1D separable conv.
-        inputs_1d = KerasTensor([2, 20, 3])
+        data_format = backend.config.image_data_format()
+        # Test 1D max pooling.
+        if data_format == "channels_last":
+            input_shape = (2, 20, 3)
+        else:
+            input_shape = (2, 3, 20)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([4, 3, 2])
         pointwise_kernel = KerasTensor([1, 6, 5])
         self.assertEqual(
             knn.separable_conv(
                 inputs_1d, kernel, pointwise_kernel, 1, padding="valid"
             ).shape,
-            (2, 17, 5),
+            (2, 17, 5) if data_format == "channels_last" else (2, 5, 17),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_1d, kernel, pointwise_kernel, 1, padding="same"
             ).shape,
-            (2, 20, 5),
+            (2, 20, 5) if data_format == "channels_last" else (2, 5, 20),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_1d, kernel, pointwise_kernel, 2, dilation_rate=2
             ).shape,
-            (2, 7, 5),
+            (2, 7, 5) if data_format == "channels_last" else (2, 5, 7),
         )
 
         # Test 2D separable conv.
-        inputs_2d = KerasTensor([2, 10, 10, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 10, 10, 3)
+        else:
+            input_shape = (2, 3, 10, 10)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 3, 2])
         pointwise_kernel = KerasTensor([1, 1, 6, 5])
         self.assertEqual(
             knn.separable_conv(
                 inputs_2d, kernel, pointwise_kernel, 1, padding="valid"
             ).shape,
-            (2, 9, 9, 5),
+            (2, 9, 9, 5) if data_format == "channels_last" else (2, 5, 9, 9),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_2d, kernel, pointwise_kernel, (1, 2), padding="same"
             ).shape,
-            (2, 10, 5, 5),
+            (2, 10, 5, 5) if data_format == "channels_last" else (2, 5, 10, 5),
         )
         self.assertEqual(
             knn.separable_conv(
                 inputs_2d, kernel, pointwise_kernel, 2, dilation_rate=(2, 1)
             ).shape,
-            (2, 4, 5, 5),
+            (2, 4, 5, 5) if data_format == "channels_last" else (2, 5, 4, 5),
         )
 
     def test_conv_transpose(self):
-        inputs_1d = KerasTensor([2, 4, 3])
+        data_format = backend.config.image_data_format()
+        if data_format == "channels_last":
+            input_shape = (2, 4, 3)
+        else:
+            input_shape = (2, 3, 4)
+        inputs_1d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 5, 3])
         self.assertEqual(
-            knn.conv_transpose(inputs_1d, kernel, 2).shape, (2, 8, 5)
+            knn.conv_transpose(inputs_1d, kernel, 2).shape,
+            (2, 8, 5) if data_format == "channels_last" else (2, 5, 8),
         )
         self.assertEqual(
             knn.conv_transpose(inputs_1d, kernel, 2, padding="same").shape,
-            (2, 8, 5),
+            (2, 8, 5) if data_format == "channels_last" else (2, 5, 8),
         )
         self.assertEqual(
             knn.conv_transpose(
                 inputs_1d, kernel, 5, padding="valid", output_padding=4
             ).shape,
-            (2, 21, 5),
+            (2, 21, 5) if data_format == "channels_last" else (2, 5, 21),
         )
 
-        inputs_2d = KerasTensor([2, 4, 4, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 4, 4, 3)
+        else:
+            input_shape = (2, 3, 4, 4)
+        inputs_2d = KerasTensor(input_shape)
         kernel = KerasTensor([2, 2, 5, 3])
         self.assertEqual(
-            knn.conv_transpose(inputs_2d, kernel, 2).shape, (2, 8, 8, 5)
+            knn.conv_transpose(inputs_2d, kernel, 2).shape,
+            (2, 8, 8, 5) if data_format == "channels_last" else (2, 5, 8, 8),
         )
         self.assertEqual(
             knn.conv_transpose(inputs_2d, kernel, (2, 2), padding="same").shape,
-            (2, 8, 8, 5),
+            (2, 8, 8, 5) if data_format == "channels_last" else (2, 5, 8, 8),
         )
         self.assertEqual(
             knn.conv_transpose(
                 inputs_2d, kernel, (5, 5), padding="valid", output_padding=4
             ).shape,
-            (2, 21, 21, 5),
+            (2, 21, 21, 5)
+            if data_format == "channels_last"
+            else (2, 5, 21, 21),
         )
 
     def test_batched_and_unbatched_inputs_multi_hot(self):
@@ -793,43 +1022,59 @@ def test_log_softmax(self):
         )
 
     def test_max_pool(self):
+        data_format = backend.config.image_data_format()
         # Test 1D max pooling.
-        x = np.arange(120, dtype=float).reshape([2, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 20, 3)
+        else:
+            input_shape = (2, 3, 20)
+        x = np.arange(120, dtype=float).reshape(input_shape)
         self.assertAllClose(
             knn.max_pool(x, 2, 1, padding="valid"),
-            np_maxpool1d(x, 2, 1, padding="valid", data_format="channels_last"),
+            np_maxpool1d(x, 2, 1, padding="valid", data_format=data_format),
         )
         self.assertAllClose(
             knn.max_pool(x, 2, 2, padding="same"),
-            np_maxpool1d(x, 2, 2, padding="same", data_format="channels_last"),
+            np_maxpool1d(x, 2, 2, padding="same", data_format=data_format),
         )
 
         # Test 2D max pooling.
-        x = np.arange(540, dtype=float).reshape([2, 10, 9, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 10, 9, 3)
+        else:
+            input_shape = (2, 3, 10, 9)
+        x = np.arange(540, dtype=float).reshape(input_shape)
         self.assertAllClose(
             knn.max_pool(x, 2, 1, padding="valid"),
-            np_maxpool2d(x, 2, 1, padding="valid", data_format="channels_last"),
+            np_maxpool2d(x, 2, 1, padding="valid", data_format=data_format),
         )
         self.assertAllClose(
             knn.max_pool(x, 2, (2, 1), padding="same"),
-            np_maxpool2d(
-                x, 2, (2, 1), padding="same", data_format="channels_last"
-            ),
+            np_maxpool2d(x, 2, (2, 1), padding="same", data_format=data_format),
         )
 
     def test_average_pool_valid_padding(self):
+        data_format = backend.config.image_data_format()
         # Test 1D max pooling.
-        x = np.arange(120, dtype=float).reshape([2, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 20, 3)
+        else:
+            input_shape = (2, 3, 20)
+        x = np.arange(120, dtype=float).reshape(input_shape)
         self.assertAllClose(
             knn.average_pool(x, 2, 1, padding="valid"),
-            np_avgpool1d(x, 2, 1, padding="valid", data_format="channels_last"),
+            np_avgpool1d(x, 2, 1, padding="valid", data_format=data_format),
         )
 
         # Test 2D max pooling.
-        x = np.arange(540, dtype=float).reshape([2, 10, 9, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 10, 9, 3)
+        else:
+            input_shape = (2, 3, 10, 9)
+        x = np.arange(540, dtype=float).reshape(input_shape)
         self.assertAllClose(
             knn.average_pool(x, 2, 1, padding="valid"),
-            np_avgpool2d(x, 2, 1, padding="valid", data_format="channels_last"),
+            np_avgpool2d(x, 2, 1, padding="valid", data_format=data_format),
         )
 
     @pytest.mark.skipif(
@@ -837,20 +1082,28 @@ def test_average_pool_valid_padding(self):
         reason="Torch outputs differently from TF when using `same` padding.",
     )
     def test_average_pool_same_padding(self):
+        data_format = backend.config.image_data_format()
         # Test 1D max pooling.
-        x = np.arange(120, dtype=float).reshape([2, 20, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 20, 3)
+        else:
+            input_shape = (2, 3, 20)
+        x = np.arange(120, dtype=float).reshape(input_shape)
+
         self.assertAllClose(
             knn.average_pool(x, 2, 2, padding="same"),
-            np_avgpool1d(x, 2, 2, padding="same", data_format="channels_last"),
+            np_avgpool1d(x, 2, 2, padding="same", data_format=data_format),
         )
 
         # Test 2D max pooling.
-        x = np.arange(540, dtype=float).reshape([2, 10, 9, 3])
+        if data_format == "channels_last":
+            input_shape = (2, 10, 9, 3)
+        else:
+            input_shape = (2, 3, 10, 9)
+        x = np.arange(540, dtype=float).reshape(input_shape)
         self.assertAllClose(
             knn.average_pool(x, 2, (2, 1), padding="same"),
-            np_avgpool2d(
-                x, 2, (2, 1), padding="same", data_format="channels_last"
-            ),
+            np_avgpool2d(x, 2, (2, 1), padding="same", data_format=data_format),
         )
 
     @parameterized.product(
@@ -862,7 +1115,11 @@ def test_conv_1d(self, strides, padding, dilation_rate):
         if strides > 1 and dilation_rate > 1:
             pytest.skip("Unsupported configuration")
 
-        inputs_1d = np.arange(120, dtype=float).reshape([2, 20, 3])
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 20, 3)
+        else:
+            input_shape = (2, 3, 20)
+        inputs_1d = np.arange(120, dtype=float).reshape(input_shape)
         kernel = np.arange(24, dtype=float).reshape([4, 3, 2])
 
         outputs = knn.conv(
@@ -878,371 +1135,222 @@ def test_conv_1d(self, strides, padding, dilation_rate):
             bias_weights=np.zeros((2,)),
             strides=strides,
             padding=padding.lower(),
-            data_format="channels_last",
+            data_format=backend.config.image_data_format(),
             dilation_rate=dilation_rate,
             groups=1,
         )
         self.assertAllClose(outputs, expected)
 
-    def test_conv_2d(self):
-        inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
+    @parameterized.product(strides=(1, 2, (1, 2)), padding=("valid", "same"))
+    def test_conv_2d(self, strides, padding):
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 10, 10, 3)
+        else:
+            input_shape = (2, 3, 10, 10)
+        inputs_2d = np.arange(600, dtype=float).reshape(input_shape)
         kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
 
-        outputs = knn.conv(inputs_2d, kernel, 1, padding="valid")
+        outputs = knn.conv(inputs_2d, kernel, strides, padding=padding)
         expected = np_conv2d(
             inputs_2d,
             kernel,
             bias_weights=np.zeros((2,)),
-            strides=1,
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.conv(inputs_2d, kernel, (1, 2), padding="valid")
-        expected = np_conv2d(
-            inputs_2d,
-            kernel,
-            bias_weights=np.zeros((2,)),
-            strides=(1, 2),
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.conv(inputs_2d, kernel, (1, 2), padding="same")
-        expected = np_conv2d(
-            inputs_2d,
-            kernel,
-            bias_weights=np.zeros((2,)),
-            strides=(1, 2),
-            padding="same",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.conv(inputs_2d, kernel, 2, padding="same")
-        expected = np_conv2d(
-            inputs_2d,
-            kernel,
-            bias_weights=np.zeros((2,)),
-            strides=2,
-            padding="same",
-            data_format="channels_last",
+            strides=strides,
+            padding=padding,
+            data_format=backend.config.image_data_format(),
             dilation_rate=1,
             groups=1,
         )
         self.assertAllClose(outputs, expected)
 
-        # Test group > 1.
-        inputs_2d = np.ones([2, 10, 10, 4])
+    @parameterized.product(strides=(1, 2), dilation_rate=(1, (2, 1)))
+    def test_conv_2d_group_2(self, strides, dilation_rate):
+        if (
+            backend.backend() == "tensorflow"
+            and strides == 2
+            and dilation_rate == (2, 1)
+        ):
+            # This case is not supported by the TF backend.
+            return
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 10, 10, 4)
+        else:
+            input_shape = (2, 4, 10, 10)
+        inputs_2d = np.ones(input_shape)
         kernel = np.ones([2, 2, 2, 6])
         outputs = knn.conv(
-            inputs_2d, kernel, 2, padding="same", dilation_rate=1
-        )
-        expected = np_conv2d(
             inputs_2d,
             kernel,
-            bias_weights=np.zeros((6,)),
-            strides=2,
+            strides,
             padding="same",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.conv(
-            inputs_2d,
-            kernel,
-            1,
-            padding="same",
-            dilation_rate=(2, 1),
+            dilation_rate=dilation_rate,
         )
         expected = np_conv2d(
             inputs_2d,
             kernel,
             bias_weights=np.zeros((6,)),
-            strides=1,
+            strides=strides,
             padding="same",
-            data_format="channels_last",
-            dilation_rate=(2, 1),
+            data_format=backend.config.image_data_format(),
+            dilation_rate=dilation_rate,
             groups=1,
         )
         self.assertAllClose(outputs, expected)
 
-    def test_conv_3d(self):
-        inputs_3d = np.arange(3072, dtype=float).reshape([2, 8, 8, 8, 3])
+    @parameterized.product(strides=(1, (1, 1, 1), 2), padding=("valid", "same"))
+    def test_conv_3d(self, strides, padding):
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 8, 8, 8, 3)
+        else:
+            input_shape = (2, 3, 8, 8, 8)
+        inputs_3d = np.arange(3072, dtype=float).reshape(input_shape)
         kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2])
 
-        outputs = knn.conv(inputs_3d, kernel, 1, padding="valid")
-        expected = np_conv3d(
-            inputs_3d,
-            kernel,
-            bias_weights=np.zeros((2,)),
-            strides=(1, 1, 1),
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
-
-        outputs = knn.conv(
-            inputs_3d,
-            kernel,
-            (1, 1, 1),
-            padding="valid",
-            dilation_rate=(1, 1, 1),
-        )
-        expected = np_conv3d(
-            inputs_3d,
-            kernel,
-            bias_weights=np.zeros((2,)),
-            strides=(1, 1, 1),
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=(1, 1, 1),
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
-
-        outputs = knn.conv(inputs_3d, kernel, 2, padding="valid")
+        outputs = knn.conv(inputs_3d, kernel, strides, padding=padding)
         expected = np_conv3d(
             inputs_3d,
             kernel,
             bias_weights=np.zeros((2,)),
-            strides=2,
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
-
-        outputs = knn.conv(inputs_3d, kernel, 2, padding="same")
-        expected = np_conv3d(
-            inputs_3d,
-            kernel,
-            bias_weights=np.zeros((2,)),
-            strides=2,
-            padding="same",
-            data_format="channels_last",
+            strides=strides,
+            padding=padding,
+            data_format=backend.config.image_data_format(),
             dilation_rate=1,
             groups=1,
         )
         self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
 
-    def test_depthwise_conv_2d(self):
-        inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
+    @parameterized.product(
+        strides=(1, (1, 1), (2, 2)),
+        padding=("valid", "same"),
+        dilation_rate=(1, (2, 2)),
+    )
+    def test_depthwise_conv_2d(self, strides, padding, dilation_rate):
+        if (
+            backend.backend() == "tensorflow"
+            and strides == (2, 2)
+            and dilation_rate == (2, 2)
+        ):
+            # This case is not supported by the TF backend.
+            return
+        print(strides, padding, dilation_rate)
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 10, 10, 3)
+        else:
+            input_shape = (2, 3, 10, 10)
+        inputs_2d = np.arange(600, dtype=float).reshape(input_shape)
         kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
 
-        outputs = knn.depthwise_conv(inputs_2d, kernel, 1, padding="valid")
-        expected = np_depthwise_conv2d(
-            inputs_2d,
-            kernel,
-            bias_weights=np.zeros((6,)),
-            strides=1,
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.depthwise_conv(inputs_2d, kernel, (1, 1), padding="valid")
-        expected = np_depthwise_conv2d(
-            inputs_2d,
-            kernel,
-            bias_weights=np.zeros((6,)),
-            strides=(1, 1),
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.depthwise_conv(inputs_2d, kernel, (2, 2), padding="same")
-        expected = np_depthwise_conv2d(
+        outputs = knn.depthwise_conv(
             inputs_2d,
             kernel,
-            bias_weights=np.zeros((6,)),
-            strides=(2, 2),
-            padding="same",
-            data_format="channels_last",
-            dilation_rate=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.depthwise_conv(
-            inputs_2d, kernel, 1, padding="same", dilation_rate=(2, 2)
+            strides,
+            padding=padding,
+            dilation_rate=dilation_rate,
         )
         expected = np_depthwise_conv2d(
             inputs_2d,
             kernel,
             bias_weights=np.zeros((6,)),
-            strides=1,
-            padding="same",
-            data_format="channels_last",
-            dilation_rate=(2, 2),
+            strides=strides,
+            padding=padding,
+            data_format=backend.config.image_data_format(),
+            dilation_rate=dilation_rate,
         )
         self.assertAllClose(outputs, expected)
 
-    def test_separable_conv_2d(self):
+    @parameterized.product(
+        strides=(1, 2),
+        padding=("valid", "same"),
+        dilation_rate=(1, (2, 2)),
+    )
+    def test_separable_conv_2d(self, strides, padding, dilation_rate):
+        if (
+            backend.backend() == "tensorflow"
+            and strides == 2
+            and dilation_rate == (2, 2)
+        ):
+            # This case is not supported by the TF backend.
+            return
         # Test 2D conv.
-        inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 10, 10, 3)
+        else:
+            input_shape = (2, 3, 10, 10)
+        inputs_2d = np.arange(600, dtype=float).reshape(input_shape)
         depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
         pointwise_kernel = np.arange(72, dtype=float).reshape([1, 1, 6, 12])
 
         outputs = knn.separable_conv(
-            inputs_2d, depthwise_kernel, pointwise_kernel, 1, padding="valid"
-        )
-        # Depthwise followed by pointwise conv
-        expected_depthwise = np_depthwise_conv2d(
             inputs_2d,
             depthwise_kernel,
-            np.zeros(6),
-            strides=1,
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-        )
-        expected = np_conv2d(
-            expected_depthwise,
             pointwise_kernel,
-            np.zeros(6 * 12),
-            strides=1,
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.separable_conv(
-            inputs_2d,
-            depthwise_kernel,
-            pointwise_kernel,
-            (1, 1),
-            padding="valid",
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.separable_conv(
-            inputs_2d, depthwise_kernel, pointwise_kernel, 2, padding="same"
-        )
-        # Depthwise followed by pointwise conv
-        expected_depthwise = np_depthwise_conv2d(
-            inputs_2d,
-            depthwise_kernel,
-            np.zeros(6),
-            strides=2,
-            padding="same",
-            data_format="channels_last",
-            dilation_rate=1,
-        )
-        expected = np_conv2d(
-            expected_depthwise,
-            pointwise_kernel,
-            np.zeros(6 * 12),
-            strides=1,
-            padding="same",
-            data_format="channels_last",
-            dilation_rate=1,
-            groups=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.separable_conv(
-            inputs_2d,
-            depthwise_kernel,
-            pointwise_kernel,
-            1,
-            padding="same",
-            dilation_rate=(2, 2),
+            strides,
+            padding=padding,
+            dilation_rate=dilation_rate,
         )
         # Depthwise followed by pointwise conv
         expected_depthwise = np_depthwise_conv2d(
             inputs_2d,
             depthwise_kernel,
             np.zeros(6),
-            strides=1,
-            padding="same",
-            data_format="channels_last",
-            dilation_rate=(2, 2),
+            strides=strides,
+            padding=padding,
+            data_format=backend.config.image_data_format(),
+            dilation_rate=dilation_rate,
         )
         expected = np_conv2d(
             expected_depthwise,
             pointwise_kernel,
             np.zeros(6 * 12),
             strides=1,
-            padding="same",
-            data_format="channels_last",
-            dilation_rate=1,
+            padding=padding,
+            data_format=backend.config.image_data_format(),
+            dilation_rate=dilation_rate,
             groups=1,
         )
         self.assertAllClose(outputs, expected)
 
-    def test_conv_transpose_1d(self):
-        inputs_1d = np.arange(24, dtype=float).reshape([2, 4, 3])
+    @parameterized.product(padding=("valid", "same"))
+    def test_conv_transpose_1d(self, padding):
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 4, 3)
+        else:
+            input_shape = (2, 3, 4)
+        inputs_1d = np.arange(24, dtype=float).reshape(input_shape)
         kernel = np.arange(30, dtype=float).reshape([2, 5, 3])
-        outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="valid")
+        outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding=padding)
         expected = np_conv1d_transpose(
             inputs_1d,
             kernel,
             bias_weights=np.zeros(5),
             strides=2,
             output_padding=None,
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
-        )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="same")
-        expected = np_conv1d_transpose(
-            inputs_1d,
-            kernel,
-            bias_weights=np.zeros(5),
-            strides=2,
-            output_padding=None,
-            padding="same",
-            data_format="channels_last",
+            padding=padding,
+            data_format=backend.config.image_data_format(),
             dilation_rate=1,
         )
         self.assertAllClose(outputs, expected)
 
-    def test_conv_transpose_2d(self):
-        inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3])
+    @parameterized.product(strides=(2, (2, 2)), padding=("valid", "same"))
+    def test_conv_transpose_2d(self, strides, padding):
+        if backend.config.image_data_format() == "channels_last":
+            input_shape = (2, 4, 4, 3)
+        else:
+            input_shape = (2, 3, 4, 4)
+        inputs_2d = np.arange(96, dtype=float).reshape(input_shape)
         kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])
 
-        outputs = knn.conv_transpose(inputs_2d, kernel, (2, 2), padding="valid")
-        expected = np_conv2d_transpose(
-            inputs_2d,
-            kernel,
-            bias_weights=np.zeros(5),
-            strides=(2, 2),
-            output_padding=None,
-            padding="valid",
-            data_format="channels_last",
-            dilation_rate=1,
+        outputs = knn.conv_transpose(
+            inputs_2d, kernel, strides, padding=padding
         )
-        self.assertAllClose(outputs, expected)
-
-        outputs = knn.conv_transpose(inputs_2d, kernel, 2, padding="same")
         expected = np_conv2d_transpose(
             inputs_2d,
             kernel,
             bias_weights=np.zeros(5),
-            strides=2,
+            strides=strides,
             output_padding=None,
-            padding="same",
-            data_format="channels_last",
+            padding=padding,
+            data_format=backend.config.image_data_format(),
             dilation_rate=1,
         )
         self.assertAllClose(outputs, expected)
@@ -1498,3 +1606,4 @@ def test_on_moments(inputs):
             mean, variance = strategy.run(test_on_moments, args=(inputs,))
             self.assertEqual(mean.values[0], 4.5)
             self.assertEqual(variance.values[0], 8.75)
+            self.assertEqual(variance.values[0], 8.75)