diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 5b582e195..195d42a96 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -52,9 +52,6 @@ def __init__( if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) - if float_scaling_impl is None: - float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) - # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl @@ -68,10 +65,13 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scaling_impl_value = self.scaling_impl(x) - float_scaling_impl_value = self.float_scaling_impl( - self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scaling_impl_value / float_scaling_impl_value + scale = self.scaling_impl(x) + + if self.float_scaling_impl is not None: + float_scaling_impl_value = self.float_scaling_impl( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + scale = scale / float_scaling_impl_value + scaled_x = x / scale internal_scale = float_internal_scale( scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)