Skip to content

Commit

Permalink
Feat (FloatQuant): catch pointless exponent/mantissa bit widths
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 1, 2024
1 parent 6782af0 commit bf30126
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 68 deletions.
4 changes: 4 additions & 0 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ def __init__(
self.bit_width = StatelessBuffer(torch.tensor(float(bit_width), device=device, dtype=dtype))
self.signed: bool = signed
self.float_to_int_impl = float_to_int_impl
if exponent_bit_width == 0:
raise RuntimeError("Exponent bit width cannot be 0.")
self.exponent_bit_width = StatelessBuffer(
torch.tensor(float(exponent_bit_width), device=device, dtype=dtype))
if mantissa_bit_width == 0:
raise RuntimeError("Mantissa bit width cannot be 0.")
self.mantissa_bit_width = StatelessBuffer(
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype)))
if exponent_bias is None:
Expand Down
168 changes: 100 additions & 68 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
# SPDX-License-Identifier: BSD-3-Clause

from hypothesis import given
import hypothesis.strategies as st
import mock
import pytest
import torch

from brevitas.core.function_wrapper import RoundSte
from brevitas.core.quant.float import FloatQuant
from brevitas.core.scaling import ConstScaling
from tests.brevitas.core.bit_width_fixture import * # noqa
from tests.brevitas.core.int_quant_fixture import * # noqa
from tests.brevitas.core.shared_quant_fixture import * # noqa
from tests.brevitas.hyp_helper import float_st
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.brevitas.hyp_helper import random_minifloat_format
Expand All @@ -23,15 +20,23 @@ def test_float_quant_defaults(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
# specifically don't set exponent bias to see if default works
expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1
float_quant = FloatQuant(
bit_width=bit_width,
signed=signed,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width)
assert expected_exponent_bias == float_quant.exponent_bias()
assert isinstance(float_quant.float_to_int_impl, RoundSte)
assert isinstance(float_quant.float_scaling_impl, ConstScaling)
assert isinstance(float_quant.scaling_impl, ConstScaling)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
assert expected_exponent_bias == float_quant.exponent_bias()
assert isinstance(float_quant.float_to_int_impl, RoundSte)
assert isinstance(float_quant.float_scaling_impl, ConstScaling)
assert isinstance(float_quant.scaling_impl, ConstScaling)


@given(minifloat_format=random_minifloat_format())
Expand All @@ -44,36 +49,52 @@ def test_minifloat(minifloat_format):
@jit_disabled_for_mock()
def test_float_to_quant_float(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
exponent_bias = 2 ** (exponent_bit_width - 1) - 1
float_quant = FloatQuant(
bit_width=bit_width,
signed=signed,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias)
expected_out, _, _, bit_width_out = float_quant(inp)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed)
expected_out, _, _, bit_width_out = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
assert bit_width_out == bit_width
assert torch.equal(expected_out, out_quant * scale)
out_quant, scale = float_quant.quantize(inp)
assert bit_width_out == bit_width
assert torch.equal(expected_out, out_quant * scale)


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
def test_scaling_impls_called_once(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
scaling_impl = mock.Mock(side_effect=lambda x: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
output = float_quant.quantize(inp)
# scaling implementations should be called exaclty once on the input
scaling_impl.assert_called_once_with(inp)
float_scaling_impl.assert_called_once_with(inp)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
output = float_quant.quantize(inp)
# scaling implementations should be called exaclty once on the input
scaling_impl.assert_called_once_with(inp)
float_scaling_impl.assert_called_once_with(inp)


@given(
Expand All @@ -85,37 +106,48 @@ def test_inner_scale(inp, minifloat_format, scale):
# set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here
scaling_impl = mock.Mock(side_effect=lambda x: scale)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)

# scale inp manually
scaled_inp = inp / scale

# call internal scale
internal_scale = float_quant.internal_scale(scaled_inp)
val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale)
if signed:
val_fp_quant = torch.clip(
val_fp_quant, -1. * float_quant.fp_max_val(), float_quant.fp_max_val())
else:
val_fp_quant = torch.clip(val_fp_quant, 0., float_quant.fp_max_val())

# dequantize manually
out = val_fp_quant * scale

expected_out, expected_scale, _, _ = float_quant(inp)

assert scale == expected_scale
if scale == 0.0:
# outputs should only receive 0s or nan
assert torch.tensor([True if val == 0. or val.isnan() else False for val in out.flatten()
]).all()
assert torch.tensor([
True if val == 0. or val.isnan() else False for val in expected_out.flatten()]).all()
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
else:
assert torch.equal(out, expected_out)
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)

# scale inp manually
scaled_inp = inp / scale

# call internal scale
internal_scale = float_quant.internal_scale(scaled_inp)
val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale)
if signed:
val_fp_quant = torch.clip(
val_fp_quant, -1. * float_quant.fp_max_val(), float_quant.fp_max_val())
else:
val_fp_quant = torch.clip(val_fp_quant, 0., float_quant.fp_max_val())

# dequantize manually
out = val_fp_quant * scale

expected_out, expected_scale, _, _ = float_quant(inp)

assert scale == expected_scale
if scale == 0.0:
# outputs should only receive 0s or nan
assert torch.tensor([
True if val == 0. or val.isnan() else False for val in out.flatten()]).all()
assert torch.tensor([
True if val == 0. or val.isnan() else False for val in expected_out.flatten()
]).all()
else:
assert torch.equal(out, expected_out)

0 comments on commit bf30126

Please sign in to comment.