Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[sharktank] Fix flaxy quantizer test (attempt 2). (#641)
Follow-up to #631. The test is flaking at this line too: https://github.com/nod-ai/shark-ai/actions/runs/12148591430/job/33877368151?pr=639#step:6:25 ``` =================================== FAILURES =================================== ________________ StaticScaledQuantizerTest.testPerAxisRoundtrip ________________ [gw0] linux -- Python 3.11.10 /opt/hostedtoolcache/Python/3.11.10/x64/bin/python self = <tests.types.quantizers_test.StaticScaledQuantizerTest testMethod=testPerAxisRoundtrip> def testPerAxisRoundtrip(self): ssq = StaticScaledQuantizer( name="poodoo", axis=1, scale=torch.tensor([0.2, 0.4, 0.8], dtype=torch.float32), dtype=torch.float16, ) ssq = self._roundtrip(ssq, "_ssq") self.assertEqual(ssq.axis, 1) torch.testing.assert_close( ssq.scale, torch.tensor([0.2, 0.4, 0.8], dtype=torch.float32) ) > torch.testing.assert_close(ssq.reciprocal_scale, torch.tensor([5.0, 2.5, 1.25])) E AssertionError: The values for attribute 'dtype' do not match: torch.float32 != torch.bfloat16. sharktank/tests/types/quantizers_test.py:78: AssertionError ```
- Loading branch information