diff --git a/sharktank/tests/types/quantizers_test.py b/sharktank/tests/types/quantizers_test.py index a5fbc2b95..c389357fa 100644 --- a/sharktank/tests/types/quantizers_test.py +++ b/sharktank/tests/types/quantizers_test.py @@ -75,7 +75,9 @@ def testPerAxisRoundtrip(self): torch.testing.assert_close( ssq.scale, torch.tensor([0.2, 0.4, 0.8], dtype=torch.float32) ) - torch.testing.assert_close(ssq.reciprocal_scale, torch.tensor([5.0, 2.5, 1.25])) + torch.testing.assert_close( + ssq.reciprocal_scale, torch.tensor([5.0, 2.5, 1.25], dtype=torch.float32) + ) self.assertIs(ssq.dtype, torch.float16) def testPerAxisQuantDequant(self):