diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 26b12814f..d5e3d06d9 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -49,10 +49,10 @@ def __init__( self.fp_internal_scale_min = StatelessBuffer( 1. - self.exponent_bias() - self.mantissa_bit_width()) - if float_scaling_impl is None: - float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) + if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) + # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl @@ -68,9 +68,10 @@ def internal_scale(self, x): @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scale_impl_value = self.scaling_impl( + scaling_impl_value = self.scaling_impl(x) + float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scale_impl_value / self.float_scaling_impl(x) + scale = scaling_impl_value / float_scaling_impl_value scaled_x = x / scale internal_scale = self.internal_scale(scaled_x) val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 2022ca1d9..0bc808e4e 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -40,18 +40,18 @@ def test_float_quant_defaults(minifloat_format): inf_values=None, nan_values=None, saturating=True) - scaling = FloatScaling(None, None, True) + float_scaling = FloatScaling(None, None, True) float_quant = FloatQuant( bit_width=bit_width, - scaling_impl=scaling, + float_scaling_impl=float_scaling, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, float_clamp_impl=float_clamp) assert isinstance(float_quant.float_to_int_impl, RoundSte) - assert isinstance(float_quant.float_scaling_impl, ConstScaling) - assert isinstance(float_quant.scaling_impl, FloatScaling) + assert isinstance(float_quant.float_scaling_impl, FloatScaling) + assert isinstance(float_quant.scaling_impl, ConstScaling) @given(minifloat_format=random_minifloat_format()) @@ -81,10 +81,10 @@ def test_float_to_quant_float(inp, minifloat_format): inf_values=None, nan_values=None, saturating=True) - scaling = FloatScaling(None, None, True) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) float_quant = FloatQuant( bit_width=bit_width, - scaling_impl=scaling, + float_scaling_impl=float_scaling_impl, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, @@ -103,8 +103,8 @@ def test_float_to_quant_float(inp, minifloat_format): @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format - scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) - float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) + scaling_impl = mock.Mock(side_effect=lambda x: 1.) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( @@ -135,11 +135,11 @@ def test_scaling_impls_called_once(inp, minifloat_format): float_clamp_impl=float_clamp) _ = float_quant.quantize(inp) # scaling implementations should be called exaclty once on the input - scaling_impl.assert_called_once_with( + float_scaling_impl.assert_called_once_with( torch.tensor(exponent_bit_width), torch.tensor(mantissa_bit_width), torch.tensor(exponent_bias)) - float_scaling_impl.assert_called_once_with(inp) + scaling_impl.assert_called_once_with(inp) @given( @@ -150,8 +150,8 @@ def test_scaling_impls_called_once(inp, minifloat_format): def test_inner_scale(inp, minifloat_format, scale): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here - scaling_impl = mock.Mock(side_effect=lambda x, y, z: scale) - float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) + scaling_impl = mock.Mock(side_effect=lambda x: scale) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant(