From 5619a30b568e84a56e6fd13a3397fa61f17c942a Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 5 Feb 2024 09:52:38 -0800 Subject: [PATCH] Feat (FP8): implements conversion from fp32 to user specified fp8 format --- .../core/function_wrapper/__init__.py | 1 + src/brevitas/core/function_wrapper/clamp.py | 37 ++++++++++++ src/brevitas/core/quant/float.py | 4 ++ src/brevitas/function/ops.py | 56 +++++++++++++++++++ src/brevitas/quant/experimental/float_base.py | 11 ++++ 5 files changed, 109 insertions(+) diff --git a/src/brevitas/core/function_wrapper/__init__.py b/src/brevitas/core/function_wrapper/__init__.py index 9d5d929d6..3b3e5428b 100644 --- a/src/brevitas/core/function_wrapper/__init__.py +++ b/src/brevitas/core/function_wrapper/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from .clamp import ClampMin +from .clamp import FloatClamp from .clamp import ScalarClamp from .clamp import TensorClamp from .learned_round import LearnedRoundSte diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index c49eec726..0534aea12 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -9,6 +9,7 @@ from torch import Tensor import brevitas +from brevitas.function import clamp_to_fp_encoding from brevitas.function import tensor_clamp @@ -73,3 +74,39 @@ def __init__(self, min_val: float) -> None: @brevitas.jit.script_method def forward(self, x: Tensor): return x.clamp_min(self.min_val) + + +class FloatClamp(brevitas.jit.ScriptModule): + """" + ScriptModule for clamping minifloat formats to their inf/NaN implementations. + """ + + __constants__ = ['nan_value', 'inf_value', 'max_value', 'saturating'] + + def __init__(self, nan_value: str, inf_value: str, max_value: float, saturating: bool) -> None: + super(FloatClamp, self).__init__() + self.nan_value = self.mantissa_bits_to_float(nan_value) + self.inf_value = self.mantissa_bits_to_float(inf_value) if inf_value is not None else None + self.max_value = max_value + self.saturating = saturating + + def mantissa_bits_to_float(self, bits: str, frexp_compatible: bool = True) -> float: + res = 1.0 + for i, val in enumerate(bits): + # iterating through from left to right + res += ((2 ** -(i + 1)) * float(val)) + if frexp_compatible: + return res / 2. + else: + return res + + @brevitas.jit.script_method + def forward(self, x: Tensor, exponent_bit_width: int, mantissa_bit_width: int): + return clamp_to_fp_encoding( + x, + exponent_bit_width, + mantissa_bit_width, + nan_value=self.nan_value, + inf_value=self.inf_value, + max_value=self.max_value, + saturating=self.saturating) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index fcbed7b66..23ac2b5de 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -23,6 +23,7 @@ def __init__( signed: bool, exponent_bit_width: int, mantissa_bit_width: int, + case_clamp_impl: nn.Module, exponent_bias: Optional[int] = None, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, @@ -53,6 +54,7 @@ def __init__( self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl self.scaling_impl = scaling_impl + self.case_clamp_impl = case_clamp_impl @brevitas.jit.script_method def internal_scale(self, x): @@ -80,6 +82,8 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): y, scale = self.quantize(x) + # after quantizing, clamp to special cases like NaN, inf + y = self.case_clamp_impl(y, self.exponent_bit_width(), self.mantissa_bit_width()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.bit_width() diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 6751ab69c..fa2acf23f 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -11,6 +11,10 @@ import brevitas +P_INF_TENSOR = torch.tensor(float('inf')) +N_INF_TENSOR = torch.tensor(float('-inf')) +NAN_TENSOR = torch.tensor(float('nan')) + @brevitas.jit.script def binary_sign(x: Tensor) -> Tensor: @@ -217,3 +221,55 @@ def get_upper_bound_on_l1_norm( max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) return max_accumulator_mag * max_input_mag_inverse + + +def clamp_to_fp_encoding( + x: Tensor, + exponent_bit_width: int, + mantissa_bit_width: int, + nan_value: Tensor, + inf_value: Tensor, + max_value: Tensor, + saturating: bool): + """ + Clamp any values that exceed inf/NaN special codes to these. Differentiates between saturating + and non-saturating mode. + + nan_value needs to be set to the min NaN value there is. + """ + # TODO: question regarding inf/NaN values + max_exponent_value = 2 ** (exponent_bit_width - 1) + # decompose value + mantissa, exponent = torch.frexp(x) + # check if any of the exponent values are all 1s, i.e. equal the max_exponent value + exponent_mask = exponent - 1 >= max_exponent_value # - 1 because frexp returns exponent in range [-125, 128], but actual exponent bits are in range [-126, 127] + is_nan_mask = mantissa.abs() >= nan_value # nan_value is the min NaN value + is_inf_mask = mantissa == inf_value + full_nan_mask = torch.logical_and(exponent_mask, is_nan_mask) + full_inf_mask = torch.logical_and(exponent_mask, is_inf_mask) + if saturating: + # set all values of mantissa_nan_mask and exponent_mask to NaN + x[full_nan_mask] = NAN_TENSOR + + # set all inf_values to max_val + x[full_inf_mask] = max_value + + # clamp absolute values greater than max to +- max val + x = torch.clamp(x, -max_value, max_value) + + return x + else: + # in non saturating case, just set all exceeding values to nan + x[full_nan_mask] = NAN_TENSOR + if inf_value is None: + # we just set all values to NaN + x[full_inf_mask] = NAN_TENSOR + else: + # set inf values to +- infinity + x[full_inf_mask] = torch.where(x[full_inf_mask] > 0, P_INF_TENSOR, N_INF_TENSOR) + # clamp all values greater than max_value to +inf + x = torch.where(x > max_value, P_INF_TENSOR, x) + # clamp all values smaller than min_value to -inf + x = torch.where(x < -max_value, N_INF_TENSOR, x) + + return x diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 8da867e27..ceb20e9e7 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.core.function_wrapper import FloatClamp from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector @@ -54,9 +55,19 @@ class Fp8e4m3Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 + case_clamp_impl = FloatClamp + nan_value = '111' + inf_value = None + max_value = 448. + saturating = True class Fp8e5m2Mixin(ExponentBiasMixin): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 + case_clamp_impl = FloatClamp + nan_value = '01' # smallest NaN value. Others are '11' and '10' + inf_value = '00' + max_value = 57344. + saturating = True