Skip to content

Commit

Permalink
Fix (quant/float): correct scaling_impl and float_scaling_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 6, 2024
1 parent 0c52c9a commit 20510a9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
9 changes: 5 additions & 4 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __init__(

self.fp_internal_scale_min = StatelessBuffer(
1. - self.exponent_bias() - self.mantissa_bit_width())
if float_scaling_impl is None:
float_scaling_impl = ConstScaling(1., device=device, dtype=dtype)

if scaling_impl is None:
scaling_impl = ConstScaling(1., device=device, dtype=dtype)

# Zero-point is currently hardcoded to 0
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
Expand All @@ -68,9 +68,10 @@ def internal_scale(self, x):

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale_impl_value = self.scaling_impl(
scaling_impl_value = self.scaling_impl(x)
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale_impl_value / self.float_scaling_impl(x)
scale = scaling_impl_value / float_scaling_impl_value
scaled_x = x / scale
internal_scale = self.internal_scale(scaled_x)
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale)
Expand Down
24 changes: 12 additions & 12 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def test_float_quant_defaults(minifloat_format):
inf_values=None,
nan_values=None,
saturating=True)
scaling = FloatScaling(None, None, True)
float_scaling = FloatScaling(None, None, True)
float_quant = FloatQuant(
bit_width=bit_width,
scaling_impl=scaling,
float_scaling_impl=float_scaling,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed,
float_clamp_impl=float_clamp)
assert isinstance(float_quant.float_to_int_impl, RoundSte)
assert isinstance(float_quant.float_scaling_impl, ConstScaling)
assert isinstance(float_quant.scaling_impl, FloatScaling)
assert isinstance(float_quant.float_scaling_impl, FloatScaling)
assert isinstance(float_quant.scaling_impl, ConstScaling)


@given(minifloat_format=random_minifloat_format())
Expand Down Expand Up @@ -81,10 +81,10 @@ def test_float_to_quant_float(inp, minifloat_format):
inf_values=None,
nan_values=None,
saturating=True)
scaling = FloatScaling(None, None, True)
float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.)
float_quant = FloatQuant(
bit_width=bit_width,
scaling_impl=scaling,
float_scaling_impl=float_scaling_impl,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
Expand All @@ -103,8 +103,8 @@ def test_float_to_quant_float(inp, minifloat_format):
@jit_disabled_for_mock()
def test_scaling_impls_called_once(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
scaling_impl = mock.Mock(side_effect=lambda x: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
Expand Down Expand Up @@ -135,11 +135,11 @@ def test_scaling_impls_called_once(inp, minifloat_format):
float_clamp_impl=float_clamp)
_ = float_quant.quantize(inp)
# scaling implementations should be called exaclty once on the input
scaling_impl.assert_called_once_with(
float_scaling_impl.assert_called_once_with(
torch.tensor(exponent_bit_width),
torch.tensor(mantissa_bit_width),
torch.tensor(exponent_bias))
float_scaling_impl.assert_called_once_with(inp)
scaling_impl.assert_called_once_with(inp)


@given(
Expand All @@ -150,8 +150,8 @@ def test_scaling_impls_called_once(inp, minifloat_format):
def test_inner_scale(inp, minifloat_format, scale):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
# set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here
scaling_impl = mock.Mock(side_effect=lambda x, y, z: scale)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.)
scaling_impl = mock.Mock(side_effect=lambda x: scale)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
Expand Down

0 comments on commit 20510a9

Please sign in to comment.