diff --git a/src/brevitas/core/function_wrapper/__init__.py b/src/brevitas/core/function_wrapper/__init__.py index 9d5d929d6..3b3e5428b 100644 --- a/src/brevitas/core/function_wrapper/__init__.py +++ b/src/brevitas/core/function_wrapper/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from .clamp import ClampMin +from .clamp import FloatClamp from .clamp import ScalarClamp from .clamp import TensorClamp from .learned_round import LearnedRoundSte diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index c49eec726..0bfb79374 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -4,11 +4,14 @@ """ ScriptModule wrappers for various variants of clamping. """ +from typing import Optional, Tuple import torch from torch import Tensor +from torch.nn import Module import brevitas +from brevitas.core.utils import StatelessBuffer from brevitas.function import tensor_clamp @@ -73,3 +76,53 @@ def __init__(self, min_val: float) -> None: @brevitas.jit.script_method def forward(self, x: Tensor): return x.clamp_min(self.min_val) + + +class FloatClamp(brevitas.jit.ScriptModule): + """" + ScriptModule for clamping minifloat formats to their inf/NaN implementations. + + Currently, inf/NaN codes have to be encoded through the mantissa. + I.e. setting inf to 1101.111 (E4M3) is not a valid code. + """ + + __constants__ = ['saturating', 'has_inf_values'] + + def __init__( + self, + max_value: float, + tensor_clamp_impl: Module, + inf_values: Optional[Tuple[str]] = None, + saturating: bool = True) -> None: + super(FloatClamp, self).__init__() + + self.tensor_clamp_impl = tensor_clamp_impl + + self.max_value = StatelessBuffer(torch.tensor(max_value)) + self.saturating = saturating + self.has_inf_values = bool(inf_values) + + @brevitas.jit.script_method + def forward(self, x: Tensor): + inf_mask = x.isinf() + p_max_val_mask = x > self.max_value() + n_max_val_mask = -x > self.max_value() + + # first clamp everything to +- max_value, basically the saturating case + x = self.tensor_clamp_impl(x, min_val=-self.max_value(), max_val=self.max_value()) + + if not self.saturating: + # if non-saturating, we need to map values greater than max_val to nan or inf + if self.has_inf_values: + # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf + x[p_max_val_mask] = torch.tensor(float('inf')) + x[n_max_val_mask] = torch.tensor(float('-inf')) + else: + # no inf values, so we need to map them to NaN + full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) + x[full_max_val_mask] = torch.tensor(float('nan')) + + # we also map the inf values to NaN in this case + x[inf_mask] = torch.tensor(float('nan')) + + return x diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 2fad8782b..11da5864b 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -23,7 +23,8 @@ def __init__( signed: bool, exponent_bit_width: int, mantissa_bit_width: int, - exponent_bias: Optional[int] = None, + exponent_bias: int, + float_clamp_impl: nn.Module, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, float_to_int_impl: nn.Module = RoundSte(), @@ -43,8 +44,6 @@ def __init__( raise RuntimeError("Mantissa bit width cannot be 0.") self.mantissa_bit_width = StatelessBuffer( (torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype))) - if exponent_bias is None: - exponent_bias = 2 ** (exponent_bit_width - 1) - 1 self.exponent_bias = StatelessBuffer( torch.tensor(float(exponent_bias), device=device, dtype=dtype)) self.fp_max_val = StatelessBuffer( @@ -59,6 +58,7 @@ def __init__( self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl self.scaling_impl = scaling_impl + self.float_clamp_impl = float_clamp_impl @brevitas.jit.script_method def internal_scale(self, x): @@ -86,6 +86,8 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): y, scale = self.quantize(x) + # after quantizing, clamp to special cases like NaN/inf if they are set + y = self.float_clamp_impl(y) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 6751ab69c..7bbffaec7 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -5,6 +5,7 @@ Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ +from typing import Optional, Tuple import torch from torch import Tensor diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 8da867e27..61201578e 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.core.function_wrapper import FloatClamp +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector @@ -10,22 +12,44 @@ from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum +from brevitas.utils.float_quant_utils import get_max_value -class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum): - proxy_class = WeightQuantProxyFromInjector +class FloatBase(SolveTensorQuantFloatToIntImplFromEnum): tensor_quant = FloatQuant signed = True float_to_int_impl_type = 'round' scaling_min_val = 1e-10 + float_clamp_impl = FloatClamp + tensor_clamp_impl = TensorClamp + @value + def exponent_bias(exponent_bit_width): + return 2 ** (exponent_bit_width - 1) - 1 -class FloatActBase(SolveTensorQuantFloatToIntImplFromEnum): + @value + def max_value( + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + nan_values=None, + inf_values=None, + saturating=True): + return get_max_value( + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + nan_values, + inf_values, + saturating) + + +class FloatWeightBase(FloatBase): + proxy_class = WeightQuantProxyFromInjector + + +class FloatActBase(FloatBase): proxy_class = ActQuantProxyFromInjector - tensor_quant = FloatQuant - signed = True - float_to_int_impl_type = 'round' - scaling_min_val = 1e-10 class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver): @@ -43,20 +67,15 @@ class ScaledFloatActBase(FloatActBase, ActQuantSolver): float_scaling_impl = FloatScaling -class ExponentBiasMixin(ExtendedInjector): - - @value - def exponent_bias(exponent_bit_width): - return 2 ** (exponent_bit_width - 1) - 1 - - -class Fp8e4m3Mixin(ExponentBiasMixin): +class Fp8e4m3Mixin(ExtendedInjector): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 + saturating = True -class Fp8e5m2Mixin(ExponentBiasMixin): +class Fp8e5m2Mixin(ExtendedInjector): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 + saturating = True diff --git a/src/brevitas/quant/experimental/float_quant_ocp.py b/src/brevitas/quant/experimental/float_quant_ocp.py new file mode 100644 index 000000000..6dfda1304 --- /dev/null +++ b/src/brevitas/quant/experimental/float_quant_ocp.py @@ -0,0 +1,150 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.experimental.float_base import FloatActBase +from brevitas.quant.experimental.float_base import FloatWeightBase +from brevitas.quant.experimental.float_base import Fp8e4m3Mixin +from brevitas.quant.experimental.float_base import Fp8e5m2Mixin +from brevitas.quant.experimental.float_base import ScaledFloatActBase +from brevitas.quant.experimental.float_base import ScaledFloatWeightBase + + +class Fp8e4m3OCPMixin(Fp8e4m3Mixin): + nan_values = (('111',)) + inf_values = None + + +class Fp8e5m2OCPMixin(Fp8e5m2Mixin): + nan_values = ('01', '11', '10') + inf_values = (('00',)) + + +class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase): + """ + FP8 signed E3M4 weight quantizer. + """ + pass + + +class Fp8e5m2OCPWeight(Fp8e5m2OCPMixin, FloatWeightBase): + """ + FP8 signed E5M2 weight quantizer. + """ + pass + + +class Fp8e4m3OCPAct(Fp8e4m3OCPMixin, FloatActBase): + """ + FP8 signed E4M3 activation quantizer. + """ + pass + + +class Fp8e5m2OCPAct(Fp8e5m2OCPMixin, FloatActBase): + """ + FP8 signed E5M2 activation quantizer. + """ + pass + + +class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3OCPActPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2OCPActPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3OCPWeightPerChannelFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e5m2OCPWeightPerChannelFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3OCPActPerChannelFloat2d(Fp8e4m3OCPMixin, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e5m2OCPActPerChannelFloat2d(Fp8e5m2OCPMixin, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3OCPActPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2OCPActPerTensorFloatMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3OCPActPerChannelFloat2dMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e5m2OCPActPerChannelFloat2dMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3OCPWeightPerChannelFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3OCPWeightPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. + """ + scaling_per_output_channel = False diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py new file mode 100644 index 000000000..5d5c4037f --- /dev/null +++ b/src/brevitas/utils/float_quant_utils.py @@ -0,0 +1,69 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: + # computes the decimal place value from a given binary tensor + res = 1.0 + for i, val in enumerate(bits): + # iterating through from left to right + res += ((2 ** -(i + 1)) * float(val)) + if frexp_compatible: + return res / 2. + else: + return res + + +def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> float: + """ + Returns the minifloat value for a given exponent, mantissa and exponent_bias. + It expects the exponent and mantissa in their binary format. + """ + exponent_value = int(exponent, 2) + mantissa_value = mantissa_bits_to_float(mantissa) + return 2 ** (exponent_value - exponent_bias) * mantissa_value + + +def get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, saturating): + # Idea: take the smallest NaN/inf value, set max_value to the next smaller one + # inf without NaN not possible + if inf_values is None and nan_values is None: + # saturating has to be True if no NaN/inf value are used + assert saturating, 'cannot be non-saturating without NaN/inf values' + # no special cases, max_value is using all bits for exponent and mantissa + exponent = '1' * exponent_bit_width + mantissa = '1' * mantissa_bit_width + elif nan_values is not None: + # we at least have values for NaN, so initiate MaxValInfNaN + special_values = nan_values + inf_values if inf_values is not None else nan_values + + # check that NaN/inf values are all mantissa_bit_width long + if any(map(lambda x: len(x) > mantissa_bit_width, special_values)): + raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') + + # get the minimum special case, our max value is the next smaller value + min_special_case = min(map(lambda x: int(x, 2), special_values)) + + max_value_mantissa = min_special_case - 1 + + if max_value_mantissa < 0: + # all mantissa values are used, so we need to use decrease exponent values + exponent = '1' * (exponent_bit_width - 1) + # add trailing 0 to reach bit width + exponent += '0' + # since we decreased exponent, we can use full mantissa + mantissa = '1' * mantissa_bit_width + else: + # there is a free mantissa code, so use full exponent + exponent = '1' * exponent_bit_width + # get binary code for max_value_mantissa in the number of mantissa bits + mantissa = format(max_value_mantissa, f'0{mantissa_bit_width}b') + else: + # no NaN values but inf values + raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') + + # we don't need the sign since we're looking for the max value + max_value = get_minifloat_value( + exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias) + return max_value diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py new file mode 100644 index 000000000..681caf8ca --- /dev/null +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -0,0 +1,37 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest_cases +from pytest_cases import fixture_union + +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight + + +@pytest_cases.fixture +@pytest_cases.parametrize('sat', [True, False]) +def fp8e4m3(sat): + + class Fp8e4m3(Fp8e4m3OCPWeight): + saturating = sat + # for hypothesis and DI + hypothesis_internal_is_this_a_mock_check = True + + return Fp8e4m3 + + +@pytest_cases.fixture +@pytest_cases.parametrize('sat', [True, False]) +def fp8e5m2(sat): + + class Fp8e5m2(Fp8e5m2OCPWeight): + saturating = sat + # for hypothesis and DI + hypothesis_internal_is_this_a_mock_check = True + + return Fp8e5m2 + + +list_of_fixtures = ['fp8e4m3', 'fp8e5m2'] + +fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py new file mode 100644 index 000000000..5ba5a0a32 --- /dev/null +++ b/tests/brevitas/core/test_clamp.py @@ -0,0 +1,47 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from hypothesis import given +import pytest + +from brevitas.quant.experimental.float import Fp8e4m3Weight +from brevitas.quant.experimental.float import Fp8e5m2Weight +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight +from tests.brevitas.hyp_helper import float_tensor_random_shape_st + +from .minifloat_fixtures import * + +FORMAT_MAXVAL_MAP = { + Fp8e5m2OCPWeight: 57344., Fp8e4m3OCPWeight: 448., Fp8e4m3Weight: 480., Fp8e5m2Weight: 114688.} + + +@pytest.mark.parametrize( + 'minifloat, expected_max_val', + ((format, max_val) for format, max_val in FORMAT_MAXVAL_MAP.items())) +def test_max_value(minifloat, expected_max_val): + max_val = minifloat.float_clamp_impl.max_value() + + assert expected_max_val == max_val + + +@given(inp=float_tensor_random_shape_st()) +def test_float_clamp(inp, fp8_clamp): + max_val = fp8_clamp.float_clamp_impl.max_value() + # get values that exceed max_val + over_limit_mask = inp.abs() > max_val + + # clamp inp + inp = fp8_clamp.float_clamp_impl(inp) + + if fp8_clamp.float_clamp_impl.saturating: + # should be clamped to +- max val + assert (inp[over_limit_mask].abs() == max_val).all() + else: + # if inf_values, over limit mask should now be all inf + if fp8_clamp.float_clamp_impl.has_inf_values: + # all values exceeding max_val should be inf + assert inp[over_limit_mask].isinf().all() + else: + # all values should be NaN + assert inp[over_limit_mask].isnan().all() diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 608c58130..1e4058fb8 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -6,9 +6,12 @@ import pytest import torch +from brevitas.core.function_wrapper import FloatClamp from brevitas.core.function_wrapper import RoundSte +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling +from brevitas.utils.float_quant_utils import get_max_value from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format @@ -17,23 +20,29 @@ @given(minifloat_format=random_minifloat_format()) def test_float_quant_defaults(minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format - # specifically don't set exponent bias to see if default works - expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1 + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format + if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, - signed=signed) + exponent_bias=exponent_bias, + signed=signed, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, - signed=signed) - assert expected_exponent_bias == float_quant.exponent_bias() + exponent_bias=exponent_bias, + signed=signed, + float_clamp_impl=float_clamp) assert isinstance(float_quant.float_to_int_impl, RoundSte) assert isinstance(float_quant.float_scaling_impl, ConstScaling) assert isinstance(float_quant.scaling_impl, ConstScaling) @@ -41,26 +50,34 @@ def test_float_quant_defaults(minifloat_format): @given(minifloat_format=random_minifloat_format()) def test_minifloat(minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, _ = minifloat_format assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed) @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) def test_float_to_quant_float(inp, minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, - signed=signed) + exponent_bias=exponent_bias, + signed=signed, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, - signed=signed) + exponent_bias=exponent_bias, + signed=signed, + float_clamp_impl=float_clamp) expected_out, _, _, bit_width_out = float_quant(inp) out_quant, scale = float_quant.quantize(inp) @@ -71,7 +88,7 @@ def test_float_to_quant_float(inp, minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format scaling_impl = mock.Mock(side_effect=lambda x: 1.) float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) if exponent_bit_width == 0 or mantissa_bit_width == 0: @@ -80,17 +97,25 @@ def test_scaling_impls_called_once(inp, minifloat_format): bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=float_clamp) output = float_quant.quantize(inp) # scaling implementations should be called exaclty once on the input scaling_impl.assert_called_once_with(inp) @@ -103,7 +128,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): scale=float_st()) @jit_disabled_for_mock() def test_inner_scale(inp, minifloat_format, scale): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here scaling_impl = mock.Mock(side_effect=lambda x: scale) float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) @@ -113,17 +138,25 @@ def test_inner_scale(inp, minifloat_format, scale): bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=float_clamp) # scale inp manually scaled_inp = inp / scale diff --git a/tests/brevitas/hyp_helper.py b/tests/brevitas/hyp_helper.py index 45e72e52b..1a9157214 100644 --- a/tests/brevitas/hyp_helper.py +++ b/tests/brevitas/hyp_helper.py @@ -231,11 +231,14 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with= bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with)) exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width)) signed = draw(st.booleans()) + + exponent_bias = 2 ** (exponent_bit_width - 1) - 1 + # if no budget is left, return if bit_width == exponent_bit_width: - return bit_width, exponent_bit_width, 0, False + return bit_width, exponent_bit_width, 0, False, exponent_bias elif bit_width == (exponent_bit_width + int(signed)): - return bit_width, exponent_bit_width, 0, signed + return bit_width, exponent_bit_width, 0, signed, exponent_bias mantissa_bit_width = bit_width - exponent_bit_width - int(signed) - return bit_width, exponent_bit_width, mantissa_bit_width, signed + return bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias