From 544a2d8e426e2b25095eeb08c6a07a0509a2c7d6 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 3 Dec 2024 14:18:36 -0800 Subject: [PATCH] [sharktank] Fix flaxy quantizer test (attempt 2). (#641) Follow-up to https://github.com/nod-ai/shark-ai/pull/631. The test is flaking at this line too: https://github.com/nod-ai/shark-ai/actions/runs/12148591430/job/33877368151?pr=639#step:6:25 ``` =================================== FAILURES =================================== ________________ StaticScaledQuantizerTest.testPerAxisRoundtrip ________________ [gw0] linux -- Python 3.11.10 /opt/hostedtoolcache/Python/3.11.10/x64/bin/python self = def testPerAxisRoundtrip(self): ssq = StaticScaledQuantizer( name="poodoo", axis=1, scale=torch.tensor([0.2, 0.4, 0.8], dtype=torch.float32), dtype=torch.float16, ) ssq = self._roundtrip(ssq, "_ssq") self.assertEqual(ssq.axis, 1) torch.testing.assert_close( ssq.scale, torch.tensor([0.2, 0.4, 0.8], dtype=torch.float32) ) > torch.testing.assert_close(ssq.reciprocal_scale, torch.tensor([5.0, 2.5, 1.25])) E AssertionError: The values for attribute 'dtype' do not match: torch.float32 != torch.bfloat16. sharktank/tests/types/quantizers_test.py:78: AssertionError ``` --- sharktank/tests/types/quantizers_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sharktank/tests/types/quantizers_test.py b/sharktank/tests/types/quantizers_test.py index a5fbc2b95..c389357fa 100644 --- a/sharktank/tests/types/quantizers_test.py +++ b/sharktank/tests/types/quantizers_test.py @@ -75,7 +75,9 @@ def testPerAxisRoundtrip(self): torch.testing.assert_close( ssq.scale, torch.tensor([0.2, 0.4, 0.8], dtype=torch.float32) ) - torch.testing.assert_close(ssq.reciprocal_scale, torch.tensor([5.0, 2.5, 1.25])) + torch.testing.assert_close( + ssq.reciprocal_scale, torch.tensor([5.0, 2.5, 1.25], dtype=torch.float32) + ) self.assertIs(ssq.dtype, torch.float16) def testPerAxisQuantDequant(self):