From 1c3c997aac2859005f8bc80f15acf5355e44d012 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 5 Feb 2024 08:49:50 -0800 Subject: [PATCH] Fix (tests): fix filter for NaN/inf values --- src/brevitas/core/quant/float.py | 4 ---- tests/brevitas/core/test_float_quant.py | 6 +++++- tests/brevitas/hyp_helper.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 384e85ed6..2fad8782b 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -13,9 +13,6 @@ from brevitas.function.ops import max_float from brevitas.function.ops_ste import floor_ste -# max int that can be passed to torch.exp2() without running into inf -MAX_REPRESENTABLE_INT = 127 - class FloatQuant(brevitas.jit.ScriptModule): __constants__ = ['signed'] @@ -67,7 +64,6 @@ def __init__( def internal_scale(self, x): internal_scale = floor_ste(torch.log2(torch.abs(x))) - self.mantissa_bit_width() internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min()) - internal_scale = torch.clamp_max(internal_scale, MAX_REPRESENTABLE_INT) internal_scale = torch.exp2(internal_scale) return internal_scale diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 890102561..608c58130 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -151,4 +151,8 @@ def test_inner_scale(inp, minifloat_format, scale): True if val == 0. or val.isnan() else False for val in expected_out.flatten() ]).all() else: - assert torch.equal(out, expected_out) + # filter out NaN values as we can't compare them + # Note: this still checks if NaN appears at the same values + out_nans = out.isnan() + expected_out_nans = expected_out.isnan() + assert torch.equal(out[~out_nans], expected_out[~expected_out_nans]) diff --git a/tests/brevitas/hyp_helper.py b/tests/brevitas/hyp_helper.py index 9905d2f03..45e72e52b 100644 --- a/tests/brevitas/hyp_helper.py +++ b/tests/brevitas/hyp_helper.py @@ -227,6 +227,7 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with= """" Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed. """ + # TODO: add support for new minifloat format that comes with FloatQuantTensor bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with)) exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width)) signed = draw(st.booleans())