From bdd8acf189ce02013726cf139e342082f49c8f19 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 5 Feb 2024 09:52:38 -0800 Subject: [PATCH 01/15] Feat (FP8): implements conversion from fp32 to user specified fp8 format --- .../core/function_wrapper/__init__.py | 1 + src/brevitas/core/function_wrapper/clamp.py | 37 ++++++++++++ src/brevitas/core/quant/float.py | 4 ++ src/brevitas/function/ops.py | 56 +++++++++++++++++++ src/brevitas/quant/experimental/float_base.py | 11 ++++ 5 files changed, 109 insertions(+) diff --git a/src/brevitas/core/function_wrapper/__init__.py b/src/brevitas/core/function_wrapper/__init__.py index 9d5d929d6..3b3e5428b 100644 --- a/src/brevitas/core/function_wrapper/__init__.py +++ b/src/brevitas/core/function_wrapper/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from .clamp import ClampMin +from .clamp import FloatClamp from .clamp import ScalarClamp from .clamp import TensorClamp from .learned_round import LearnedRoundSte diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index c49eec726..0534aea12 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -9,6 +9,7 @@ from torch import Tensor import brevitas +from brevitas.function import clamp_to_fp_encoding from brevitas.function import tensor_clamp @@ -73,3 +74,39 @@ def __init__(self, min_val: float) -> None: @brevitas.jit.script_method def forward(self, x: Tensor): return x.clamp_min(self.min_val) + + +class FloatClamp(brevitas.jit.ScriptModule): + """" + ScriptModule for clamping minifloat formats to their inf/NaN implementations. + """ + + __constants__ = ['nan_value', 'inf_value', 'max_value', 'saturating'] + + def __init__(self, nan_value: str, inf_value: str, max_value: float, saturating: bool) -> None: + super(FloatClamp, self).__init__() + self.nan_value = self.mantissa_bits_to_float(nan_value) + self.inf_value = self.mantissa_bits_to_float(inf_value) if inf_value is not None else None + self.max_value = max_value + self.saturating = saturating + + def mantissa_bits_to_float(self, bits: str, frexp_compatible: bool = True) -> float: + res = 1.0 + for i, val in enumerate(bits): + # iterating through from left to right + res += ((2 ** -(i + 1)) * float(val)) + if frexp_compatible: + return res / 2. + else: + return res + + @brevitas.jit.script_method + def forward(self, x: Tensor, exponent_bit_width: int, mantissa_bit_width: int): + return clamp_to_fp_encoding( + x, + exponent_bit_width, + mantissa_bit_width, + nan_value=self.nan_value, + inf_value=self.inf_value, + max_value=self.max_value, + saturating=self.saturating) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 2fad8782b..5357da1bf 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -23,6 +23,7 @@ def __init__( signed: bool, exponent_bit_width: int, mantissa_bit_width: int, + case_clamp_impl: nn.Module, exponent_bias: Optional[int] = None, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, @@ -59,6 +60,7 @@ def __init__( self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl self.scaling_impl = scaling_impl + self.case_clamp_impl = case_clamp_impl @brevitas.jit.script_method def internal_scale(self, x): @@ -86,6 +88,8 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): y, scale = self.quantize(x) + # after quantizing, clamp to special cases like NaN, inf + y = self.case_clamp_impl(y, self.exponent_bit_width(), self.mantissa_bit_width()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 6751ab69c..fa2acf23f 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -11,6 +11,10 @@ import brevitas +P_INF_TENSOR = torch.tensor(float('inf')) +N_INF_TENSOR = torch.tensor(float('-inf')) +NAN_TENSOR = torch.tensor(float('nan')) + @brevitas.jit.script def binary_sign(x: Tensor) -> Tensor: @@ -217,3 +221,55 @@ def get_upper_bound_on_l1_norm( max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) return max_accumulator_mag * max_input_mag_inverse + + +def clamp_to_fp_encoding( + x: Tensor, + exponent_bit_width: int, + mantissa_bit_width: int, + nan_value: Tensor, + inf_value: Tensor, + max_value: Tensor, + saturating: bool): + """ + Clamp any values that exceed inf/NaN special codes to these. Differentiates between saturating + and non-saturating mode. + + nan_value needs to be set to the min NaN value there is. + """ + # TODO: question regarding inf/NaN values + max_exponent_value = 2 ** (exponent_bit_width - 1) + # decompose value + mantissa, exponent = torch.frexp(x) + # check if any of the exponent values are all 1s, i.e. equal the max_exponent value + exponent_mask = exponent - 1 >= max_exponent_value # - 1 because frexp returns exponent in range [-125, 128], but actual exponent bits are in range [-126, 127] + is_nan_mask = mantissa.abs() >= nan_value # nan_value is the min NaN value + is_inf_mask = mantissa == inf_value + full_nan_mask = torch.logical_and(exponent_mask, is_nan_mask) + full_inf_mask = torch.logical_and(exponent_mask, is_inf_mask) + if saturating: + # set all values of mantissa_nan_mask and exponent_mask to NaN + x[full_nan_mask] = NAN_TENSOR + + # set all inf_values to max_val + x[full_inf_mask] = max_value + + # clamp absolute values greater than max to +- max val + x = torch.clamp(x, -max_value, max_value) + + return x + else: + # in non saturating case, just set all exceeding values to nan + x[full_nan_mask] = NAN_TENSOR + if inf_value is None: + # we just set all values to NaN + x[full_inf_mask] = NAN_TENSOR + else: + # set inf values to +- infinity + x[full_inf_mask] = torch.where(x[full_inf_mask] > 0, P_INF_TENSOR, N_INF_TENSOR) + # clamp all values greater than max_value to +inf + x = torch.where(x > max_value, P_INF_TENSOR, x) + # clamp all values smaller than min_value to -inf + x = torch.where(x < -max_value, N_INF_TENSOR, x) + + return x diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 8da867e27..ceb20e9e7 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.core.function_wrapper import FloatClamp from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector @@ -54,9 +55,19 @@ class Fp8e4m3Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 + case_clamp_impl = FloatClamp + nan_value = '111' + inf_value = None + max_value = 448. + saturating = True class Fp8e5m2Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 + case_clamp_impl = FloatClamp + nan_value = '01' # smallest NaN value. Others are '11' and '10' + inf_value = '00' + max_value = 57344. + saturating = True From 3200de80882550c6d373110ff9ad7013b239ec52 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Tue, 6 Feb 2024 09:53:24 -0800 Subject: [PATCH 02/15] Feat (minifloat): calculates max_value based on user format --- src/brevitas/core/function_wrapper/clamp.py | 77 ++++++++++++++++--- src/brevitas/core/quant/float.py | 6 +- src/brevitas/function/ops.py | 61 ++++++--------- src/brevitas/quant/experimental/float_base.py | 10 +-- 4 files changed, 101 insertions(+), 53 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 0534aea12..6af1b2dd3 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -4,6 +4,7 @@ """ ScriptModule wrappers for various variants of clamping. """ +from typing import Tuple import torch from torch import Tensor @@ -79,18 +80,24 @@ def forward(self, x: Tensor): class FloatClamp(brevitas.jit.ScriptModule): """" ScriptModule for clamping minifloat formats to their inf/NaN implementations. + + Currently, inf/NaN codes have to be encoded through the mantissa. + I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ - __constants__ = ['nan_value', 'inf_value', 'max_value', 'saturating'] + __constants__ = ['nan_values', 'inf_values', 'saturating'] - def __init__(self, nan_value: str, inf_value: str, max_value: float, saturating: bool) -> None: + def __init__(self, nan_values: Tuple[str], inf_values: Tuple[str], saturating: bool) -> None: super(FloatClamp, self).__init__() - self.nan_value = self.mantissa_bits_to_float(nan_value) - self.inf_value = self.mantissa_bits_to_float(inf_value) if inf_value is not None else None - self.max_value = max_value + # TODO: check that NaN/inf values are all mantissa_bit_width long + self.nan_values = nan_values if nan_values is not None else tuple() + self.inf_values = inf_values if inf_values is not None else tuple() + # inf without NaN not possible + if self.inf_values is not None and self.nan_values is None: + raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') self.saturating = saturating - def mantissa_bits_to_float(self, bits: str, frexp_compatible: bool = True) -> float: + def mantissa_bits_to_float(self, bits: str, frexp_compatible: bool = False) -> float: res = 1.0 for i, val in enumerate(bits): # iterating through from left to right @@ -100,13 +107,63 @@ def mantissa_bits_to_float(self, bits: str, frexp_compatible: bool = True) -> fl else: return res + def get_minifloat_value( + self, + exponent_string: str, + mantissa_string: str, + exponent_bias: Tensor, + sign: str = '0') -> float: + exponent_value = int(exponent_string, 2) + mantissa_value = self.mantissa_bits_to_float(mantissa_string) + return ((-1) ** float(sign)) * 2 ** (exponent_value - exponent_bias) * mantissa_value + + def get_max_value( + self, exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: Tensor) -> float: + # calculate max possible value for this specific format + if not self.nan_values and not self.inf_values: + # we don't have any codes, so just return max possible value + exponent_string = '1' * exponent_bit_width + mantissa_string = '1' * mantissa_bit_width + else: + # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 + min_special_case = min(map(lambda x: int(x, 2), self.nan_values + self.inf_values)) + max_value_mantissa = min_special_case - 1 + if max_value_mantissa < 0: + # all mantissa values are used, so we need to use decrease exponent values + exponent_string = '1' * (exponent_bit_width - 1) + exponent_string += '0' # add trailing 0 to reach bit width + # since we decreased exponent, we can use full mantissa + mantissa_string = '1' * mantissa_bit_width + else: + # there is a free mantissa code, so use full exponent + exponent_string = '1' * exponent_bit_width + # get binary code for max_value_mantissa in the number of mantissa bits + mantissa_string = format(max_value_mantissa, f'0{mantissa_bit_width}b') + + # we don't need the sign since we're looking for the max value + max_value = self.get_minifloat_value( + exponent_string=exponent_string, + mantissa_string=mantissa_string, + exponent_bias=exponent_bias) + return max_value + @brevitas.jit.script_method - def forward(self, x: Tensor, exponent_bit_width: int, mantissa_bit_width: int): + def forward( + self, + x: Tensor, + exponent_bit_width: Tensor, + mantissa_bit_width: Tensor, + exponent_bias: Tensor): + max_value = self.get_max_value( + exponent_bit_width=exponent_bit_width.int().item(), + mantissa_bit_width=mantissa_bit_width.int().item(), + exponent_bias=exponent_bias) + # TODO: at this time, we just pass the codes for inf/NaN, we might need to change that return clamp_to_fp_encoding( x, exponent_bit_width, mantissa_bit_width, - nan_value=self.nan_value, - inf_value=self.inf_value, - max_value=self.max_value, + nan_values=self.nan_values, + inf_values=self.inf_values, + max_value=max_value, saturating=self.saturating) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 5357da1bf..375a640f5 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -89,7 +89,11 @@ def dequantize(self, y, scale): def forward(self, x): y, scale = self.quantize(x) # after quantizing, clamp to special cases like NaN, inf - y = self.case_clamp_impl(y, self.exponent_bit_width(), self.mantissa_bit_width()) + y = self.case_clamp_impl( + y, + exponent_bit_width=self.exponent_bit_width(), + mantissa_bit_width=self.mantissa_bit_width(), + exponent_bias=self.exponent_bias()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index fa2acf23f..41afeb01a 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -5,6 +5,7 @@ Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ +from typing import Tuple import torch from torch import Tensor @@ -227,8 +228,8 @@ def clamp_to_fp_encoding( x: Tensor, exponent_bit_width: int, mantissa_bit_width: int, - nan_value: Tensor, - inf_value: Tensor, + nan_values: Tuple[float], + inf_values: Tuple[float], max_value: Tensor, saturating: bool): """ @@ -237,39 +238,27 @@ def clamp_to_fp_encoding( nan_value needs to be set to the min NaN value there is. """ - # TODO: question regarding inf/NaN values - max_exponent_value = 2 ** (exponent_bit_width - 1) - # decompose value - mantissa, exponent = torch.frexp(x) - # check if any of the exponent values are all 1s, i.e. equal the max_exponent value - exponent_mask = exponent - 1 >= max_exponent_value # - 1 because frexp returns exponent in range [-125, 128], but actual exponent bits are in range [-126, 127] - is_nan_mask = mantissa.abs() >= nan_value # nan_value is the min NaN value - is_inf_mask = mantissa == inf_value - full_nan_mask = torch.logical_and(exponent_mask, is_nan_mask) - full_inf_mask = torch.logical_and(exponent_mask, is_inf_mask) - if saturating: - # set all values of mantissa_nan_mask and exponent_mask to NaN - x[full_nan_mask] = NAN_TENSOR - - # set all inf_values to max_val - x[full_inf_mask] = max_value - - # clamp absolute values greater than max to +- max val - x = torch.clamp(x, -max_value, max_value) + # TODO: think about setting NaN/inf values to the specific minifloat code, that's also why not all arguments are used at this time + # NaN values all stay at NaN, so no need to do anything with NaN values + # get all positive inf values + inf_mask = x.isinf() + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value - return x + if saturating: + # clamp everything to +- max_value + x = x.clamp(-max_value, max_value) else: - # in non saturating case, just set all exceeding values to nan - x[full_nan_mask] = NAN_TENSOR - if inf_value is None: - # we just set all values to NaN - x[full_inf_mask] = NAN_TENSOR - else: - # set inf values to +- infinity - x[full_inf_mask] = torch.where(x[full_inf_mask] > 0, P_INF_TENSOR, N_INF_TENSOR) - # clamp all values greater than max_value to +inf - x = torch.where(x > max_value, P_INF_TENSOR, x) - # clamp all values smaller than min_value to -inf - x = torch.where(x < -max_value, N_INF_TENSOR, x) - - return x + if inf_values: + # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf + x[p_max_val_mask] = P_INF_TENSOR + x[n_max_val_mask] = N_INF_TENSOR + if not inf_values: + # no inf values, so we need to map them to NaN + full_max_val_mask = torch.logical_and(p_max_val_mask, n_max_val_mask) + x[full_max_val_mask] = NAN_TENSOR + + # we also map the inf values to NaN in this case + x[inf_mask] = NAN_TENSOR + + return x diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index ceb20e9e7..70910456d 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -56,9 +56,8 @@ class Fp8e4m3Mixin(ExponentBiasMixin): exponent_bit_width = 4 mantissa_bit_width = 3 case_clamp_impl = FloatClamp - nan_value = '111' - inf_value = None - max_value = 448. + nan_values = tuple(('111',)) + inf_values = None saturating = True @@ -67,7 +66,6 @@ class Fp8e5m2Mixin(ExponentBiasMixin): exponent_bit_width = 5 mantissa_bit_width = 2 case_clamp_impl = FloatClamp - nan_value = '01' # smallest NaN value. Others are '11' and '10' - inf_value = '00' - max_value = 57344. + nan_values = ('01', '11', '10') + inf_values = tuple(('00',)) saturating = True From 23961b64e23232a951237fe2af8f78540ea623e7 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 9 Feb 2024 10:33:15 -0800 Subject: [PATCH 03/15] Tests (minifloat): adds tests and fixes bugs --- src/brevitas/core/function_wrapper/clamp.py | 47 ++++++---------- src/brevitas/core/quant/float.py | 15 +++--- src/brevitas/function/ops.py | 18 +++---- src/brevitas/utils/float_quant_utils.py | 25 +++++++++ tests/brevitas/core/minifloat_fixtures.py | 55 +++++++++++++++++++ tests/brevitas/core/test_minifloat.py | 60 +++++++++++++++++++++ 6 files changed, 173 insertions(+), 47 deletions(-) create mode 100644 src/brevitas/utils/float_quant_utils.py create mode 100644 tests/brevitas/core/minifloat_fixtures.py create mode 100644 tests/brevitas/core/test_minifloat.py diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 6af1b2dd3..c9c5af375 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -12,6 +12,7 @@ import brevitas from brevitas.function import clamp_to_fp_encoding from brevitas.function import tensor_clamp +from brevitas.utils.float_quant_utils import get_minifloat_value class TensorClamp(brevitas.jit.ScriptModule): @@ -87,7 +88,8 @@ class FloatClamp(brevitas.jit.ScriptModule): __constants__ = ['nan_values', 'inf_values', 'saturating'] - def __init__(self, nan_values: Tuple[str], inf_values: Tuple[str], saturating: bool) -> None: + def __init__( + self, nan_values: Tuple[str], inf_values: Tuple[str], saturating: bool = False) -> None: super(FloatClamp, self).__init__() # TODO: check that NaN/inf values are all mantissa_bit_width long self.nan_values = nan_values if nan_values is not None else tuple() @@ -97,28 +99,11 @@ def __init__(self, nan_values: Tuple[str], inf_values: Tuple[str], saturating: b raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') self.saturating = saturating - def mantissa_bits_to_float(self, bits: str, frexp_compatible: bool = False) -> float: - res = 1.0 - for i, val in enumerate(bits): - # iterating through from left to right - res += ((2 ** -(i + 1)) * float(val)) - if frexp_compatible: - return res / 2. - else: - return res - - def get_minifloat_value( - self, - exponent_string: str, - mantissa_string: str, - exponent_bias: Tensor, - sign: str = '0') -> float: - exponent_value = int(exponent_string, 2) - mantissa_value = self.mantissa_bits_to_float(mantissa_string) - return ((-1) ** float(sign)) * 2 ** (exponent_value - exponent_bias) * mantissa_value - def get_max_value( - self, exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: Tensor) -> float: + self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor, + exponent_bias: Tensor) -> float: + exponent_bit_width = exponent_bit_width.int().item() + mantissa_bit_width = mantissa_bit_width.int().item() # calculate max possible value for this specific format if not self.nan_values and not self.inf_values: # we don't have any codes, so just return max possible value @@ -141,7 +126,7 @@ def get_max_value( mantissa_string = format(max_value_mantissa, f'0{mantissa_bit_width}b') # we don't need the sign since we're looking for the max value - max_value = self.get_minifloat_value( + max_value = get_minifloat_value( exponent_string=exponent_string, mantissa_string=mantissa_string, exponent_bias=exponent_bias) @@ -155,15 +140,15 @@ def forward( mantissa_bit_width: Tensor, exponent_bias: Tensor): max_value = self.get_max_value( - exponent_bit_width=exponent_bit_width.int().item(), - mantissa_bit_width=mantissa_bit_width.int().item(), + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias) # TODO: at this time, we just pass the codes for inf/NaN, we might need to change that return clamp_to_fp_encoding( - x, - exponent_bit_width, - mantissa_bit_width, - nan_values=self.nan_values, - inf_values=self.inf_values, + x=x, max_value=max_value, - saturating=self.saturating) + saturating=self.saturating, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + nan_values=self.nan_values, + inf_values=self.inf_values) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 375a640f5..bab61ec1f 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -23,7 +23,7 @@ def __init__( signed: bool, exponent_bit_width: int, mantissa_bit_width: int, - case_clamp_impl: nn.Module, + case_clamp_impl: Optional[nn.Module] = None, exponent_bias: Optional[int] = None, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, @@ -88,12 +88,13 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): y, scale = self.quantize(x) - # after quantizing, clamp to special cases like NaN, inf - y = self.case_clamp_impl( - y, - exponent_bit_width=self.exponent_bit_width(), - mantissa_bit_width=self.mantissa_bit_width(), - exponent_bias=self.exponent_bias()) + # after quantizing, clamp to special cases like NaN/inf if they are set + if self.case_clamp_impl is not None: + y = self.case_clamp_impl( + y, + exponent_bit_width=self.exponent_bit_width(), + mantissa_bit_width=self.mantissa_bit_width(), + exponent_bias=self.exponent_bias()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 41afeb01a..bddac59eb 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -5,7 +5,7 @@ Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ -from typing import Tuple +from typing import Optional, Tuple import torch from torch import Tensor @@ -226,12 +226,12 @@ def get_upper_bound_on_l1_norm( def clamp_to_fp_encoding( x: Tensor, - exponent_bit_width: int, - mantissa_bit_width: int, - nan_values: Tuple[float], - inf_values: Tuple[float], max_value: Tensor, - saturating: bool): + saturating: bool, + exponent_bit_width: Optional[Tensor] = None, + mantissa_bit_width: Optional[Tensor] = None, + nan_values: Optional[Tuple[float]] = None, + inf_values: Optional[Tuple[float]] = None): """ Clamp any values that exceed inf/NaN special codes to these. Differentiates between saturating and non-saturating mode. @@ -249,13 +249,13 @@ def clamp_to_fp_encoding( # clamp everything to +- max_value x = x.clamp(-max_value, max_value) else: - if inf_values: + if len(inf_values) > 0: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf x[p_max_val_mask] = P_INF_TENSOR x[n_max_val_mask] = N_INF_TENSOR - if not inf_values: + else: # no inf values, so we need to map them to NaN - full_max_val_mask = torch.logical_and(p_max_val_mask, n_max_val_mask) + full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) x[full_max_val_mask] = NAN_TENSOR # we also map the inf values to NaN in this case diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py new file mode 100644 index 000000000..493546c0d --- /dev/null +++ b/src/brevitas/utils/float_quant_utils.py @@ -0,0 +1,25 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from torch import Tensor + + +def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: + res = 1.0 + for i, val in enumerate(bits): + # iterating through from left to right + res += ((2 ** -(i + 1)) * float(val)) + if frexp_compatible: + return res / 2. + else: + return res + + +def get_minifloat_value( + exponent_string: str, + mantissa_string: str, + exponent_bias: Tensor, + sign: str = '0') -> float: + exponent_value = int(exponent_string, 2) + mantissa_value = mantissa_bits_to_float(mantissa_string) + return ((-1) ** float(sign)) * 2 ** (exponent_value - exponent_bias) * mantissa_value diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py new file mode 100644 index 000000000..ed0f3025f --- /dev/null +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -0,0 +1,55 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest_cases +from pytest_cases import fixture_union + +from brevitas.core.function_wrapper import FloatClamp +from brevitas.quant.experimental.float_base import ExponentBiasMixin + + +class Fp8e4m3Base(ExponentBiasMixin): + bit_width = 8 + exponent_bit_width = 4 + mantissa_bit_width = 3 + case_clamp_impl = FloatClamp + nan_values = tuple(('111',)) + inf_values = None + # hypothesis extra + hypothesis_internal_is_this_a_mock_check = False + + +class Fp8e5m2Base(ExponentBiasMixin): + bit_width = 8 + exponent_bit_width = 5 + mantissa_bit_width = 2 + case_clamp_impl = FloatClamp + nan_values = ('01', '11', '10') + inf_values = tuple(('00',)) + # hypothesis extra + hypothesis_internal_is_this_a_mock_check = False + + +@pytest_cases.fixture +@pytest_cases.parametrize('sat', [True, False]) +def fp8e4m3(sat): + + class Fp8e4m3(Fp8e4m3Base): + saturating = sat + + return Fp8e4m3 + + +@pytest_cases.fixture +@pytest_cases.parametrize('sat', [True, False]) +def fp8e5m2(sat): + + class Fp8e5m2(Fp8e5m2Base): + saturating = sat + + return Fp8e5m2 + + +list_of_fixtures = ['fp8e4m3', 'fp8e5m2'] + +fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py new file mode 100644 index 000000000..934336548 --- /dev/null +++ b/tests/brevitas/core/test_minifloat.py @@ -0,0 +1,60 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from hypothesis import given +import pytest +import torch + +from brevitas.quant.experimental.float_base import Fp8e4m3Mixin +from brevitas.quant.experimental.float_base import Fp8e5m2Mixin +from tests.brevitas.hyp_helper import float_tensor_random_shape_st + +from .minifloat_fixtures import * + +FORMATS = {Fp8e5m2Mixin: 57344., Fp8e4m3Mixin: 448.} + + +@pytest.mark.parametrize( + 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items())) +def test_max_value(minifloat, expected_max_val): + # minifloat_format, expected_max_val = format + exponent_bit_width = torch.tensor(minifloat.exponent_bit_width) + mantissa_bit_width = torch.tensor(minifloat.mantissa_bit_width) + exponent_bias = torch.tensor(minifloat.exponent_bias) + + max_val = minifloat.case_clamp_impl.get_max_value( + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias) + + assert expected_max_val == max_val + + +@given(inp=float_tensor_random_shape_st()) +def test_clamp(inp, fp8_clamp): + # construct tensor which exceeds max val + exponent_bit_width = torch.tensor(fp8_clamp.exponent_bit_width) + mantissa_bit_width = torch.tensor(fp8_clamp.mantissa_bit_width) + exponent_bias = torch.tensor(fp8_clamp.exponent_bias) + + max_val = fp8_clamp.case_clamp_impl.get_max_value( + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias) + # get values that exceed max_val + over_limit_mask = inp.abs() > max_val + + # clamp inp + inp = fp8_clamp.case_clamp_impl(inp, exponent_bit_width, mantissa_bit_width, exponent_bias) + + if fp8_clamp.case_clamp_impl.saturating: + # should be clamped to +- max val + assert (inp[over_limit_mask].abs() == max_val).all() + else: + # if inf_values, over limit mask should now be all inf + if len(fp8_clamp.case_clamp_impl.inf_values) > 0: + # all values exceeding max_val should be inf + assert inp[over_limit_mask].isinf().all() + else: + # all values should be NaN + assert inp[over_limit_mask].isnan().all() From 21cadb7d56d9130cf4e1a96b166c896547dda8b0 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 12 Feb 2024 07:44:26 -0800 Subject: [PATCH 04/15] Fix (minifloat): restructure FloatClamp and max val calculation --- src/brevitas/core/function_wrapper/clamp.py | 142 +++++++++++++------- tests/brevitas/core/minifloat_fixtures.py | 5 +- tests/brevitas/core/test_minifloat.py | 22 +-- 3 files changed, 99 insertions(+), 70 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index c9c5af375..7e4d6e278 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -10,7 +10,9 @@ from torch import Tensor import brevitas +from brevitas.core.utils import StatelessBuffer from brevitas.function import clamp_to_fp_encoding +from brevitas.function import max_float from brevitas.function import tensor_clamp from brevitas.utils.float_quant_utils import get_minifloat_value @@ -86,69 +88,111 @@ class FloatClamp(brevitas.jit.ScriptModule): I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ - __constants__ = ['nan_values', 'inf_values', 'saturating'] + __constants__ = [ + 'exponent_bit_width', + 'mantissa_bit_width', + 'exponent_bias', + 'nan_values', + 'inf_values', + 'saturating'] def __init__( - self, nan_values: Tuple[str], inf_values: Tuple[str], saturating: bool = False) -> None: + self, + exponent_bit_width: Tensor, + mantissa_bit_width: Tensor, + exponent_bias: Tensor, + nan_values: Tuple[str], + inf_values: Tuple[str], + saturating: bool = False) -> None: super(FloatClamp, self).__init__() - # TODO: check that NaN/inf values are all mantissa_bit_width long + + self.exponent_bit_width = exponent_bit_width + self.mantissa_bit_width = mantissa_bit_width + self.exponent_bias = exponent_bias + self.nan_values = nan_values if nan_values is not None else tuple() self.inf_values = inf_values if inf_values is not None else tuple() + self.saturating = saturating + + # check that NaN/inf values are all mantissa_bit_width long + if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values + self.inf_values)): + raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') + # inf without NaN not possible - if self.inf_values is not None and self.nan_values is None: + if len(self.inf_values) == 0 and len(self.nan_values) == 0: + self.max_val_impl = StatelessBuffer( + max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)) + elif len(self.inf_values) > 0 and len(self.nan_values) == 0: raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') + else: + # we at least have values for NaN, so initiate MaxValInfNaN + self.max_val_impl = MaxFloatInfNaN( + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + exponent_bias=self.exponent_bias, + nan_values=self.nan_values, + inf_values=self.inf_values, + saturating=self.saturating) + + @brevitas.jit.script_method + def forward(self, inp: Tensor): + # get max value for the minifloat config + max_value = self.max_val_impl() + # TODO: change this to a class? + return clamp_to_fp_encoding( + x=inp, + max_value=max_value, + saturating=self.saturating, + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + nan_values=self.nan_values, + inf_values=self.inf_values) + + +class MaxFloatInfNaN(brevitas.jit.ScriptModule): + + def __init__( + self, + exponent_bit_width: Tensor, + mantissa_bit_width: Tensor, + exponent_bias: Tensor, + nan_values: Tuple[str], + inf_values: Tuple[str], + saturating: bool = False) -> None: + super(MaxFloatInfNaN, self).__init__() + self.exponent_bit_width = exponent_bit_width + self.mantissa_bit_width = mantissa_bit_width + self.exponent_bias = exponent_bias + + self.inf_values = inf_values + self.nan_values = nan_values + self.saturating = saturating - def get_max_value( - self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor, - exponent_bias: Tensor) -> float: - exponent_bit_width = exponent_bit_width.int().item() - mantissa_bit_width = mantissa_bit_width.int().item() - # calculate max possible value for this specific format - if not self.nan_values and not self.inf_values: - # we don't have any codes, so just return max possible value - exponent_string = '1' * exponent_bit_width + @brevitas.jit.script_method + def forward(self): + exponent_bit_width = self.exponent_bit_width + mantissa_bit_width = self.mantissa_bit_width + + # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 + min_special_case = min(map(lambda x: int(x, 2), self.nan_values + self.inf_values)) + max_value_mantissa = min_special_case - 1 + + if max_value_mantissa < 0: + # all mantissa values are used, so we need to use decrease exponent values + exponent_string = '1' * (exponent_bit_width - 1) + exponent_string += '0' # add trailing 0 to reach bit width + # since we decreased exponent, we can use full mantissa mantissa_string = '1' * mantissa_bit_width else: - # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 - min_special_case = min(map(lambda x: int(x, 2), self.nan_values + self.inf_values)) - max_value_mantissa = min_special_case - 1 - if max_value_mantissa < 0: - # all mantissa values are used, so we need to use decrease exponent values - exponent_string = '1' * (exponent_bit_width - 1) - exponent_string += '0' # add trailing 0 to reach bit width - # since we decreased exponent, we can use full mantissa - mantissa_string = '1' * mantissa_bit_width - else: - # there is a free mantissa code, so use full exponent - exponent_string = '1' * exponent_bit_width - # get binary code for max_value_mantissa in the number of mantissa bits - mantissa_string = format(max_value_mantissa, f'0{mantissa_bit_width}b') + # there is a free mantissa code, so use full exponent + exponent_string = '1' * exponent_bit_width + # get binary code for max_value_mantissa in the number of mantissa bits + mantissa_string = format(max_value_mantissa, f'0{mantissa_bit_width}b') # we don't need the sign since we're looking for the max value max_value = get_minifloat_value( exponent_string=exponent_string, mantissa_string=mantissa_string, - exponent_bias=exponent_bias) + exponent_bias=self.exponent_bias) return max_value - - @brevitas.jit.script_method - def forward( - self, - x: Tensor, - exponent_bit_width: Tensor, - mantissa_bit_width: Tensor, - exponent_bias: Tensor): - max_value = self.get_max_value( - exponent_bit_width=exponent_bit_width, - mantissa_bit_width=mantissa_bit_width, - exponent_bias=exponent_bias) - # TODO: at this time, we just pass the codes for inf/NaN, we might need to change that - return clamp_to_fp_encoding( - x=x, - max_value=max_value, - saturating=self.saturating, - exponent_bit_width=exponent_bit_width, - mantissa_bit_width=mantissa_bit_width, - nan_values=self.nan_values, - inf_values=self.inf_values) diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index ed0f3025f..72e9d9ca6 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -6,9 +6,10 @@ from brevitas.core.function_wrapper import FloatClamp from brevitas.quant.experimental.float_base import ExponentBiasMixin +from brevitas.quant.experimental.float_base import ScaledFloatWeightBase -class Fp8e4m3Base(ExponentBiasMixin): +class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 @@ -19,7 +20,7 @@ class Fp8e4m3Base(ExponentBiasMixin): hypothesis_internal_is_this_a_mock_check = False -class Fp8e5m2Base(ExponentBiasMixin): +class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index 934336548..949190c78 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -17,35 +17,19 @@ @pytest.mark.parametrize( 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items())) def test_max_value(minifloat, expected_max_val): - # minifloat_format, expected_max_val = format - exponent_bit_width = torch.tensor(minifloat.exponent_bit_width) - mantissa_bit_width = torch.tensor(minifloat.mantissa_bit_width) - exponent_bias = torch.tensor(minifloat.exponent_bias) - - max_val = minifloat.case_clamp_impl.get_max_value( - exponent_bit_width=exponent_bit_width, - mantissa_bit_width=mantissa_bit_width, - exponent_bias=exponent_bias) + max_val = minifloat.case_clamp_impl.max_val_impl() assert expected_max_val == max_val @given(inp=float_tensor_random_shape_st()) def test_clamp(inp, fp8_clamp): - # construct tensor which exceeds max val - exponent_bit_width = torch.tensor(fp8_clamp.exponent_bit_width) - mantissa_bit_width = torch.tensor(fp8_clamp.mantissa_bit_width) - exponent_bias = torch.tensor(fp8_clamp.exponent_bias) - - max_val = fp8_clamp.case_clamp_impl.get_max_value( - exponent_bit_width=exponent_bit_width, - mantissa_bit_width=mantissa_bit_width, - exponent_bias=exponent_bias) + max_val = fp8_clamp.case_clamp_impl.max_val_impl() # get values that exceed max_val over_limit_mask = inp.abs() > max_val # clamp inp - inp = fp8_clamp.case_clamp_impl(inp, exponent_bit_width, mantissa_bit_width, exponent_bias) + inp = fp8_clamp.case_clamp_impl(inp) if fp8_clamp.case_clamp_impl.saturating: # should be clamped to +- max val From f0676f8da471561544a0b960e963e40bca7f5c82 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Wed, 14 Feb 2024 08:24:05 -0800 Subject: [PATCH 05/15] Fix (minifloat): restructure MaxValue and clamp computation --- src/brevitas/core/function_wrapper/clamp.py | 149 ++++++++++++++---- src/brevitas/core/quant/float.py | 14 +- src/brevitas/function/ops.py | 44 ------ src/brevitas/quant/experimental/float_base.py | 4 +- tests/brevitas/core/test_minifloat.py | 3 +- 5 files changed, 127 insertions(+), 87 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 7e4d6e278..01d507b66 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -4,18 +4,21 @@ """ ScriptModule wrappers for various variants of clamping. """ -from typing import Tuple +from typing import Optional, Tuple import torch from torch import Tensor import brevitas from brevitas.core.utils import StatelessBuffer -from brevitas.function import clamp_to_fp_encoding from brevitas.function import max_float from brevitas.function import tensor_clamp from brevitas.utils.float_quant_utils import get_minifloat_value +P_INF_TENSOR = torch.tensor(float('inf')) +N_INF_TENSOR = torch.tensor(float('-inf')) +NAN_TENSOR = torch.tensor(float('nan')) + class TensorClamp(brevitas.jit.ScriptModule): """ @@ -101,8 +104,8 @@ def __init__( exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor, - nan_values: Tuple[str], - inf_values: Tuple[str], + nan_values: Optional[Tuple[str]] = None, + inf_values: Optional[Tuple[str]] = None, saturating: bool = False) -> None: super(FloatClamp, self).__init__() @@ -110,22 +113,17 @@ def __init__( self.mantissa_bit_width = mantissa_bit_width self.exponent_bias = exponent_bias - self.nan_values = nan_values if nan_values is not None else tuple() - self.inf_values = inf_values if inf_values is not None else tuple() + self.nan_values = nan_values + self.inf_values = inf_values self.saturating = saturating - # check that NaN/inf values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values + self.inf_values)): - raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') - # inf without NaN not possible - if len(self.inf_values) == 0 and len(self.nan_values) == 0: + if self.inf_values is None and self.nan_values is None: self.max_val_impl = StatelessBuffer( - max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)) - elif len(self.inf_values) > 0 and len(self.nan_values) == 0: - raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') - else: - # we at least have values for NaN, so initiate MaxValInfNaN + max_float( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())) + elif self.inf_values is not None and self.nan_values is not None: + # we have values for NaN and inf, so initiate MaxValInfNaN self.max_val_impl = MaxFloatInfNaN( exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, @@ -133,20 +131,25 @@ def __init__( nan_values=self.nan_values, inf_values=self.inf_values, saturating=self.saturating) + elif self.inf_values is None and self.nan_values is not None: + # we only have values for NaN, so initiate MaxValNaN + self.max_val_impl = MaxFloatNaN( + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + exponent_bias=self.exponent_bias, + nan_values=self.nan_values, + saturating=self.saturating) + else: + # no NaN values but inf values + raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') + + self.clamp_impl = CaseClamp(inf_values=self.inf_values, saturating=self.saturating) @brevitas.jit.script_method def forward(self, inp: Tensor): # get max value for the minifloat config max_value = self.max_val_impl() - # TODO: change this to a class? - return clamp_to_fp_encoding( - x=inp, - max_value=max_value, - saturating=self.saturating, - exponent_bit_width=self.exponent_bit_width, - mantissa_bit_width=self.mantissa_bit_width, - nan_values=self.nan_values, - inf_values=self.inf_values) + return self.clamp_impl(inp, max_value) class MaxFloatInfNaN(brevitas.jit.ScriptModule): @@ -169,26 +172,27 @@ def __init__( self.saturating = saturating + # check that NaN/inf values are all mantissa_bit_width long + if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values + self.inf_values)): + raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') + @brevitas.jit.script_method def forward(self): - exponent_bit_width = self.exponent_bit_width - mantissa_bit_width = self.mantissa_bit_width - # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 min_special_case = min(map(lambda x: int(x, 2), self.nan_values + self.inf_values)) max_value_mantissa = min_special_case - 1 if max_value_mantissa < 0: # all mantissa values are used, so we need to use decrease exponent values - exponent_string = '1' * (exponent_bit_width - 1) + exponent_string = '1' * (self.exponent_bit_width - 1) exponent_string += '0' # add trailing 0 to reach bit width # since we decreased exponent, we can use full mantissa - mantissa_string = '1' * mantissa_bit_width + mantissa_string = '1' * self.mantissa_bit_width else: # there is a free mantissa code, so use full exponent - exponent_string = '1' * exponent_bit_width + exponent_string = '1' * self.exponent_bit_width # get binary code for max_value_mantissa in the number of mantissa bits - mantissa_string = format(max_value_mantissa, f'0{mantissa_bit_width}b') + mantissa_string = format(max_value_mantissa, f'0{self.mantissa_bit_width}b') # we don't need the sign since we're looking for the max value max_value = get_minifloat_value( @@ -196,3 +200,84 @@ def forward(self): mantissa_string=mantissa_string, exponent_bias=self.exponent_bias) return max_value + + +class MaxFloatNaN(brevitas.jit.ScriptModule): + + def __init__( + self, + exponent_bit_width: Tensor, + mantissa_bit_width: Tensor, + exponent_bias: Tensor, + nan_values: Tuple[str], + saturating: bool = False) -> None: + super(MaxFloatNaN, self).__init__() + self.exponent_bit_width = exponent_bit_width + self.mantissa_bit_width = mantissa_bit_width + self.exponent_bias = exponent_bias + + self.nan_values = nan_values + self.saturating = saturating + + # check that NaN values are all mantissa_bit_width long + if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values)): + raise RuntimeError('NaN codes need to be the same length as the mantissa.') + + @brevitas.jit.script_method + def forward(self): + # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 + min_special_case = min(map(lambda x: int(x, 2), self.nan_values)) + max_value_mantissa = min_special_case - 1 + + if max_value_mantissa < 0: + # all mantissa values are used, so we need to use decrease exponent values + exponent_string = '1' * (self.exponent_bit_width - 1) + exponent_string += '0' # add trailing 0 to reach bit width + # since we decreased exponent, we can use full mantissa + mantissa_string = '1' * self.mantissa_bit_width + else: + # there is a free mantissa code, so use full exponent + exponent_string = '1' * self.exponent_bit_width + # get binary code for max_value_mantissa in the number of mantissa bits + mantissa_string = format(max_value_mantissa, f'0{self.mantissa_bit_width}b') + + # we don't need the sign since we're looking for the max value + max_value = get_minifloat_value( + exponent_string=exponent_string, + mantissa_string=mantissa_string, + exponent_bias=self.exponent_bias) + return max_value + + +class CaseClamp(brevitas.jit.ScriptModule): + + def __init__(self, inf_values: Tuple[str], saturating: bool) -> None: + super(CaseClamp, self).__init__() + self.inf_values = inf_values + self.saturating = saturating + + @brevitas.jit.script_method + def forward(self, x: Tensor, max_value: Tensor): + # NaN values all stay at NaN, so no need to do anything with NaN values + # get all positive inf values + inf_mask = x.isinf() + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value + + if self.saturating: + # clamp everything to +- max_value + x = x.clamp(-max_value, max_value) + else: + if self.inf_values is not None: + # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf + x[p_max_val_mask] = P_INF_TENSOR + x[n_max_val_mask] = N_INF_TENSOR + else: + # no inf values, so we need to map them to NaN + full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) + x[full_max_val_mask] = NAN_TENSOR + + # we also map the inf values to NaN in this case + x[inf_mask] = NAN_TENSOR + + return x diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index bab61ec1f..0de9711c4 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -7,6 +7,7 @@ import torch.nn as nn import brevitas +from brevitas.core.function_wrapper import FloatClamp from brevitas.core.function_wrapper import RoundSte from brevitas.core.scaling import ConstScaling from brevitas.core.utils import StatelessBuffer @@ -56,11 +57,15 @@ def __init__( float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) + if case_clamp_impl is None: + self.case_clamp_impl = FloatClamp( + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + exponent_bias=self.exponent_bias) # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl self.scaling_impl = scaling_impl - self.case_clamp_impl = case_clamp_impl @brevitas.jit.script_method def internal_scale(self, x): @@ -89,12 +94,7 @@ def dequantize(self, y, scale): def forward(self, x): y, scale = self.quantize(x) # after quantizing, clamp to special cases like NaN/inf if they are set - if self.case_clamp_impl is not None: - y = self.case_clamp_impl( - y, - exponent_bit_width=self.exponent_bit_width(), - mantissa_bit_width=self.mantissa_bit_width(), - exponent_bias=self.exponent_bias()) + y = self.case_clamp_impl(y) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index bddac59eb..7bbffaec7 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -12,10 +12,6 @@ import brevitas -P_INF_TENSOR = torch.tensor(float('inf')) -N_INF_TENSOR = torch.tensor(float('-inf')) -NAN_TENSOR = torch.tensor(float('nan')) - @brevitas.jit.script def binary_sign(x: Tensor) -> Tensor: @@ -222,43 +218,3 @@ def get_upper_bound_on_l1_norm( max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) return max_accumulator_mag * max_input_mag_inverse - - -def clamp_to_fp_encoding( - x: Tensor, - max_value: Tensor, - saturating: bool, - exponent_bit_width: Optional[Tensor] = None, - mantissa_bit_width: Optional[Tensor] = None, - nan_values: Optional[Tuple[float]] = None, - inf_values: Optional[Tuple[float]] = None): - """ - Clamp any values that exceed inf/NaN special codes to these. Differentiates between saturating - and non-saturating mode. - - nan_value needs to be set to the min NaN value there is. - """ - # TODO: think about setting NaN/inf values to the specific minifloat code, that's also why not all arguments are used at this time - # NaN values all stay at NaN, so no need to do anything with NaN values - # get all positive inf values - inf_mask = x.isinf() - p_max_val_mask = x > max_value - n_max_val_mask = -x > max_value - - if saturating: - # clamp everything to +- max_value - x = x.clamp(-max_value, max_value) - else: - if len(inf_values) > 0: - # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf - x[p_max_val_mask] = P_INF_TENSOR - x[n_max_val_mask] = N_INF_TENSOR - else: - # no inf values, so we need to map them to NaN - full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) - x[full_max_val_mask] = NAN_TENSOR - - # we also map the inf values to NaN in this case - x[inf_mask] = NAN_TENSOR - - return x diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 70910456d..929b3e02d 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -56,7 +56,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin): exponent_bit_width = 4 mantissa_bit_width = 3 case_clamp_impl = FloatClamp - nan_values = tuple(('111',)) + nan_values = (('111',)) inf_values = None saturating = True @@ -67,5 +67,5 @@ class Fp8e5m2Mixin(ExponentBiasMixin): mantissa_bit_width = 2 case_clamp_impl = FloatClamp nan_values = ('01', '11', '10') - inf_values = tuple(('00',)) + inf_values = (('00',)) saturating = True diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index 949190c78..58c4038c7 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -3,7 +3,6 @@ from hypothesis import given import pytest -import torch from brevitas.quant.experimental.float_base import Fp8e4m3Mixin from brevitas.quant.experimental.float_base import Fp8e5m2Mixin @@ -36,7 +35,7 @@ def test_clamp(inp, fp8_clamp): assert (inp[over_limit_mask].abs() == max_val).all() else: # if inf_values, over limit mask should now be all inf - if len(fp8_clamp.case_clamp_impl.inf_values) > 0: + if fp8_clamp.case_clamp_impl.inf_values is not None: # all values exceeding max_val should be inf assert inp[over_limit_mask].isinf().all() else: From a8cbd4a21243ea5c67171a2941037e30d54f4f30 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 15 Feb 2024 15:03:22 -0800 Subject: [PATCH 06/15] remove second max val class --- src/brevitas/core/function_wrapper/clamp.py | 64 ++------------------- src/brevitas/core/quant/float.py | 4 +- tests/brevitas/core/test_float_quant.py | 19 ++++-- tests/brevitas/hyp_helper.py | 9 ++- 4 files changed, 26 insertions(+), 70 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 01d507b66..c2a47acbc 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -122,8 +122,8 @@ def __init__( self.max_val_impl = StatelessBuffer( max_float( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())) - elif self.inf_values is not None and self.nan_values is not None: - # we have values for NaN and inf, so initiate MaxValInfNaN + elif self.nan_values is not None: + # we at least have values for NaN, so initiate MaxValInfNaN self.max_val_impl = MaxFloatInfNaN( exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, @@ -131,14 +131,6 @@ def __init__( nan_values=self.nan_values, inf_values=self.inf_values, saturating=self.saturating) - elif self.inf_values is None and self.nan_values is not None: - # we only have values for NaN, so initiate MaxValNaN - self.max_val_impl = MaxFloatNaN( - exponent_bit_width=self.exponent_bit_width, - mantissa_bit_width=self.mantissa_bit_width, - exponent_bias=self.exponent_bias, - nan_values=self.nan_values, - saturating=self.saturating) else: # no NaN values but inf values raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') @@ -169,64 +161,18 @@ def __init__( self.inf_values = inf_values self.nan_values = nan_values + self.__special_values = nan_values + inf_values if inf_values is not None else nan_values self.saturating = saturating # check that NaN/inf values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values + self.inf_values)): + if any(map(lambda x: len(x) > mantissa_bit_width, self.__special_values)): raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') @brevitas.jit.script_method def forward(self): # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 - min_special_case = min(map(lambda x: int(x, 2), self.nan_values + self.inf_values)) - max_value_mantissa = min_special_case - 1 - - if max_value_mantissa < 0: - # all mantissa values are used, so we need to use decrease exponent values - exponent_string = '1' * (self.exponent_bit_width - 1) - exponent_string += '0' # add trailing 0 to reach bit width - # since we decreased exponent, we can use full mantissa - mantissa_string = '1' * self.mantissa_bit_width - else: - # there is a free mantissa code, so use full exponent - exponent_string = '1' * self.exponent_bit_width - # get binary code for max_value_mantissa in the number of mantissa bits - mantissa_string = format(max_value_mantissa, f'0{self.mantissa_bit_width}b') - - # we don't need the sign since we're looking for the max value - max_value = get_minifloat_value( - exponent_string=exponent_string, - mantissa_string=mantissa_string, - exponent_bias=self.exponent_bias) - return max_value - - -class MaxFloatNaN(brevitas.jit.ScriptModule): - - def __init__( - self, - exponent_bit_width: Tensor, - mantissa_bit_width: Tensor, - exponent_bias: Tensor, - nan_values: Tuple[str], - saturating: bool = False) -> None: - super(MaxFloatNaN, self).__init__() - self.exponent_bit_width = exponent_bit_width - self.mantissa_bit_width = mantissa_bit_width - self.exponent_bias = exponent_bias - - self.nan_values = nan_values - self.saturating = saturating - - # check that NaN values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values)): - raise RuntimeError('NaN codes need to be the same length as the mantissa.') - - @brevitas.jit.script_method - def forward(self): - # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 - min_special_case = min(map(lambda x: int(x, 2), self.nan_values)) + min_special_case = min(map(lambda x: int(x, 2), self.__special_values)) max_value_mantissa = min_special_case - 1 if max_value_mantissa < 0: diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 0de9711c4..8ea60e27f 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -24,8 +24,8 @@ def __init__( signed: bool, exponent_bit_width: int, mantissa_bit_width: int, + exponent_bias: int, case_clamp_impl: Optional[nn.Module] = None, - exponent_bias: Optional[int] = None, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, float_to_int_impl: nn.Module = RoundSte(), @@ -45,8 +45,6 @@ def __init__( raise RuntimeError("Mantissa bit width cannot be 0.") self.mantissa_bit_width = StatelessBuffer( (torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype))) - if exponent_bias is None: - exponent_bias = 2 ** (exponent_bit_width - 1) - 1 self.exponent_bias = StatelessBuffer( torch.tensor(float(exponent_bias), device=device, dtype=dtype)) self.fp_max_val = StatelessBuffer( diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 608c58130..2e51f1db3 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -17,7 +17,8 @@ @given(minifloat_format=random_minifloat_format()) def test_float_quant_defaults(minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format + # specifically don't set exponent bias to see if default works expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1 if exponent_bit_width == 0 or mantissa_bit_width == 0: @@ -26,12 +27,14 @@ def test_float_quant_defaults(minifloat_format): bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed) else: float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed) assert expected_exponent_bias == float_quant.exponent_bias() assert isinstance(float_quant.float_to_int_impl, RoundSte) @@ -41,25 +44,27 @@ def test_float_quant_defaults(minifloat_format): @given(minifloat_format=random_minifloat_format()) def test_minifloat(minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, _ = minifloat_format assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed) @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) def test_float_to_quant_float(inp, minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed) else: float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed) expected_out, _, _, bit_width_out = float_quant(inp) @@ -71,7 +76,7 @@ def test_float_to_quant_float(inp, minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format scaling_impl = mock.Mock(side_effect=lambda x: 1.) float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) if exponent_bit_width == 0 or mantissa_bit_width == 0: @@ -80,6 +85,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl) @@ -88,6 +94,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl) @@ -103,7 +110,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): scale=float_st()) @jit_disabled_for_mock() def test_inner_scale(inp, minifloat_format, scale): - bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here scaling_impl = mock.Mock(side_effect=lambda x: scale) float_scaling_impl = mock.Mock(side_effect=lambda x: 1.) @@ -113,6 +120,7 @@ def test_inner_scale(inp, minifloat_format, scale): bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl) @@ -121,6 +129,7 @@ def test_inner_scale(inp, minifloat_format, scale): bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl) diff --git a/tests/brevitas/hyp_helper.py b/tests/brevitas/hyp_helper.py index 45e72e52b..1a9157214 100644 --- a/tests/brevitas/hyp_helper.py +++ b/tests/brevitas/hyp_helper.py @@ -231,11 +231,14 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with= bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with)) exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width)) signed = draw(st.booleans()) + + exponent_bias = 2 ** (exponent_bit_width - 1) - 1 + # if no budget is left, return if bit_width == exponent_bit_width: - return bit_width, exponent_bit_width, 0, False + return bit_width, exponent_bit_width, 0, False, exponent_bias elif bit_width == (exponent_bit_width + int(signed)): - return bit_width, exponent_bit_width, 0, signed + return bit_width, exponent_bit_width, 0, signed, exponent_bias mantissa_bit_width = bit_width - exponent_bit_width - int(signed) - return bit_width, exponent_bit_width, mantissa_bit_width, signed + return bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias From 1322fe88e1db85dbe437afd84dd09cdb23e311cc Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 16 Feb 2024 07:40:26 -0800 Subject: [PATCH 07/15] Fix (minifloat): make MaxFloatInfNaN jit compatible --- src/brevitas/core/function_wrapper/clamp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index c2a47acbc..784e0ad4d 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -169,11 +169,13 @@ def __init__( if any(map(lambda x: len(x) > mantissa_bit_width, self.__special_values)): raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') + # move computation of min for forward pass here so it's jit compatible + self.__min_special_case = min(map(lambda x: int(x, 2), self.__special_values)) + @brevitas.jit.script_method def forward(self): # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 - min_special_case = min(map(lambda x: int(x, 2), self.__special_values)) - max_value_mantissa = min_special_case - 1 + max_value_mantissa = self.__min_special_case - 1 if max_value_mantissa < 0: # all mantissa values are used, so we need to use decrease exponent values From 62fefe5b7082d3070c843b9f2cc5c3dc5b79ec36 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 19 Feb 2024 03:04:25 -0800 Subject: [PATCH 08/15] Fix (minifloat): fix private attr for jit and make hypothesis tests work with jit --- src/brevitas/core/function_wrapper/clamp.py | 8 ++++---- tests/brevitas/core/test_float_quant.py | 2 ++ tests/brevitas/core/test_minifloat.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 784e0ad4d..ff2882d01 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -161,21 +161,21 @@ def __init__( self.inf_values = inf_values self.nan_values = nan_values - self.__special_values = nan_values + inf_values if inf_values is not None else nan_values + self._special_values = nan_values + inf_values if inf_values is not None else nan_values self.saturating = saturating # check that NaN/inf values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, self.__special_values)): + if any(map(lambda x: len(x) > mantissa_bit_width, self._special_values)): raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') # move computation of min for forward pass here so it's jit compatible - self.__min_special_case = min(map(lambda x: int(x, 2), self.__special_values)) + self._min_special_case = min(map(lambda x: int(x, 2), self._special_values)) @brevitas.jit.script_method def forward(self): # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 - max_value_mantissa = self.__min_special_case - 1 + max_value_mantissa = self._min_special_case - 1 if max_value_mantissa < 0: # all mantissa values are used, so we need to use decrease exponent values diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 2e51f1db3..ec7536bf6 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -16,6 +16,7 @@ @given(minifloat_format=random_minifloat_format()) +@jit_disabled_for_mock() def test_float_quant_defaults(minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format @@ -49,6 +50,7 @@ def test_minifloat(minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) +@jit_disabled_for_mock() def test_float_to_quant_float(inp, minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format if exponent_bit_width == 0 or mantissa_bit_width == 0: diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index 58c4038c7..3cdcf7bbb 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -7,6 +7,7 @@ from brevitas.quant.experimental.float_base import Fp8e4m3Mixin from brevitas.quant.experimental.float_base import Fp8e5m2Mixin from tests.brevitas.hyp_helper import float_tensor_random_shape_st +from tests.marker import jit_disabled_for_mock from .minifloat_fixtures import * @@ -22,6 +23,7 @@ def test_max_value(minifloat, expected_max_val): @given(inp=float_tensor_random_shape_st()) +@jit_disabled_for_mock() def test_clamp(inp, fp8_clamp): max_val = fp8_clamp.case_clamp_impl.max_val_impl() # get values that exceed max_val From 4927a9f24edd8cdebc841c316e1c30d5714c90d5 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 19 Feb 2024 08:50:35 -0800 Subject: [PATCH 09/15] Fix (minifloat): fix jit issues with FloatClamp --- src/brevitas/core/function_wrapper/clamp.py | 47 ++++++++------------- src/brevitas/core/quant/float.py | 6 +-- src/brevitas/utils/float_quant_utils.py | 36 +++++++++++----- tests/brevitas/core/minifloat_fixtures.py | 35 ++++++++++++--- tests/brevitas/core/test_float_quant.py | 2 - tests/brevitas/core/test_minifloat.py | 4 +- 6 files changed, 76 insertions(+), 54 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index ff2882d01..13b5da2a2 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -13,12 +13,9 @@ from brevitas.core.utils import StatelessBuffer from brevitas.function import max_float from brevitas.function import tensor_clamp +from brevitas.utils.float_quant_utils import dec_to_bits from brevitas.utils.float_quant_utils import get_minifloat_value -P_INF_TENSOR = torch.tensor(float('inf')) -N_INF_TENSOR = torch.tensor(float('-inf')) -NAN_TENSOR = torch.tensor(float('nan')) - class TensorClamp(brevitas.jit.ScriptModule): """ @@ -91,13 +88,7 @@ class FloatClamp(brevitas.jit.ScriptModule): I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ - __constants__ = [ - 'exponent_bit_width', - 'mantissa_bit_width', - 'exponent_bias', - 'nan_values', - 'inf_values', - 'saturating'] + __constants__ = ['nan_values', 'inf_values', 'saturating'] def __init__( self, @@ -109,9 +100,9 @@ def __init__( saturating: bool = False) -> None: super(FloatClamp, self).__init__() - self.exponent_bit_width = exponent_bit_width - self.mantissa_bit_width = mantissa_bit_width - self.exponent_bias = exponent_bias + self.exponent_bit_width = torch.tensor(exponent_bit_width) + self.mantissa_bit_width = torch.tensor(mantissa_bit_width) + self.exponent_bias = torch.tensor(exponent_bias) self.nan_values = nan_values self.inf_values = inf_values @@ -120,8 +111,7 @@ def __init__( # inf without NaN not possible if self.inf_values is None and self.nan_values is None: self.max_val_impl = StatelessBuffer( - max_float( - self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())) + max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)) elif self.nan_values is not None: # we at least have values for NaN, so initiate MaxValInfNaN self.max_val_impl = MaxFloatInfNaN( @@ -170,7 +160,7 @@ def __init__( raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') # move computation of min for forward pass here so it's jit compatible - self._min_special_case = min(map(lambda x: int(x, 2), self._special_values)) + self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), self._special_values))) @brevitas.jit.script_method def forward(self): @@ -179,21 +169,20 @@ def forward(self): if max_value_mantissa < 0: # all mantissa values are used, so we need to use decrease exponent values - exponent_string = '1' * (self.exponent_bit_width - 1) - exponent_string += '0' # add trailing 0 to reach bit width + exponent = torch.tensor(1).repeat(self.exponent_bit_width - 1) + exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype) + ]) # add trailing 0 to reach bit width # since we decreased exponent, we can use full mantissa - mantissa_string = '1' * self.mantissa_bit_width + mantissa = torch.tensor(1).repeat(self.mantissa_bit_width) else: # there is a free mantissa code, so use full exponent - exponent_string = '1' * self.exponent_bit_width + exponent = torch.tensor(1).repeat(self.exponent_bit_width) # get binary code for max_value_mantissa in the number of mantissa bits - mantissa_string = format(max_value_mantissa, f'0{self.mantissa_bit_width}b') + mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width) # we don't need the sign since we're looking for the max value max_value = get_minifloat_value( - exponent_string=exponent_string, - mantissa_string=mantissa_string, - exponent_bias=self.exponent_bias) + exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias) return max_value @@ -218,14 +207,14 @@ def forward(self, x: Tensor, max_value: Tensor): else: if self.inf_values is not None: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf - x[p_max_val_mask] = P_INF_TENSOR - x[n_max_val_mask] = N_INF_TENSOR + x[p_max_val_mask] = torch.tensor(float('inf')) + x[n_max_val_mask] = torch.tensor(float('-inf')) else: # no inf values, so we need to map them to NaN full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask) - x[full_max_val_mask] = NAN_TENSOR + x[full_max_val_mask] = torch.tensor(float('nan')) # we also map the inf values to NaN in this case - x[inf_mask] = NAN_TENSOR + x[inf_mask] = torch.tensor(float('nan')) return x diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 8ea60e27f..cf806b7fb 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -57,9 +57,9 @@ def __init__( scaling_impl = ConstScaling(1., device=device, dtype=dtype) if case_clamp_impl is None: self.case_clamp_impl = FloatClamp( - exponent_bit_width=self.exponent_bit_width, - mantissa_bit_width=self.mantissa_bit_width, - exponent_bias=self.exponent_bias) + exponent_bit_width=self.exponent_bit_width(), + mantissa_bit_width=self.mantissa_bit_width(), + exponent_bias=self.exponent_bias()) # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py index 493546c0d..33515f95e 100644 --- a/src/brevitas/utils/float_quant_utils.py +++ b/src/brevitas/utils/float_quant_utils.py @@ -1,25 +1,41 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch from torch import Tensor -def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: +def mantissa_bits_to_float(bits: Tensor, frexp_compatible: bool = False) -> float: + # computes the decimal place value from a given binary tensor res = 1.0 for i, val in enumerate(bits): # iterating through from left to right - res += ((2 ** -(i + 1)) * float(val)) + res += ((2 ** -(i + 1)) * val) if frexp_compatible: return res / 2. else: return res -def get_minifloat_value( - exponent_string: str, - mantissa_string: str, - exponent_bias: Tensor, - sign: str = '0') -> float: - exponent_value = int(exponent_string, 2) - mantissa_value = mantissa_bits_to_float(mantissa_string) - return ((-1) ** float(sign)) * 2 ** (exponent_value - exponent_bias) * mantissa_value +def get_minifloat_value(exponent: Tensor, mantissa: Tensor, exponent_bias: Tensor) -> Tensor: + """ + Returns the minifloat value for a given exponent, mantissa and exponent_bias. + It expects the exponent and mantissa in their binary format. + """ + exponent_value = bits_to_dec(exponent) + mantissa_value = mantissa_bits_to_float(mantissa) + return torch.exp2(exponent_value - exponent_bias) * mantissa_value + + +def dec_to_bits(value: Tensor, bits: int) -> Tensor: + # set up mask + mask = 2 ** torch.arange(bits - 1, -1, -1).to(value.device, value.dtype) + # add dimension, bitwise_and gets the bits needed for the value, the rest is converting to byte + return value.unsqueeze(-1).bitwise_and(mask).ne(0).byte() + + +def bits_to_dec(bits: Tensor) -> Tensor: + # get num of bits used + num_bits = len(bits) + # convert by summing decimal values of set bits + return torch.sum((2 ** torch.arange(num_bits - 1, -1, -1)) * bits) diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index 72e9d9ca6..8ee882049 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -14,8 +14,6 @@ class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase): exponent_bit_width = 4 mantissa_bit_width = 3 case_clamp_impl = FloatClamp - nan_values = tuple(('111',)) - inf_values = None # hypothesis extra hypothesis_internal_is_this_a_mock_check = False @@ -25,32 +23,55 @@ class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase): exponent_bit_width = 5 mantissa_bit_width = 2 case_clamp_impl = FloatClamp - nan_values = ('01', '11', '10') - inf_values = tuple(('00',)) # hypothesis extra hypothesis_internal_is_this_a_mock_check = False @pytest_cases.fixture @pytest_cases.parametrize('sat', [True, False]) -def fp8e4m3(sat): +def fp8e4m3_regular(sat): class Fp8e4m3(Fp8e4m3Base): saturating = sat + nan_values = tuple(('111',)) + inf_values = None return Fp8e4m3 @pytest_cases.fixture @pytest_cases.parametrize('sat', [True, False]) -def fp8e5m2(sat): +def fp8e5m2_regular(sat): class Fp8e5m2(Fp8e5m2Base): saturating = sat + nan_values = ('01', '11', '10') + inf_values = tuple(('00',)) return Fp8e5m2 -list_of_fixtures = ['fp8e4m3', 'fp8e5m2'] +@pytest_cases.fixture +@pytest_cases.parametrize('sat', [True, False]) +def fp8e4m3_no_special_values(sat): + + class Fp8e4m3None(Fp8e4m3Base): + saturating = sat + + return Fp8e4m3None + + +@pytest_cases.fixture +@pytest_cases.parametrize('sat', [True, False]) +def fp8e5m2_no_special_values(sat): + + class Fp8e5m2None(Fp8e5m2Base): + saturating = sat + + return Fp8e5m2None + + +list_of_fixtures = [ + 'fp8e4m3_regular', 'fp8e5m2_regular', 'fp8e4m3_no_special_values', 'fp8e5m2_no_special_values'] fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index ec7536bf6..2e51f1db3 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -16,7 +16,6 @@ @given(minifloat_format=random_minifloat_format()) -@jit_disabled_for_mock() def test_float_quant_defaults(minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format @@ -50,7 +49,6 @@ def test_minifloat(minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) -@jit_disabled_for_mock() def test_float_to_quant_float(inp, minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format if exponent_bit_width == 0 or mantissa_bit_width == 0: diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index 3cdcf7bbb..b25296acb 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -7,11 +7,10 @@ from brevitas.quant.experimental.float_base import Fp8e4m3Mixin from brevitas.quant.experimental.float_base import Fp8e5m2Mixin from tests.brevitas.hyp_helper import float_tensor_random_shape_st -from tests.marker import jit_disabled_for_mock from .minifloat_fixtures import * -FORMATS = {Fp8e5m2Mixin: 57344., Fp8e4m3Mixin: 448.} +FORMATS = {Fp8e5m2Mixin: 57344., Fp8e4m3Mixin: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.} @pytest.mark.parametrize( @@ -23,7 +22,6 @@ def test_max_value(minifloat, expected_max_val): @given(inp=float_tensor_random_shape_st()) -@jit_disabled_for_mock() def test_clamp(inp, fp8_clamp): max_val = fp8_clamp.case_clamp_impl.max_val_impl() # get values that exceed max_val From baa98181f51aaff14da1461be7f2b1069ae48fb3 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Wed, 21 Feb 2024 05:41:42 -0800 Subject: [PATCH 10/15] Feat (minifloat): use tensor_clamp_impl for minifloat conversion and clean up --- src/brevitas/core/function_wrapper/clamp.py | 108 +++++++++--------- src/brevitas/core/quant/float.py | 8 +- src/brevitas/quant/experimental/float_base.py | 4 +- tests/brevitas/core/minifloat_fixtures.py | 7 +- tests/brevitas/core/test_minifloat.py | 10 +- 5 files changed, 67 insertions(+), 70 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 13b5da2a2..c919dabf0 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -8,6 +8,7 @@ import torch from torch import Tensor +from torch.nn import Module import brevitas from brevitas.core.utils import StatelessBuffer @@ -88,79 +89,70 @@ class FloatClamp(brevitas.jit.ScriptModule): I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ - __constants__ = ['nan_values', 'inf_values', 'saturating'] - def __init__( self, - exponent_bit_width: Tensor, - mantissa_bit_width: Tensor, - exponent_bias: Tensor, + exponent_bit_width: int, + mantissa_bit_width: int, + exponent_bias: int, + tensor_clamp_impl: Module = TensorClamp(), nan_values: Optional[Tuple[str]] = None, inf_values: Optional[Tuple[str]] = None, saturating: bool = False) -> None: super(FloatClamp, self).__init__() - self.exponent_bit_width = torch.tensor(exponent_bit_width) - self.mantissa_bit_width = torch.tensor(mantissa_bit_width) - self.exponent_bias = torch.tensor(exponent_bias) - - self.nan_values = nan_values - self.inf_values = inf_values - self.saturating = saturating - # inf without NaN not possible - if self.inf_values is None and self.nan_values is None: - self.max_val_impl = StatelessBuffer( - max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)) - elif self.nan_values is not None: + if inf_values is None and nan_values is None: + max_val_impl = StatelessBuffer( + max_float( + torch.tensor(exponent_bit_width), + torch.tensor(mantissa_bit_width), + torch.tensor(exponent_bias))) + elif nan_values is not None: # we at least have values for NaN, so initiate MaxValInfNaN - self.max_val_impl = MaxFloatInfNaN( - exponent_bit_width=self.exponent_bit_width, - mantissa_bit_width=self.mantissa_bit_width, - exponent_bias=self.exponent_bias, - nan_values=self.nan_values, - inf_values=self.inf_values, - saturating=self.saturating) + max_val_impl = MaxFloatInfNaN( + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, + nan_values=nan_values, + inf_values=inf_values) else: # no NaN values but inf values raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') - self.clamp_impl = CaseClamp(inf_values=self.inf_values, saturating=self.saturating) + # class for clamping to inf/NaN values + self.fpx_clamp_impl = FpXClamp( + inf_values=inf_values, saturating=saturating, tensor_clamp_impl=tensor_clamp_impl) + + # get max value for the minifloat config, no need to compute it during forward pass + self.max_value = max_val_impl() @brevitas.jit.script_method def forward(self, inp: Tensor): - # get max value for the minifloat config - max_value = self.max_val_impl() - return self.clamp_impl(inp, max_value) + return self.fpx_clamp_impl(inp, self.max_value) class MaxFloatInfNaN(brevitas.jit.ScriptModule): def __init__( self, - exponent_bit_width: Tensor, - mantissa_bit_width: Tensor, - exponent_bias: Tensor, + exponent_bit_width: int, + mantissa_bit_width: int, + exponent_bias: int, nan_values: Tuple[str], - inf_values: Tuple[str], - saturating: bool = False) -> None: + inf_values: Optional[Tuple[str]]) -> None: super(MaxFloatInfNaN, self).__init__() - self.exponent_bit_width = exponent_bit_width - self.mantissa_bit_width = mantissa_bit_width - self.exponent_bias = exponent_bias - - self.inf_values = inf_values - self.nan_values = nan_values - self._special_values = nan_values + inf_values if inf_values is not None else nan_values + self.exponent_bit_width = StatelessBuffer(torch.tensor(exponent_bit_width)) + self.mantissa_bit_width = StatelessBuffer(torch.tensor(mantissa_bit_width)) + self.exponent_bias = StatelessBuffer(torch.tensor(exponent_bias)) - self.saturating = saturating + _special_values = nan_values + inf_values if inf_values is not None else nan_values # check that NaN/inf values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, self._special_values)): + if any(map(lambda x: len(x) > mantissa_bit_width, _special_values)): raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') # move computation of min for forward pass here so it's jit compatible - self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), self._special_values))) + self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), _special_values))) @brevitas.jit.script_method def forward(self): @@ -169,29 +161,30 @@ def forward(self): if max_value_mantissa < 0: # all mantissa values are used, so we need to use decrease exponent values - exponent = torch.tensor(1).repeat(self.exponent_bit_width - 1) - exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype) - ]) # add trailing 0 to reach bit width + exponent = torch.tensor(1).repeat(self.exponent_bit_width() - 1) + # add trailing 0 to reach bit width + exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)]) # since we decreased exponent, we can use full mantissa - mantissa = torch.tensor(1).repeat(self.mantissa_bit_width) + mantissa = torch.tensor(1).repeat(self.mantissa_bit_width()) else: # there is a free mantissa code, so use full exponent - exponent = torch.tensor(1).repeat(self.exponent_bit_width) + exponent = torch.tensor(1).repeat(self.exponent_bit_width()) # get binary code for max_value_mantissa in the number of mantissa bits - mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width) + mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width()) # we don't need the sign since we're looking for the max value max_value = get_minifloat_value( - exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias) + exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias()) return max_value -class CaseClamp(brevitas.jit.ScriptModule): +class FpXClamp(brevitas.jit.ScriptModule): - def __init__(self, inf_values: Tuple[str], saturating: bool) -> None: - super(CaseClamp, self).__init__() + def __init__(self, inf_values: Tuple[str], saturating: bool, tensor_clamp_impl: Module) -> None: + super(FpXClamp, self).__init__() self.inf_values = inf_values self.saturating = saturating + self.tensor_clamp_impl = tensor_clamp_impl @brevitas.jit.script_method def forward(self, x: Tensor, max_value: Tensor): @@ -201,10 +194,11 @@ def forward(self, x: Tensor, max_value: Tensor): p_max_val_mask = x > max_value n_max_val_mask = -x > max_value - if self.saturating: - # clamp everything to +- max_value - x = x.clamp(-max_value, max_value) - else: + # first clamp everything to +- max_value, basically the saturating case + x = self.tensor_clamp_impl(x, min_val=-max_value, max_val=max_value) + + if not self.saturating: + # if non-saturating, we need to map values greater than max_val to nan or inf if self.inf_values is not None: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf x[p_max_val_mask] = torch.tensor(float('inf')) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index cf806b7fb..8a462fc0e 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -25,7 +25,7 @@ def __init__( exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: int, - case_clamp_impl: Optional[nn.Module] = None, + float_clamp_impl: Optional[nn.Module] = None, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, float_to_int_impl: nn.Module = RoundSte(), @@ -55,8 +55,8 @@ def __init__( float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) - if case_clamp_impl is None: - self.case_clamp_impl = FloatClamp( + if float_clamp_impl is None: + self.float_clamp_impl = FloatClamp( exponent_bit_width=self.exponent_bit_width(), mantissa_bit_width=self.mantissa_bit_width(), exponent_bias=self.exponent_bias()) @@ -92,7 +92,7 @@ def dequantize(self, y, scale): def forward(self, x): y, scale = self.quantize(x) # after quantizing, clamp to special cases like NaN/inf if they are set - y = self.case_clamp_impl(y) + y = self.float_clamp_impl(y) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 929b3e02d..8421c3c13 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -55,7 +55,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp nan_values = (('111',)) inf_values = None saturating = True @@ -65,7 +65,7 @@ class Fp8e5m2Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp nan_values = ('01', '11', '10') inf_values = (('00',)) saturating = True diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index 8ee882049..e0f7528a6 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -5,6 +5,7 @@ from pytest_cases import fixture_union from brevitas.core.function_wrapper import FloatClamp +from brevitas.inject.enum import BitWidthImplType from brevitas.quant.experimental.float_base import ExponentBiasMixin from brevitas.quant.experimental.float_base import ScaledFloatWeightBase @@ -13,7 +14,8 @@ class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp + bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False @@ -22,7 +24,8 @@ class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 - case_clamp_impl = FloatClamp + float_clamp_impl = FloatClamp + bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index b25296acb..a8b2f93b9 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -16,26 +16,26 @@ @pytest.mark.parametrize( 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items())) def test_max_value(minifloat, expected_max_val): - max_val = minifloat.case_clamp_impl.max_val_impl() + max_val = minifloat.float_clamp_impl.max_value assert expected_max_val == max_val @given(inp=float_tensor_random_shape_st()) def test_clamp(inp, fp8_clamp): - max_val = fp8_clamp.case_clamp_impl.max_val_impl() + max_val = fp8_clamp.float_clamp_impl.max_value # get values that exceed max_val over_limit_mask = inp.abs() > max_val # clamp inp - inp = fp8_clamp.case_clamp_impl(inp) + inp = fp8_clamp.float_clamp_impl(inp) - if fp8_clamp.case_clamp_impl.saturating: + if fp8_clamp.float_clamp_impl.fpx_clamp_impl.saturating: # should be clamped to +- max val assert (inp[over_limit_mask].abs() == max_val).all() else: # if inf_values, over limit mask should now be all inf - if fp8_clamp.case_clamp_impl.inf_values is not None: + if fp8_clamp.float_clamp_impl.fpx_clamp_impl.inf_values is not None: # all values exceeding max_val should be inf assert inp[over_limit_mask].isinf().all() else: From dbc042beadbdf84bcde7ee738ab59158095b592f Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 23 Feb 2024 02:48:31 -0800 Subject: [PATCH 11/15] Fix (minifloat): compute max_value during dependency injection --- src/brevitas/core/function_wrapper/clamp.py | 108 ++---------------- src/brevitas/core/quant/float.py | 6 +- src/brevitas/quant/experimental/float_base.py | 13 ++- src/brevitas/utils/float_quant_utils.py | 61 +++++++--- tests/brevitas/core/minifloat_fixtures.py | 9 +- tests/brevitas/core/test_minifloat.py | 8 +- 6 files changed, 79 insertions(+), 126 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index c919dabf0..40cff68c0 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -12,10 +12,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer -from brevitas.function import max_float from brevitas.function import tensor_clamp -from brevitas.utils.float_quant_utils import dec_to_bits -from brevitas.utils.float_quant_utils import get_minifloat_value class TensorClamp(brevitas.jit.ScriptModule): @@ -89,117 +86,34 @@ class FloatClamp(brevitas.jit.ScriptModule): I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ + __constants__ = ['saturating', 'has_inf_values'] + def __init__( self, - exponent_bit_width: int, - mantissa_bit_width: int, - exponent_bias: int, + max_value: float, tensor_clamp_impl: Module = TensorClamp(), - nan_values: Optional[Tuple[str]] = None, inf_values: Optional[Tuple[str]] = None, saturating: bool = False) -> None: super(FloatClamp, self).__init__() - # inf without NaN not possible - if inf_values is None and nan_values is None: - max_val_impl = StatelessBuffer( - max_float( - torch.tensor(exponent_bit_width), - torch.tensor(mantissa_bit_width), - torch.tensor(exponent_bias))) - elif nan_values is not None: - # we at least have values for NaN, so initiate MaxValInfNaN - max_val_impl = MaxFloatInfNaN( - exponent_bit_width=exponent_bit_width, - mantissa_bit_width=mantissa_bit_width, - exponent_bias=exponent_bias, - nan_values=nan_values, - inf_values=inf_values) - else: - # no NaN values but inf values - raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') - - # class for clamping to inf/NaN values - self.fpx_clamp_impl = FpXClamp( - inf_values=inf_values, saturating=saturating, tensor_clamp_impl=tensor_clamp_impl) - - # get max value for the minifloat config, no need to compute it during forward pass - self.max_value = max_val_impl() - - @brevitas.jit.script_method - def forward(self, inp: Tensor): - return self.fpx_clamp_impl(inp, self.max_value) - - -class MaxFloatInfNaN(brevitas.jit.ScriptModule): - - def __init__( - self, - exponent_bit_width: int, - mantissa_bit_width: int, - exponent_bias: int, - nan_values: Tuple[str], - inf_values: Optional[Tuple[str]]) -> None: - super(MaxFloatInfNaN, self).__init__() - self.exponent_bit_width = StatelessBuffer(torch.tensor(exponent_bit_width)) - self.mantissa_bit_width = StatelessBuffer(torch.tensor(mantissa_bit_width)) - self.exponent_bias = StatelessBuffer(torch.tensor(exponent_bias)) - - _special_values = nan_values + inf_values if inf_values is not None else nan_values - - # check that NaN/inf values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, _special_values)): - raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') - - # move computation of min for forward pass here so it's jit compatible - self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), _special_values))) + self.tensor_clamp_impl = tensor_clamp_impl - @brevitas.jit.script_method - def forward(self): - # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 - max_value_mantissa = self._min_special_case - 1 - - if max_value_mantissa < 0: - # all mantissa values are used, so we need to use decrease exponent values - exponent = torch.tensor(1).repeat(self.exponent_bit_width() - 1) - # add trailing 0 to reach bit width - exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)]) - # since we decreased exponent, we can use full mantissa - mantissa = torch.tensor(1).repeat(self.mantissa_bit_width()) - else: - # there is a free mantissa code, so use full exponent - exponent = torch.tensor(1).repeat(self.exponent_bit_width()) - # get binary code for max_value_mantissa in the number of mantissa bits - mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width()) - - # we don't need the sign since we're looking for the max value - max_value = get_minifloat_value( - exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias()) - return max_value - - -class FpXClamp(brevitas.jit.ScriptModule): - - def __init__(self, inf_values: Tuple[str], saturating: bool, tensor_clamp_impl: Module) -> None: - super(FpXClamp, self).__init__() - self.inf_values = inf_values + self.max_value = StatelessBuffer(torch.tensor(max_value)) self.saturating = saturating - self.tensor_clamp_impl = tensor_clamp_impl + self.has_inf_values = bool(inf_values) @brevitas.jit.script_method - def forward(self, x: Tensor, max_value: Tensor): - # NaN values all stay at NaN, so no need to do anything with NaN values - # get all positive inf values + def forward(self, x: Tensor): inf_mask = x.isinf() - p_max_val_mask = x > max_value - n_max_val_mask = -x > max_value + p_max_val_mask = x > self.max_value() + n_max_val_mask = -x > self.max_value() # first clamp everything to +- max_value, basically the saturating case - x = self.tensor_clamp_impl(x, min_val=-max_value, max_val=max_value) + x = self.tensor_clamp_impl(x, min_val=-self.max_value(), max_val=self.max_value()) if not self.saturating: # if non-saturating, we need to map values greater than max_val to nan or inf - if self.inf_values is not None: + if self.has_inf_values: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf x[p_max_val_mask] = torch.tensor(float('inf')) x[n_max_val_mask] = torch.tensor(float('-inf')) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 8a462fc0e..8557b9974 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -13,6 +13,7 @@ from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float from brevitas.function.ops_ste import floor_ste +from brevitas.utils.float_quant_utils import get_max_value class FloatQuant(brevitas.jit.ScriptModule): @@ -57,9 +58,8 @@ def __init__( scaling_impl = ConstScaling(1., device=device, dtype=dtype) if float_clamp_impl is None: self.float_clamp_impl = FloatClamp( - exponent_bit_width=self.exponent_bit_width(), - mantissa_bit_width=self.mantissa_bit_width(), - exponent_bias=self.exponent_bias()) + max_value=get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None)) # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 8421c3c13..3777e6703 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -11,6 +11,7 @@ from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum +from brevitas.utils.float_quant_utils import get_max_value class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum): @@ -51,7 +52,15 @@ def exponent_bias(exponent_bit_width): return 2 ** (exponent_bit_width - 1) - 1 -class Fp8e4m3Mixin(ExponentBiasMixin): +class MaxFloatInfNaNMixin(ExtendedInjector): + + @value + def max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values): + return get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values) + + +class Fp8e4m3Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 @@ -61,7 +70,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin): saturating = True -class Fp8e5m2Mixin(ExponentBiasMixin): +class Fp8e5m2Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py index 33515f95e..b90dbf82d 100644 --- a/src/brevitas/utils/float_quant_utils.py +++ b/src/brevitas/utils/float_quant_utils.py @@ -1,41 +1,66 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import torch -from torch import Tensor - -def mantissa_bits_to_float(bits: Tensor, frexp_compatible: bool = False) -> float: +def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: # computes the decimal place value from a given binary tensor res = 1.0 for i, val in enumerate(bits): # iterating through from left to right - res += ((2 ** -(i + 1)) * val) + res += ((2 ** -(i + 1)) * float(val)) if frexp_compatible: return res / 2. else: return res -def get_minifloat_value(exponent: Tensor, mantissa: Tensor, exponent_bias: Tensor) -> Tensor: +def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> float: """ Returns the minifloat value for a given exponent, mantissa and exponent_bias. It expects the exponent and mantissa in their binary format. """ - exponent_value = bits_to_dec(exponent) + exponent_value = int(exponent, 2) mantissa_value = mantissa_bits_to_float(mantissa) - return torch.exp2(exponent_value - exponent_bias) * mantissa_value + return 2 ** (exponent_value - exponent_bias) * mantissa_value + + +def get_max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values): + # Idea: take the smallest NaN/inf value, set max_value to the next smaller one + # inf without NaN not possible + if inf_values is None and nan_values is None: + # no special cases, max_value is using all bits for exponent and mantissa + exponent = '1' * exponent_bit_width + mantissa = '1' * mantissa_bit_width + elif nan_values is not None: + # we at least have values for NaN, so initiate MaxValInfNaN + special_values = nan_values + inf_values if inf_values is not None else nan_values + # check that NaN/inf values are all mantissa_bit_width long + if any(map(lambda x: len(x) > mantissa_bit_width, special_values)): + raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') -def dec_to_bits(value: Tensor, bits: int) -> Tensor: - # set up mask - mask = 2 ** torch.arange(bits - 1, -1, -1).to(value.device, value.dtype) - # add dimension, bitwise_and gets the bits needed for the value, the rest is converting to byte - return value.unsqueeze(-1).bitwise_and(mask).ne(0).byte() + # get the minimum special case, our max value is the next smaller value + min_special_case = min(map(lambda x: int(x, 2), special_values)) + max_value_mantissa = min_special_case - 1 + + if max_value_mantissa < 0: + # all mantissa values are used, so we need to use decrease exponent values + exponent = '1' * (exponent_bit_width - 1) + # add trailing 0 to reach bit width + exponent += '0' + # since we decreased exponent, we can use full mantissa + mantissa = '1' * mantissa_bit_width + else: + # there is a free mantissa code, so use full exponent + exponent = '1' * exponent_bit_width + # get binary code for max_value_mantissa in the number of mantissa bits + mantissa = format(max_value_mantissa, f'0{mantissa_bit_width}b') + else: + # no NaN values but inf values + raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') -def bits_to_dec(bits: Tensor) -> Tensor: - # get num of bits used - num_bits = len(bits) - # convert by summing decimal values of set bits - return torch.sum((2 ** torch.arange(num_bits - 1, -1, -1)) * bits) + # we don't need the sign since we're looking for the max value + max_value = get_minifloat_value( + exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias) + return max_value diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index e0f7528a6..48cefc663 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -7,24 +7,29 @@ from brevitas.core.function_wrapper import FloatClamp from brevitas.inject.enum import BitWidthImplType from brevitas.quant.experimental.float_base import ExponentBiasMixin +from brevitas.quant.experimental.float_base import MaxFloatInfNaNMixin from brevitas.quant.experimental.float_base import ScaledFloatWeightBase -class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase): +class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 float_clamp_impl = FloatClamp + nan_values = None + inf_values = None bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False -class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase): +class Fp8e5m2Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 float_clamp_impl = FloatClamp + nan_values = None + inf_values = None bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index a8b2f93b9..2a4f6b000 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -16,26 +16,26 @@ @pytest.mark.parametrize( 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items())) def test_max_value(minifloat, expected_max_val): - max_val = minifloat.float_clamp_impl.max_value + max_val = minifloat.float_clamp_impl.max_value() assert expected_max_val == max_val @given(inp=float_tensor_random_shape_st()) def test_clamp(inp, fp8_clamp): - max_val = fp8_clamp.float_clamp_impl.max_value + max_val = fp8_clamp.float_clamp_impl.max_value() # get values that exceed max_val over_limit_mask = inp.abs() > max_val # clamp inp inp = fp8_clamp.float_clamp_impl(inp) - if fp8_clamp.float_clamp_impl.fpx_clamp_impl.saturating: + if fp8_clamp.float_clamp_impl.saturating: # should be clamped to +- max val assert (inp[over_limit_mask].abs() == max_val).all() else: # if inf_values, over limit mask should now be all inf - if fp8_clamp.float_clamp_impl.fpx_clamp_impl.inf_values is not None: + if fp8_clamp.float_clamp_impl.has_inf_values: # all values exceeding max_val should be inf assert inp[over_limit_mask].isinf().all() else: From 2bb43cff2ee9b1e44017087ce6fe165585f1622f Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 23 Feb 2024 04:09:44 -0800 Subject: [PATCH 12/15] Fix (minifloat): add check for saturating and clean up --- src/brevitas/core/function_wrapper/clamp.py | 4 +- src/brevitas/core/quant/float.py | 7 +-- src/brevitas/quant/experimental/float_base.py | 14 +++++- src/brevitas/utils/float_quant_utils.py | 5 +- tests/brevitas/core/minifloat_fixtures.py | 25 ++-------- tests/brevitas/core/test_float_quant.py | 46 ++++++++++++++----- 6 files changed, 58 insertions(+), 43 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 40cff68c0..0bfb79374 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -91,9 +91,9 @@ class FloatClamp(brevitas.jit.ScriptModule): def __init__( self, max_value: float, - tensor_clamp_impl: Module = TensorClamp(), + tensor_clamp_impl: Module, inf_values: Optional[Tuple[str]] = None, - saturating: bool = False) -> None: + saturating: bool = True) -> None: super(FloatClamp, self).__init__() self.tensor_clamp_impl = tensor_clamp_impl diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 8557b9974..29e1f31eb 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -26,7 +26,7 @@ def __init__( exponent_bit_width: int, mantissa_bit_width: int, exponent_bias: int, - float_clamp_impl: Optional[nn.Module] = None, + float_clamp_impl: nn.Module, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, float_to_int_impl: nn.Module = RoundSte(), @@ -56,14 +56,11 @@ def __init__( float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) - if float_clamp_impl is None: - self.float_clamp_impl = FloatClamp( - max_value=get_max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, None, None)) # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl self.scaling_impl = scaling_impl + self.float_clamp_impl = float_clamp_impl @brevitas.jit.script_method def internal_scale(self, x): diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 3777e6703..e3fa49c5b 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from brevitas.core.function_wrapper import FloatClamp +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector @@ -55,9 +56,16 @@ def exponent_bias(exponent_bit_width): class MaxFloatInfNaNMixin(ExtendedInjector): @value - def max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values): + def max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, + saturating): return get_max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values) + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + nan_values, + inf_values, + saturating) class Fp8e4m3Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): @@ -65,6 +73,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): exponent_bit_width = 4 mantissa_bit_width = 3 float_clamp_impl = FloatClamp + tensor_clamp_impl = TensorClamp nan_values = (('111',)) inf_values = None saturating = True @@ -75,6 +84,7 @@ class Fp8e5m2Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): exponent_bit_width = 5 mantissa_bit_width = 2 float_clamp_impl = FloatClamp + tensor_clamp_impl = TensorClamp nan_values = ('01', '11', '10') inf_values = (('00',)) saturating = True diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py index b90dbf82d..5d5c4037f 100644 --- a/src/brevitas/utils/float_quant_utils.py +++ b/src/brevitas/utils/float_quant_utils.py @@ -24,10 +24,13 @@ def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> flo return 2 ** (exponent_value - exponent_bias) * mantissa_value -def get_max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values): +def get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, saturating): # Idea: take the smallest NaN/inf value, set max_value to the next smaller one # inf without NaN not possible if inf_values is None and nan_values is None: + # saturating has to be True if no NaN/inf value are used + assert saturating, 'cannot be non-saturating without NaN/inf values' # no special cases, max_value is using all bits for exponent and mantissa exponent = '1' * exponent_bit_width mantissa = '1' * mantissa_bit_width diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index 48cefc663..a3cc644e4 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -18,6 +18,7 @@ class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase) float_clamp_impl = FloatClamp nan_values = None inf_values = None + saturating = True bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False @@ -30,6 +31,7 @@ class Fp8e5m2Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase) float_clamp_impl = FloatClamp nan_values = None inf_values = None + saturating = True bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False @@ -59,27 +61,6 @@ class Fp8e5m2(Fp8e5m2Base): return Fp8e5m2 -@pytest_cases.fixture -@pytest_cases.parametrize('sat', [True, False]) -def fp8e4m3_no_special_values(sat): - - class Fp8e4m3None(Fp8e4m3Base): - saturating = sat - - return Fp8e4m3None - - -@pytest_cases.fixture -@pytest_cases.parametrize('sat', [True, False]) -def fp8e5m2_no_special_values(sat): - - class Fp8e5m2None(Fp8e5m2Base): - saturating = sat - - return Fp8e5m2None - - -list_of_fixtures = [ - 'fp8e4m3_regular', 'fp8e5m2_regular', 'fp8e4m3_no_special_values', 'fp8e5m2_no_special_values'] +list_of_fixtures = ['fp8e4m3_regular', 'fp8e5m2_regular'] fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 2e51f1db3..1e4058fb8 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -6,9 +6,12 @@ import pytest import torch +from brevitas.core.function_wrapper import FloatClamp from brevitas.core.function_wrapper import RoundSte +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling +from brevitas.utils.float_quant_utils import get_max_value from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format @@ -19,8 +22,6 @@ def test_float_quant_defaults(minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format - # specifically don't set exponent bias to see if default works - expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1 if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( @@ -28,15 +29,20 @@ def test_float_quant_defaults(minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, - signed=signed) + signed=signed, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, - signed=signed) - assert expected_exponent_bias == float_quant.exponent_bias() + signed=signed, + float_clamp_impl=float_clamp) assert isinstance(float_quant.float_to_int_impl, RoundSte) assert isinstance(float_quant.float_scaling_impl, ConstScaling) assert isinstance(float_quant.scaling_impl, ConstScaling) @@ -58,14 +64,20 @@ def test_float_to_quant_float(inp, minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, - signed=signed) + signed=signed, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, - signed=signed) + signed=signed, + float_clamp_impl=float_clamp) expected_out, _, _, bit_width_out = float_quant(inp) out_quant, scale = float_quant.quantize(inp) @@ -88,8 +100,13 @@ def test_scaling_impls_called_once(inp, minifloat_format): exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, @@ -97,7 +114,8 @@ def test_scaling_impls_called_once(inp, minifloat_format): exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=float_clamp) output = float_quant.quantize(inp) # scaling implementations should be called exaclty once on the input scaling_impl.assert_called_once_with(inp) @@ -123,8 +141,13 @@ def test_inner_scale(inp, minifloat_format, scale): exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=None) else: + max_value = get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None, True) + # init FloatClamp + float_clamp = FloatClamp(max_value=max_value, tensor_clamp_impl=TensorClamp()) float_quant = FloatQuant( bit_width=bit_width, exponent_bit_width=exponent_bit_width, @@ -132,7 +155,8 @@ def test_inner_scale(inp, minifloat_format, scale): exponent_bias=exponent_bias, signed=signed, scaling_impl=scaling_impl, - float_scaling_impl=float_scaling_impl) + float_scaling_impl=float_scaling_impl, + float_clamp_impl=float_clamp) # scale inp manually scaled_inp = inp / scale From 0c1d9ce0a9e5f12bd8e01c7c1fbe3e5abc1317d1 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 23 Feb 2024 06:20:04 -0800 Subject: [PATCH 13/15] Fix (minifloat): restructuring quantizers --- src/brevitas/core/quant/float.py | 2 - src/brevitas/quant/experimental/float_base.py | 61 ++++++++----------- tests/brevitas/core/minifloat_fixtures.py | 9 +-- tests/brevitas/core/test_minifloat.py | 6 +- 4 files changed, 31 insertions(+), 47 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 29e1f31eb..11da5864b 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -7,13 +7,11 @@ import torch.nn as nn import brevitas -from brevitas.core.function_wrapper import FloatClamp from brevitas.core.function_wrapper import RoundSte from brevitas.core.scaling import ConstScaling from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float from brevitas.function.ops_ste import floor_ste -from brevitas.utils.float_quant_utils import get_max_value class FloatQuant(brevitas.jit.ScriptModule): diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index e3fa49c5b..697a7b15a 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -15,20 +15,37 @@ from brevitas.utils.float_quant_utils import get_max_value -class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum): - proxy_class = WeightQuantProxyFromInjector +class FloatBase(SolveTensorQuantFloatToIntImplFromEnum): tensor_quant = FloatQuant signed = True float_to_int_impl_type = 'round' scaling_min_val = 1e-10 + float_clamp_impl = FloatClamp + tensor_clamp_impl = TensorClamp + @value + def exponent_bias(exponent_bit_width): + return 2 ** (exponent_bit_width - 1) - 1 -class FloatActBase(SolveTensorQuantFloatToIntImplFromEnum): + @value + def max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, + saturating): + return get_max_value( + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + nan_values, + inf_values, + saturating) + + +class FloatWeightBase(FloatBase): + proxy_class = WeightQuantProxyFromInjector + + +class FloatActBase(FloatBase): proxy_class = ActQuantProxyFromInjector - tensor_quant = FloatQuant - signed = True - float_to_int_impl_type = 'round' - scaling_min_val = 1e-10 class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver): @@ -46,45 +63,19 @@ class ScaledFloatActBase(FloatActBase, ActQuantSolver): float_scaling_impl = FloatScaling -class ExponentBiasMixin(ExtendedInjector): - - @value - def exponent_bias(exponent_bit_width): - return 2 ** (exponent_bit_width - 1) - 1 - - -class MaxFloatInfNaNMixin(ExtendedInjector): - - @value - def max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, - saturating): - return get_max_value( - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - nan_values, - inf_values, - saturating) - - -class Fp8e4m3Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): +class Fp8e4m3Mixin(ExtendedInjector): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 - float_clamp_impl = FloatClamp - tensor_clamp_impl = TensorClamp nan_values = (('111',)) inf_values = None saturating = True -class Fp8e5m2Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): +class Fp8e5m2Mixin(ExtendedInjector): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 - float_clamp_impl = FloatClamp - tensor_clamp_impl = TensorClamp nan_values = ('01', '11', '10') inf_values = (('00',)) saturating = True diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index a3cc644e4..d16889c9b 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -4,18 +4,14 @@ import pytest_cases from pytest_cases import fixture_union -from brevitas.core.function_wrapper import FloatClamp from brevitas.inject.enum import BitWidthImplType -from brevitas.quant.experimental.float_base import ExponentBiasMixin -from brevitas.quant.experimental.float_base import MaxFloatInfNaNMixin from brevitas.quant.experimental.float_base import ScaledFloatWeightBase -class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase): +class Fp8e4m3Base(ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 - float_clamp_impl = FloatClamp nan_values = None inf_values = None saturating = True @@ -24,11 +20,10 @@ class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase) hypothesis_internal_is_this_a_mock_check = False -class Fp8e5m2Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase): +class Fp8e5m2Base(ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 - float_clamp_impl = FloatClamp nan_values = None inf_values = None saturating = True diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index 2a4f6b000..f5f0e8d54 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -4,13 +4,13 @@ from hypothesis import given import pytest -from brevitas.quant.experimental.float_base import Fp8e4m3Mixin -from brevitas.quant.experimental.float_base import Fp8e5m2Mixin +from brevitas.quant.experimental.float import Fp8e4m3Weight +from brevitas.quant.experimental.float import Fp8e5m2Weight from tests.brevitas.hyp_helper import float_tensor_random_shape_st from .minifloat_fixtures import * -FORMATS = {Fp8e5m2Mixin: 57344., Fp8e4m3Mixin: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.} +FORMATS = {Fp8e5m2Weight: 57344., Fp8e4m3Weight: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.} @pytest.mark.parametrize( From 721f596f8178d9b2fc4c71c6c308453b919ac15f Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 26 Feb 2024 02:25:49 -0800 Subject: [PATCH 14/15] Fix (minifloat): restructure OCP format quantizers --- src/brevitas/quant/experimental/float_base.py | 12 +- .../quant/experimental/float_quant_ocp.py | 150 ++++++++++++++++++ tests/brevitas/core/minifloat_fixtures.py | 46 ++---- tests/brevitas/core/test_minifloat.py | 5 +- 4 files changed, 171 insertions(+), 42 deletions(-) create mode 100644 src/brevitas/quant/experimental/float_quant_ocp.py diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 697a7b15a..61201578e 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -29,8 +29,12 @@ def exponent_bias(exponent_bit_width): @value def max_value( - exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, - saturating): + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + nan_values=None, + inf_values=None, + saturating=True): return get_max_value( exponent_bit_width, mantissa_bit_width, @@ -67,8 +71,6 @@ class Fp8e4m3Mixin(ExtendedInjector): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 - nan_values = (('111',)) - inf_values = None saturating = True @@ -76,6 +78,4 @@ class Fp8e5m2Mixin(ExtendedInjector): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 - nan_values = ('01', '11', '10') - inf_values = (('00',)) saturating = True diff --git a/src/brevitas/quant/experimental/float_quant_ocp.py b/src/brevitas/quant/experimental/float_quant_ocp.py new file mode 100644 index 000000000..6dfda1304 --- /dev/null +++ b/src/brevitas/quant/experimental/float_quant_ocp.py @@ -0,0 +1,150 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.experimental.float_base import FloatActBase +from brevitas.quant.experimental.float_base import FloatWeightBase +from brevitas.quant.experimental.float_base import Fp8e4m3Mixin +from brevitas.quant.experimental.float_base import Fp8e5m2Mixin +from brevitas.quant.experimental.float_base import ScaledFloatActBase +from brevitas.quant.experimental.float_base import ScaledFloatWeightBase + + +class Fp8e4m3OCPMixin(Fp8e4m3Mixin): + nan_values = (('111',)) + inf_values = None + + +class Fp8e5m2OCPMixin(Fp8e5m2Mixin): + nan_values = ('01', '11', '10') + inf_values = (('00',)) + + +class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase): + """ + FP8 signed E3M4 weight quantizer. + """ + pass + + +class Fp8e5m2OCPWeight(Fp8e5m2OCPMixin, FloatWeightBase): + """ + FP8 signed E5M2 weight quantizer. + """ + pass + + +class Fp8e4m3OCPAct(Fp8e4m3OCPMixin, FloatActBase): + """ + FP8 signed E4M3 activation quantizer. + """ + pass + + +class Fp8e5m2OCPAct(Fp8e5m2OCPMixin, FloatActBase): + """ + FP8 signed E5M2 activation quantizer. + """ + pass + + +class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3OCPActPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2OCPActPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3OCPWeightPerChannelFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e5m2OCPWeightPerChannelFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): + """ + FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3OCPActPerChannelFloat2d(Fp8e4m3OCPMixin, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e5m2OCPActPerChannelFloat2d(Fp8e5m2OCPMixin, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3OCPActPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e5m2OCPActPerTensorFloatMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_channel = False + + +class Fp8e4m3OCPActPerChannelFloat2dMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e5m2OCPActPerChannelFloat2dMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_channel = True + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3OCPWeightPerChannelFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3OCPWeightPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. + """ + scaling_per_output_channel = False diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index d16889c9b..681caf8ca 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -4,58 +4,34 @@ import pytest_cases from pytest_cases import fixture_union -from brevitas.inject.enum import BitWidthImplType -from brevitas.quant.experimental.float_base import ScaledFloatWeightBase - - -class Fp8e4m3Base(ScaledFloatWeightBase): - bit_width = 8 - exponent_bit_width = 4 - mantissa_bit_width = 3 - nan_values = None - inf_values = None - saturating = True - bit_width_impl_type = BitWidthImplType.CONST - # hypothesis extra - hypothesis_internal_is_this_a_mock_check = False - - -class Fp8e5m2Base(ScaledFloatWeightBase): - bit_width = 8 - exponent_bit_width = 5 - mantissa_bit_width = 2 - nan_values = None - inf_values = None - saturating = True - bit_width_impl_type = BitWidthImplType.CONST - # hypothesis extra - hypothesis_internal_is_this_a_mock_check = False +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight @pytest_cases.fixture @pytest_cases.parametrize('sat', [True, False]) -def fp8e4m3_regular(sat): +def fp8e4m3(sat): - class Fp8e4m3(Fp8e4m3Base): + class Fp8e4m3(Fp8e4m3OCPWeight): saturating = sat - nan_values = tuple(('111',)) - inf_values = None + # for hypothesis and DI + hypothesis_internal_is_this_a_mock_check = True return Fp8e4m3 @pytest_cases.fixture @pytest_cases.parametrize('sat', [True, False]) -def fp8e5m2_regular(sat): +def fp8e5m2(sat): - class Fp8e5m2(Fp8e5m2Base): + class Fp8e5m2(Fp8e5m2OCPWeight): saturating = sat - nan_values = ('01', '11', '10') - inf_values = tuple(('00',)) + # for hypothesis and DI + hypothesis_internal_is_this_a_mock_check = True return Fp8e5m2 -list_of_fixtures = ['fp8e4m3_regular', 'fp8e5m2_regular'] +list_of_fixtures = ['fp8e4m3', 'fp8e5m2'] fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures) diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index f5f0e8d54..fb05c8141 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -6,11 +6,14 @@ from brevitas.quant.experimental.float import Fp8e4m3Weight from brevitas.quant.experimental.float import Fp8e5m2Weight +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight from tests.brevitas.hyp_helper import float_tensor_random_shape_st from .minifloat_fixtures import * -FORMATS = {Fp8e5m2Weight: 57344., Fp8e4m3Weight: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.} +FORMATS = { + Fp8e5m2OCPWeight: 57344., Fp8e4m3OCPWeight: 448., Fp8e4m3Weight: 480., Fp8e5m2Weight: 114688.} @pytest.mark.parametrize( From 3c73c3b73472e4478420824973a61865e39566a4 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 7 Mar 2024 01:11:04 -0800 Subject: [PATCH 15/15] Fix (minifloat): rename clamping test --- tests/brevitas/core/{test_minifloat.py => test_clamp.py} | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) rename tests/brevitas/core/{test_minifloat.py => test_clamp.py} (90%) diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_clamp.py similarity index 90% rename from tests/brevitas/core/test_minifloat.py rename to tests/brevitas/core/test_clamp.py index fb05c8141..5ba5a0a32 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_clamp.py @@ -12,12 +12,13 @@ from .minifloat_fixtures import * -FORMATS = { +FORMAT_MAXVAL_MAP = { Fp8e5m2OCPWeight: 57344., Fp8e4m3OCPWeight: 448., Fp8e4m3Weight: 480., Fp8e5m2Weight: 114688.} @pytest.mark.parametrize( - 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items())) + 'minifloat, expected_max_val', + ((format, max_val) for format, max_val in FORMAT_MAXVAL_MAP.items())) def test_max_value(minifloat, expected_max_val): max_val = minifloat.float_clamp_impl.max_value() @@ -25,7 +26,7 @@ def test_max_value(minifloat, expected_max_val): @given(inp=float_tensor_random_shape_st()) -def test_clamp(inp, fp8_clamp): +def test_float_clamp(inp, fp8_clamp): max_val = fp8_clamp.float_clamp_impl.max_value() # get values that exceed max_val over_limit_mask = inp.abs() > max_val