From 33257923f00a8b5019541a2b9a44a59d13fc8e85 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 10:58:24 +0100 Subject: [PATCH] missing fix --- src/brevitas/core/scaling/float_scaling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/scaling/float_scaling.py b/src/brevitas/core/scaling/float_scaling.py index 7cf99d73d..86451523b 100644 --- a/src/brevitas/core/scaling/float_scaling.py +++ b/src/brevitas/core/scaling/float_scaling.py @@ -9,6 +9,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float +from brevitas.utils.quant_utils import MAX_MANTISSA_DICT class FloatScaling(brevitas.jit.ScriptModule): @@ -25,6 +26,7 @@ def __init__( self.inf_values = inf_values self.nan_values = nan_values self.saturating = saturating + self.max_mantissa_dict = MAX_MANTISSA_DICT if max_available_float: max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) @@ -36,7 +38,8 @@ def __init__( def forward( self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor) -> Tensor: - max_value = max_float(exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) + max_value = max_float( + exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) return max_value