From 1b2a64be2e3cbf8d0e95e200db5f351796fad326 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Wed, 21 Feb 2024 05:41:42 -0800 Subject: [PATCH] Feat (minifloat): use tensor_clamp_impl for minifloat conversion and clean up --- src/brevitas/core/function_wrapper/clamp.py | 108 +++++++++--------- src/brevitas/core/quant/float.py | 8 +- src/brevitas/quant/experimental/float_base.py | 4 +- tests/brevitas/core/minifloat_fixtures.py | 7 +- tests/brevitas/core/test_minifloat.py | 10 +- 5 files changed, 67 insertions(+), 70 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 13b5da2a2..c919dabf0 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -8,6 +8,7 @@ import torch from torch import Tensor +from torch.nn import Module import brevitas from brevitas.core.utils import StatelessBuffer @@ -88,79 +89,70 @@ class FloatClamp(brevitas.jit.ScriptModule): I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ - __constants__ = ['nan_values', 'inf_values', 'saturating'] - def __init__( self, - exponent_bit_width: Tensor, - mantissa_bit_width: Tensor, - exponent_bias: Tensor, + exponent_bit_width: int, + mantissa_bit_width: int, + exponent_bias: int, + tensor_clamp_impl: Module = TensorClamp(), nan_values: Optional[Tuple[str]] = None, inf_values: Optional[Tuple[str]] = None, saturating: bool = False) -> None: super(FloatClamp, self).__init__() - self.exponent_bit_width = torch.tensor(exponent_bit_width) - self.mantissa_bit_width = torch.tensor(mantissa_bit_width) - self.exponent_bias = torch.tensor(exponent_bias) - - self.nan_values = nan_values - self.inf_values = inf_values - self.saturating = saturating - # inf without NaN not possible - if self.inf_values is None and self.nan_values is None: - self.max_val_impl = StatelessBuffer( - max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)) - elif self.nan_values is not None: + if inf_values is None and nan_values is None: + max_val_impl = StatelessBuffer( + max_float( + torch.tensor(exponent_bit_width), + torch.tensor(mantissa_bit_width), + torch.tensor(exponent_bias))) + elif nan_values is not None: # we at least have values for NaN, so initiate MaxValInfNaN - self.max_val_impl = MaxFloatInfNaN( - exponent_bit_width=self.exponent_bit_width, - mantissa_bit_width=self.mantissa_bit_width, - exponent_bias=self.exponent_bias, - nan_values=self.nan_values, - inf_values=self.inf_values, - saturating=self.saturating) + max_val_impl = MaxFloatInfNaN( + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, + nan_values=nan_values, + inf_values=inf_values) else: # no NaN values but inf values raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') - self.clamp_impl = CaseClamp(inf_values=self.inf_values, saturating=self.saturating) + # class for clamping to inf/NaN values + self.fpx_clamp_impl = FpXClamp( + inf_values=inf_values, saturating=saturating, tensor_clamp_impl=tensor_clamp_impl) + + # get max value for the minifloat config, no need to compute it during forward pass + self.max_value = max_val_impl() @brevitas.jit.script_method def forward(self, inp: Tensor): - # get max value for the minifloat config - max_value = self.max_val_impl() - return self.clamp_impl(inp, max_value) + return self.fpx_clamp_impl(inp, self.max_value) class MaxFloatInfNaN(brevitas.jit.ScriptModule): def __init__( self, - exponent_bit_width: Tensor, - mantissa_bit_width: Tensor, - exponent_bias: Tensor, + exponent_bit_width: int, + mantissa_bit_width: int, + exponent_bias: int, nan_values: Tuple[str], - inf_values: Tuple[str], - saturating: bool = False) -> None: + inf_values: Optional[Tuple[str]]) -> None: super(MaxFloatInfNaN, self).__init__() - self.exponent_bit_width = exponent_bit_width - self.mantissa_bit_width = mantissa_bit_width - self.exponent_bias = exponent_bias - - self.inf_values = inf_values - self.nan_values = nan_values - self._special_values = nan_values + inf_values if inf_values is not None else nan_values + self.exponent_bit_width = StatelessBuffer(torch.tensor(exponent_bit_width)) + self.mantissa_bit_width = StatelessBuffer(torch.tensor(mantissa_bit_width)) + self.exponent_bias = StatelessBuffer(torch.tensor(exponent_bias)) - self.saturating = saturating + _special_values = nan_values + inf_values if inf_values is not None else nan_values # check that NaN/inf values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, self._special_values)): + if any(map(lambda x: len(x) > mantissa_bit_width, _special_values)): raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') # move computation of min for forward pass here so it's jit compatible - self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), self._special_values))) + self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), _special_values))) @brevitas.jit.script_method def forward(self): @@ -169,29 +161,30 @@ def forward(self): if max_value_mantissa < 0: # all mantissa values are used, so we need to use decrease exponent values - exponent = torch.tensor(1).repeat(self.exponent_bit_width - 1) - exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype) - ]) # add trailing 0 to reach bit width + exponent = torch.tensor(1).repeat(self.exponent_bit_width() - 1) + # add trailing 0 to reach bit width + exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)]) # since we decreased exponent, we can use full mantissa - mantissa = torch.tensor(1).repeat(self.mantissa_bit_width) + mantissa = torch.tensor(1).repeat(self.mantissa_bit_width()) else: # there is a free mantissa code, so use full exponent - exponent = torch.tensor(1).repeat(self.exponent_bit_width) + exponent = torch.tensor(1).repeat(self.exponent_bit_width()) # get binary code for max_value_mantissa in the number of mantissa bits - mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width) + mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width()) # we don't need the sign since we're looking for the max value max_value = get_minifloat_value( - exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias) + exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias()) return max_value -class CaseClamp(brevitas.jit.ScriptModule): +class FpXClamp(brevitas.jit.ScriptModule): - def __init__(self, inf_values: Tuple[str], saturating: bool) -> None: - super(CaseClamp, self).__init__() + def __init__(self, inf_values: Tuple[str], saturating: bool, tensor_clamp_impl: Module) -> None: + super(FpXClamp, self).__init__() self.inf_values = inf_values self.saturating = saturating + self.tensor_clamp_impl = tensor_clamp_impl @brevitas.jit.script_method def forward(self, x: Tensor, max_value: Tensor): @@ -201,10 +194,11 @@ def forward(self, x: Tensor, max_value: Tensor): p_max_val_mask = x > max_value n_max_val_mask = -x > max_value - if self.saturating: - # clamp everything to +- max_value - x = x.clamp(-max_value, max_value) - else: + # first clamp everything to +- max_value, basically the saturating case + x = self.tensor_clamp_impl(x, min_val=-max_value, max_val=max_value) + + if not self.saturating: + # if non-saturating, we need to map values greater than max_val to nan or inf if self.inf_values is not None: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf x[p_max_val_mask] = torch.tensor(float('inf')) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index cf806b7fb..8a462fc0e 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -25,7 +25,7 @@ def __init__( exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: int, - case_clamp_impl: Optional[nn.Module] = None, + float_clamp_impl: Optional[nn.Module] = None, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, float_to_int_impl: nn.Module = RoundSte(), @@ -55,8 +55,8 @@ def __init__( float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) - if case_clamp_impl is None: - self.case_clamp_impl = FloatClamp( + if float_clamp_impl is None: + self.float_clamp_impl = FloatClamp( exponent_bit_width=self.exponent_bit_width(), mantissa_bit_width=self.mantissa_bit_width(), exponent_bias=self.exponent_bias()) @@ -92,7 +92,7 @@ def dequantize(self, y, scale): def forward(self, x): y, scale = self.quantize(x) # after quantizing, clamp to special cases like NaN/inf if they are set - y = self.case_clamp_impl(y) + y = self.float_clamp_impl(y) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 929b3e02d..8421c3c13 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -55,7 +55,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp nan_values = (('111',)) inf_values = None saturating = True @@ -65,7 +65,7 @@ class Fp8e5m2Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp nan_values = ('01', '11', '10') inf_values = (('00',)) saturating = True diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index 8ee882049..e0f7528a6 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -5,6 +5,7 @@ from pytest_cases import fixture_union from brevitas.core.function_wrapper import FloatClamp +from brevitas.inject.enum import BitWidthImplType from brevitas.quant.experimental.float_base import ExponentBiasMixin from brevitas.quant.experimental.float_base import ScaledFloatWeightBase @@ -13,7 +14,8 @@ class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp + bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False @@ -22,7 +24,8 @@ class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp + bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index b25296acb..a8b2f93b9 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -16,26 +16,26 @@ @pytest.mark.parametrize( 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items())) def test_max_value(minifloat, expected_max_val): - max_val = minifloat.case_clamp_impl.max_val_impl() + max_val = minifloat.float_clamp_impl.max_value assert expected_max_val == max_val @given(inp=float_tensor_random_shape_st()) def test_clamp(inp, fp8_clamp): - max_val = fp8_clamp.case_clamp_impl.max_val_impl() + max_val = fp8_clamp.float_clamp_impl.max_value # get values that exceed max_val over_limit_mask = inp.abs() > max_val # clamp inp - inp = fp8_clamp.case_clamp_impl(inp) + inp = fp8_clamp.float_clamp_impl(inp) - if fp8_clamp.case_clamp_impl.saturating: + if fp8_clamp.float_clamp_impl.fpx_clamp_impl.saturating: # should be clamped to +- max val assert (inp[over_limit_mask].abs() == max_val).all() else: # if inf_values, over limit mask should now be all inf - if fp8_clamp.case_clamp_impl.inf_values is not None: + if fp8_clamp.float_clamp_impl.fpx_clamp_impl.inf_values is not None: # all values exceeding max_val should be inf assert inp[over_limit_mask].isinf().all() else: