diff --git a/keras_core/layers/merging/merging_test.py b/keras_core/layers/merging/merging_test.py index 5d205363a..94630c1cf 100644 --- a/keras_core/layers/merging/merging_test.py +++ b/keras_core/layers/merging/merging_test.py @@ -1,21 +1,97 @@ import numpy as np import pytest +from absl.testing import parameterized from keras_core import backend from keras_core import layers from keras_core import models -from keras_core import ops from keras_core import testing +def np_dot(a, b, axes): + if isinstance(axes, int): + axes = (axes, axes) + axes = [axis if axis < 0 else axis - 1 for axis in axes] + res = np.stack([np.tensordot(a[i], b[i], axes) for i in range(a.shape[0])]) + if len(res.shape) == 1: + res = np.expand_dims(res, axis=1) + return res + + +TEST_PARAMETERS = [ + { + "testcase_name": "add", + "layer_class": layers.Add, + "np_op": np.add, + }, + { + "testcase_name": "substract", + "layer_class": layers.Subtract, + "np_op": np.subtract, + }, + { + "testcase_name": "minimum", + "layer_class": layers.Minimum, + "np_op": np.minimum, + }, + { + "testcase_name": "maximum", + "layer_class": layers.Maximum, + "np_op": np.maximum, + }, + { + "testcase_name": "multiply", + "layer_class": layers.Multiply, + "np_op": np.multiply, + }, + { + "testcase_name": "average", + "layer_class": layers.Average, + "np_op": lambda a, b: np.multiply(np.add(a, b), 0.5), + }, + { + "testcase_name": "concat", + "layer_class": layers.Concatenate, + "np_op": lambda a, b, **kwargs: np.concatenate((a, b), **kwargs), + "init_kwargs": {"axis": -1}, + "expected_output_shape": (2, 4, 10), + }, + { + "testcase_name": "dot_2d", + "layer_class": layers.Dot, + "np_op": np_dot, + "init_kwargs": {"axes": -1}, + "input_shape": (2, 4), + "expected_output_shape": (2, 1), + "skip_mask_test": True, + }, + { + "testcase_name": "dot_3d", + "layer_class": layers.Dot, + "np_op": np_dot, + "init_kwargs": {"axes": -1}, + "expected_output_shape": (2, 4, 4), + "skip_mask_test": True, + }, +] + + @pytest.mark.requires_trainable_backend -class MergingLayersTest(testing.TestCase): - def test_add_basic(self): +class MergingLayersTest(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters(TEST_PARAMETERS) + def test_basic( + self, + layer_class, + init_kwargs={}, + input_shape=(2, 4, 5), + expected_output_shape=(2, 4, 5), + **kwargs + ): self.run_layer_test( - layers.Add, - init_kwargs={}, - input_shape=[(2, 3), (2, 3)], - expected_output_shape=(2, 3), + layer_class, + init_kwargs=init_kwargs, + input_shape=[input_shape, input_shape], + expected_output_shape=expected_output_shape, expected_num_trainable_weights=0, expected_num_non_trainable_weights=0, expected_num_seed_generators=0, @@ -23,175 +99,119 @@ def test_add_basic(self): supports_masking=True, ) - def test_add_correctness_dynamic(self): - x1 = np.random.rand(2, 4, 5) - x2 = np.random.rand(2, 4, 5) - x3 = ops.convert_to_tensor(x1 + x2) + @parameterized.named_parameters(TEST_PARAMETERS) + def test_correctness_static( + self, + layer_class, + np_op, + init_kwargs={}, + input_shape=(2, 4, 5), + expected_output_shape=(2, 4, 5), + skip_mask_test=False, + ): + batch_size = input_shape[0] + shape = input_shape[1:] + x1 = np.random.rand(*input_shape) + x2 = np.random.rand(*input_shape) + x3 = np_op(x1, x2, **init_kwargs) - input_1 = layers.Input(shape=(4, 5)) - input_2 = layers.Input(shape=(4, 5)) - add_layer = layers.Add() - out = add_layer([input_1, input_2]) + input_1 = layers.Input(shape=shape, batch_size=batch_size) + input_2 = layers.Input(shape=shape, batch_size=batch_size) + layer = layer_class(**init_kwargs) + out = layer([input_1, input_2]) model = models.Model([input_1, input_2], out) res = model([x1, x2]) - self.assertEqual(res.shape, (2, 4, 5)) + self.assertEqual(res.shape, expected_output_shape) self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - add_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - add_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], + self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None])) + if not skip_mask_test: + self.assertTrue( + np.all( + backend.convert_to_numpy( + layer.compute_mask( + [input_1, input_2], + [backend.Variable(x1), backend.Variable(x2)], + ) ) ) ) - ) - def test_add_correctness_static(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - x2 = np.random.rand(batch_size, *shape) - x3 = ops.convert_to_tensor(x1 + x2) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - add_layer = layers.Add() - out = add_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (batch_size, *shape)) + @parameterized.named_parameters(TEST_PARAMETERS) + def test_correctness_dynamic( + self, + layer_class, + np_op, + init_kwargs={}, + input_shape=(2, 4, 5), + expected_output_shape=(2, 4, 5), + skip_mask_test=False, + ): + shape = input_shape[1:] + x1 = np.random.rand(*input_shape) + x2 = np.random.rand(*input_shape) + x3 = np_op(x1, x2, **init_kwargs) + + input_1 = layers.Input(shape=shape) + input_2 = layers.Input(shape=shape) + layer = layer_class(**init_kwargs) + out = layer([input_1, input_2]) + model = models.Model([input_1, input_2], out) + res = model([x1, x2]) + + self.assertEqual(res.shape, expected_output_shape) self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - add_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - add_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], + self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None])) + if not skip_mask_test: + self.assertTrue( + np.all( + backend.convert_to_numpy( + layer.compute_mask( + [input_1, input_2], + [backend.Variable(x1), backend.Variable(x2)], + ) ) ) ) - ) - def test_add_errors(self): - batch_size = 2 - shape = (4, 5) + @parameterized.named_parameters(TEST_PARAMETERS) + def test_errors( + self, + layer_class, + init_kwargs={}, + input_shape=(2, 4, 5), + skip_mask_test=False, + **kwargs + ): + if skip_mask_test: + pytest.skip("Masking not supported") + + batch_size = input_shape[0] + shape = input_shape[1:] + x1 = np.random.rand(*input_shape) x1 = np.random.rand(batch_size, *shape) input_1 = layers.Input(shape=shape, batch_size=batch_size) input_2 = layers.Input(shape=shape, batch_size=batch_size) - add_layer = layers.Add() + layer = layer_class(**init_kwargs) with self.assertRaisesRegex(ValueError, "`mask` should be a list."): - add_layer.compute_mask([input_1, input_2], x1) + layer.compute_mask([input_1, input_2], x1) with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): - add_layer.compute_mask(input_1, [None, None]) + layer.compute_mask(input_1, [None, None]) with self.assertRaisesRegex( ValueError, " should have the same length." ): - add_layer.compute_mask([input_1, input_2], [None]) + layer.compute_mask([input_1, input_2], [None]) - def test_subtract_basic(self): - self.run_layer_test( - layers.Subtract, - init_kwargs={}, - input_shape=[(2, 3), (2, 3)], - expected_output_shape=(2, 3), - expected_num_trainable_weights=0, - expected_num_non_trainable_weights=0, - expected_num_seed_generators=0, - expected_num_losses=0, - supports_masking=True, - ) - - def test_subtract_correctness_dynamic(self): - x1 = np.random.rand(2, 4, 5) - x2 = np.random.rand(2, 4, 5) - x3 = ops.convert_to_tensor(x1 - x2) - - input_1 = layers.Input(shape=(4, 5)) - input_2 = layers.Input(shape=(4, 5)) - subtract_layer = layers.Subtract() - out = subtract_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 4, 5)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - subtract_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - subtract_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_subtract_correctness_static(self): - batch_size = 2 + def test_subtract_layer_inputs_length_errors(self): shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - x2 = np.random.rand(batch_size, *shape) - x3 = ops.convert_to_tensor(x1 - x2) + input_1 = layers.Input(shape=shape) + input_2 = layers.Input(shape=shape) + input_3 = layers.Input(shape=shape) - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - subtract_layer = layers.Subtract() - out = subtract_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (batch_size, *shape)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - subtract_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - subtract_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_subtract_errors(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - input_3 = layers.Input(shape=shape, batch_size=batch_size) - subtract_layer = layers.Subtract() - - with self.assertRaisesRegex(ValueError, "`mask` should be a list."): - subtract_layer.compute_mask([input_1, input_2], x1) - - with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): - subtract_layer.compute_mask(input_1, [None, None]) - - with self.assertRaisesRegex( - ValueError, " should have the same length." - ): - subtract_layer.compute_mask([input_1, input_2], [None]) with self.assertRaisesRegex( ValueError, "layer should be called on exactly 2 inputs" ): @@ -200,515 +220,3 @@ def test_subtract_errors(self): ValueError, "layer should be called on exactly 2 inputs" ): layers.Subtract()([input_1]) - - def test_minimum_basic(self): - self.run_layer_test( - layers.Minimum, - init_kwargs={}, - input_shape=[(2, 3), (2, 3)], - expected_output_shape=(2, 3), - expected_num_trainable_weights=0, - expected_num_non_trainable_weights=0, - expected_num_seed_generators=0, - expected_num_losses=0, - supports_masking=True, - ) - - def test_minimum_correctness_dynamic(self): - x1 = np.random.rand(2, 4, 5) - x2 = np.random.rand(2, 4, 5) - x3 = ops.minimum(x1, x2) - - input_1 = layers.Input(shape=(4, 5)) - input_2 = layers.Input(shape=(4, 5)) - merge_layer = layers.Minimum() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 4, 5)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_minimum_correctness_static(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - x2 = np.random.rand(batch_size, *shape) - x3 = ops.minimum(x1, x2) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Minimum() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (batch_size, *shape)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_minimum_errors(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Minimum() - - with self.assertRaisesRegex(ValueError, "`mask` should be a list."): - merge_layer.compute_mask([input_1, input_2], x1) - - with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): - merge_layer.compute_mask(input_1, [None, None]) - - with self.assertRaisesRegex( - ValueError, " should have the same length." - ): - merge_layer.compute_mask([input_1, input_2], [None]) - - def test_maximum_basic(self): - self.run_layer_test( - layers.Maximum, - init_kwargs={}, - input_shape=[(2, 3), (2, 3)], - expected_output_shape=(2, 3), - expected_num_trainable_weights=0, - expected_num_non_trainable_weights=0, - expected_num_seed_generators=0, - expected_num_losses=0, - supports_masking=True, - ) - - def test_maximum_correctness_dynamic(self): - x1 = np.random.rand(2, 4, 5) - x2 = np.random.rand(2, 4, 5) - x3 = ops.maximum(x1, x2) - - input_1 = layers.Input(shape=(4, 5)) - input_2 = layers.Input(shape=(4, 5)) - merge_layer = layers.Maximum() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 4, 5)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_maximum_correctness_static(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - x2 = np.random.rand(batch_size, *shape) - x3 = ops.maximum(x1, x2) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Maximum() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (batch_size, *shape)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_maximum_errors(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Maximum() - - with self.assertRaisesRegex(ValueError, "`mask` should be a list."): - merge_layer.compute_mask([input_1, input_2], x1) - - with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): - merge_layer.compute_mask(input_1, [None, None]) - - with self.assertRaisesRegex( - ValueError, " should have the same length." - ): - merge_layer.compute_mask([input_1, input_2], [None]) - - def test_multiply_basic(self): - self.run_layer_test( - layers.Multiply, - init_kwargs={}, - input_shape=[(2, 3), (2, 3)], - expected_output_shape=(2, 3), - expected_num_trainable_weights=0, - expected_num_non_trainable_weights=0, - expected_num_seed_generators=0, - expected_num_losses=0, - supports_masking=True, - ) - - def test_multiply_correctness_dynamic(self): - x1 = np.random.rand(2, 4, 5) - x2 = np.random.rand(2, 4, 5) - x3 = ops.convert_to_tensor(x1 * x2) - - input_1 = layers.Input(shape=(4, 5)) - input_2 = layers.Input(shape=(4, 5)) - merge_layer = layers.Multiply() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 4, 5)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_multiply_correctness_static(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - x2 = np.random.rand(batch_size, *shape) - x3 = ops.convert_to_tensor(x1 * x2) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Multiply() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (batch_size, *shape)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_multiply_errors(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Multiply() - - with self.assertRaisesRegex(ValueError, "`mask` should be a list."): - merge_layer.compute_mask([input_1, input_2], x1) - - with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): - merge_layer.compute_mask(input_1, [None, None]) - - with self.assertRaisesRegex( - ValueError, " should have the same length." - ): - merge_layer.compute_mask([input_1, input_2], [None]) - - def test_average_basic(self): - self.run_layer_test( - layers.Average, - init_kwargs={}, - input_shape=[(2, 3), (2, 3)], - expected_output_shape=(2, 3), - expected_num_trainable_weights=0, - expected_num_non_trainable_weights=0, - expected_num_seed_generators=0, - expected_num_losses=0, - supports_masking=True, - ) - - def test_average_correctness_dynamic(self): - x1 = np.random.rand(2, 4, 5) - x2 = np.random.rand(2, 4, 5) - x3 = ops.average(np.array([x1, x2]), axis=0) - - input_1 = layers.Input(shape=(4, 5)) - input_2 = layers.Input(shape=(4, 5)) - merge_layer = layers.Average() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 4, 5)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_average_correctness_static(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - x2 = np.random.rand(batch_size, *shape) - x3 = ops.average(np.array([x1, x2]), axis=0) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Average() - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (batch_size, *shape)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_average_errors(self): - batch_size = 2 - shape = (4, 5) - x1 = np.random.rand(batch_size, *shape) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Average() - - with self.assertRaisesRegex(ValueError, "`mask` should be a list."): - merge_layer.compute_mask([input_1, input_2], x1) - - with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): - merge_layer.compute_mask(input_1, [None, None]) - - with self.assertRaisesRegex( - ValueError, " should have the same length." - ): - merge_layer.compute_mask([input_1, input_2], [None]) - - def test_concatenate_basic(self): - self.run_layer_test( - layers.Concatenate, - init_kwargs={"axis": 1}, - input_shape=[(2, 3), (2, 3)], - expected_output_shape=(2, 6), - expected_num_trainable_weights=0, - expected_num_non_trainable_weights=0, - expected_num_seed_generators=0, - expected_num_losses=0, - supports_masking=True, - ) - - def test_concatenate_correctness_dynamic(self): - x1 = np.random.rand(2, 4, 5) - x2 = np.random.rand(2, 4, 5) - axis = 1 - - x3 = ops.concatenate([x1, x2], axis=axis) - - input_1 = layers.Input(shape=(4, 5)) - input_2 = layers.Input(shape=(4, 5)) - merge_layer = layers.Concatenate(axis=axis) - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 8, 5)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [backend.Variable(x1), backend.Variable(x2)], - ) - ) - ) - ) - - def test_concatenate_correctness_static(self): - batch_size = 2 - shape = (4, 5) - axis = 1 - x1 = np.random.rand(batch_size, *shape) - x2 = np.random.rand(batch_size, *shape) - x3 = ops.concatenate([x1, x2], axis=axis) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Concatenate(axis=axis) - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (batch_size, 8, 5)) - self.assertAllClose(res, x3, atol=1e-4) - self.assertIsNone( - merge_layer.compute_mask([input_1, input_2], [None, None]) - ) - self.assertTrue( - np.all( - backend.convert_to_numpy( - merge_layer.compute_mask( - [input_1, input_2], - [x1, x2], - ) - ) - ) - ) - - def test_concatenate_errors(self): - batch_size = 2 - shape = (4, 5) - axis = 1 - x1 = np.random.rand(batch_size, *shape) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Concatenate(axis=axis) - - with self.assertRaisesRegex(ValueError, "`mask` should be a list."): - merge_layer.compute_mask([input_1, input_2], x1) - - with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): - merge_layer.compute_mask(input_1, [None, None]) - - with self.assertRaisesRegex( - ValueError, " should have the same length." - ): - merge_layer.compute_mask([input_1, input_2], [None]) - - def test_dot_basic(self): - self.run_layer_test( - layers.Dot, - init_kwargs={"axes": -1}, - input_shape=[(4, 3), (4, 3)], - expected_output_shape=(4, 1), - expected_num_trainable_weights=0, - expected_num_non_trainable_weights=0, - expected_num_seed_generators=0, - expected_num_losses=0, - supports_masking=None, - ) - - def test_dot_correctness_dynamic(self): - x1 = np.random.rand(2, 4) - x2 = np.random.rand(2, 4) - axes = 1 - - expected = np.zeros((2, 1)) - expected[0, 0] = np.dot(x1[0], x2[0]) - expected[1, 0] = np.dot(x1[1], x2[1]) - - input_1 = layers.Input(shape=(4,)) - input_2 = layers.Input(shape=(4,)) - merge_layer = layers.Dot(axes=axes) - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 1)) - self.assertAllClose(res, expected, atol=1e-4) - - def test_dot_correctness_static(self): - batch_size = 2 - shape = (4,) - axes = 1 - - x1 = np.random.rand(2, 4) - x2 = np.random.rand(2, 4) - expected = np.zeros((2, 1)) - expected[0, 0] = np.dot(x1[0], x2[0]) - expected[1, 0] = np.dot(x1[1], x2[1]) - - input_1 = layers.Input(shape=shape, batch_size=batch_size) - input_2 = layers.Input(shape=shape, batch_size=batch_size) - merge_layer = layers.Dot(axes=axes) - out = merge_layer([input_1, input_2]) - model = models.Model([input_1, input_2], out) - res = model([x1, x2]) - - self.assertEqual(res.shape, (2, 1)) - self.assertAllClose(res, expected, atol=1e-4)