From 1d124bb2014c4576ba5025b360ff8a96d9aa63cf Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:28:47 +0800 Subject: [PATCH] Update `quantizers` tests --- keras/quantizers/quantizers_test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/keras/quantizers/quantizers_test.py b/keras/quantizers/quantizers_test.py index d646f612b9d..61ae58a8428 100644 --- a/keras/quantizers/quantizers_test.py +++ b/keras/quantizers/quantizers_test.py @@ -9,16 +9,19 @@ def test_abs_max_quantizer(self): values = random.uniform([3, 4, 5], minval=-1, maxval=1) quantizer = quantizers.AbsMaxQuantizer(axis=-1) - # Test quantize + # Test quantizing quantized_values, scale = quantizer(values) - self.assertEqual(quantized_values.shape, [3, 4, 5]) - self.assertEqual(scale.shape, [3, 4, 1]) + self.assertEqual(tuple(quantized_values.shape), (3, 4, 5)) + self.assertEqual(tuple(scale.shape), (3, 4, 1)) self.assertLessEqual(ops.max(quantized_values), 127) self.assertGreaterEqual(ops.min(quantized_values), -127) - # Test dequantize + # Test dequantizing dequantized_values = ops.divide(quantized_values, scale) - self.assertAllClose(values, dequantized_values, atol=1) + rmse = ops.sqrt( + ops.mean(ops.square(ops.subtract(values, dequantized_values))) + ) + self.assertLess(rmse, 1e-1) # loose assertion # Test serialization self.run_class_serialization_test(quantizer)