Skip to content

Commit

Permalink
Feat (minifloat): use tensor_clamp_impl for minifloat conversion and …
Browse files Browse the repository at this point in the history
…clean up
  • Loading branch information
fabianandresgrob committed Feb 23, 2024
1 parent 260e20b commit 1b2a64b
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 70 deletions.
108 changes: 51 additions & 57 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from torch import Tensor
from torch.nn import Module

import brevitas
from brevitas.core.utils import StatelessBuffer
Expand Down Expand Up @@ -88,79 +89,70 @@ class FloatClamp(brevitas.jit.ScriptModule):
I.e. setting inf to 1101.111 (E4M3) is not a valid code.
"""

__constants__ = ['nan_values', 'inf_values', 'saturating']

def __init__(
self,
exponent_bit_width: Tensor,
mantissa_bit_width: Tensor,
exponent_bias: Tensor,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
tensor_clamp_impl: Module = TensorClamp(),
nan_values: Optional[Tuple[str]] = None,
inf_values: Optional[Tuple[str]] = None,
saturating: bool = False) -> None:
super(FloatClamp, self).__init__()

self.exponent_bit_width = torch.tensor(exponent_bit_width)
self.mantissa_bit_width = torch.tensor(mantissa_bit_width)
self.exponent_bias = torch.tensor(exponent_bias)

self.nan_values = nan_values
self.inf_values = inf_values
self.saturating = saturating

# inf without NaN not possible
if self.inf_values is None and self.nan_values is None:
self.max_val_impl = StatelessBuffer(
max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias))
elif self.nan_values is not None:
if inf_values is None and nan_values is None:
max_val_impl = StatelessBuffer(
max_float(
torch.tensor(exponent_bit_width),
torch.tensor(mantissa_bit_width),
torch.tensor(exponent_bias)))
elif nan_values is not None:
# we at least have values for NaN, so initiate MaxValInfNaN
self.max_val_impl = MaxFloatInfNaN(
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
exponent_bias=self.exponent_bias,
nan_values=self.nan_values,
inf_values=self.inf_values,
saturating=self.saturating)
max_val_impl = MaxFloatInfNaN(
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
nan_values=nan_values,
inf_values=inf_values)
else:
# no NaN values but inf values
raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.')

self.clamp_impl = CaseClamp(inf_values=self.inf_values, saturating=self.saturating)
# class for clamping to inf/NaN values
self.fpx_clamp_impl = FpXClamp(
inf_values=inf_values, saturating=saturating, tensor_clamp_impl=tensor_clamp_impl)

# get max value for the minifloat config, no need to compute it during forward pass
self.max_value = max_val_impl()

@brevitas.jit.script_method
def forward(self, inp: Tensor):
# get max value for the minifloat config
max_value = self.max_val_impl()
return self.clamp_impl(inp, max_value)
return self.fpx_clamp_impl(inp, self.max_value)


class MaxFloatInfNaN(brevitas.jit.ScriptModule):

def __init__(
self,
exponent_bit_width: Tensor,
mantissa_bit_width: Tensor,
exponent_bias: Tensor,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
nan_values: Tuple[str],
inf_values: Tuple[str],
saturating: bool = False) -> None:
inf_values: Optional[Tuple[str]]) -> None:
super(MaxFloatInfNaN, self).__init__()
self.exponent_bit_width = exponent_bit_width
self.mantissa_bit_width = mantissa_bit_width
self.exponent_bias = exponent_bias

self.inf_values = inf_values
self.nan_values = nan_values
self._special_values = nan_values + inf_values if inf_values is not None else nan_values
self.exponent_bit_width = StatelessBuffer(torch.tensor(exponent_bit_width))
self.mantissa_bit_width = StatelessBuffer(torch.tensor(mantissa_bit_width))
self.exponent_bias = StatelessBuffer(torch.tensor(exponent_bias))

self.saturating = saturating
_special_values = nan_values + inf_values if inf_values is not None else nan_values

# check that NaN/inf values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, self._special_values)):
if any(map(lambda x: len(x) > mantissa_bit_width, _special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

# move computation of min for forward pass here so it's jit compatible
self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), self._special_values)))
self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), _special_values)))

@brevitas.jit.script_method
def forward(self):
Expand All @@ -169,29 +161,30 @@ def forward(self):

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
exponent = torch.tensor(1).repeat(self.exponent_bit_width - 1)
exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)
]) # add trailing 0 to reach bit width
exponent = torch.tensor(1).repeat(self.exponent_bit_width() - 1)
# add trailing 0 to reach bit width
exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)])
# since we decreased exponent, we can use full mantissa
mantissa = torch.tensor(1).repeat(self.mantissa_bit_width)
mantissa = torch.tensor(1).repeat(self.mantissa_bit_width())
else:
# there is a free mantissa code, so use full exponent
exponent = torch.tensor(1).repeat(self.exponent_bit_width)
exponent = torch.tensor(1).repeat(self.exponent_bit_width())
# get binary code for max_value_mantissa in the number of mantissa bits
mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width)
mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width())

# we don't need the sign since we're looking for the max value
max_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias)
exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias())
return max_value


class CaseClamp(brevitas.jit.ScriptModule):
class FpXClamp(brevitas.jit.ScriptModule):

def __init__(self, inf_values: Tuple[str], saturating: bool) -> None:
super(CaseClamp, self).__init__()
def __init__(self, inf_values: Tuple[str], saturating: bool, tensor_clamp_impl: Module) -> None:
super(FpXClamp, self).__init__()
self.inf_values = inf_values
self.saturating = saturating
self.tensor_clamp_impl = tensor_clamp_impl

@brevitas.jit.script_method
def forward(self, x: Tensor, max_value: Tensor):
Expand All @@ -201,10 +194,11 @@ def forward(self, x: Tensor, max_value: Tensor):
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value

if self.saturating:
# clamp everything to +- max_value
x = x.clamp(-max_value, max_value)
else:
# first clamp everything to +- max_value, basically the saturating case
x = self.tensor_clamp_impl(x, min_val=-max_value, max_val=max_value)

if not self.saturating:
# if non-saturating, we need to map values greater than max_val to nan or inf
if self.inf_values is not None:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
case_clamp_impl: Optional[nn.Module] = None,
float_clamp_impl: Optional[nn.Module] = None,
scaling_impl: Optional[nn.Module] = None,
float_scaling_impl: Optional[nn.Module] = None,
float_to_int_impl: nn.Module = RoundSte(),
Expand Down Expand Up @@ -55,8 +55,8 @@ def __init__(
float_scaling_impl = ConstScaling(1., device=device, dtype=dtype)
if scaling_impl is None:
scaling_impl = ConstScaling(1., device=device, dtype=dtype)
if case_clamp_impl is None:
self.case_clamp_impl = FloatClamp(
if float_clamp_impl is None:
self.float_clamp_impl = FloatClamp(
exponent_bit_width=self.exponent_bit_width(),
mantissa_bit_width=self.mantissa_bit_width(),
exponent_bias=self.exponent_bias())
Expand Down Expand Up @@ -92,7 +92,7 @@ def dequantize(self, y, scale):
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y = self.case_clamp_impl(y)
y = self.float_clamp_impl(y)
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.bit_width()
4 changes: 2 additions & 2 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
case_clamp_impl = FloatClamp
float_clamp_impl = FloatClamp
nan_values = (('111',))
inf_values = None
saturating = True
Expand All @@ -65,7 +65,7 @@ class Fp8e5m2Mixin(ExponentBiasMixin):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
case_clamp_impl = FloatClamp
float_clamp_impl = FloatClamp
nan_values = ('01', '11', '10')
inf_values = (('00',))
saturating = True
7 changes: 5 additions & 2 deletions tests/brevitas/core/minifloat_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest_cases import fixture_union

from brevitas.core.function_wrapper import FloatClamp
from brevitas.inject.enum import BitWidthImplType
from brevitas.quant.experimental.float_base import ExponentBiasMixin
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase

Expand All @@ -13,7 +14,8 @@ class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
case_clamp_impl = FloatClamp
float_clamp_impl = FloatClamp
bit_width_impl_type = BitWidthImplType.CONST
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False

Expand All @@ -22,7 +24,8 @@ class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
case_clamp_impl = FloatClamp
float_clamp_impl = FloatClamp
bit_width_impl_type = BitWidthImplType.CONST
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False

Expand Down
10 changes: 5 additions & 5 deletions tests/brevitas/core/test_minifloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@
@pytest.mark.parametrize(
'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items()))
def test_max_value(minifloat, expected_max_val):
max_val = minifloat.case_clamp_impl.max_val_impl()
max_val = minifloat.float_clamp_impl.max_value

assert expected_max_val == max_val


@given(inp=float_tensor_random_shape_st())
def test_clamp(inp, fp8_clamp):
max_val = fp8_clamp.case_clamp_impl.max_val_impl()
max_val = fp8_clamp.float_clamp_impl.max_value
# get values that exceed max_val
over_limit_mask = inp.abs() > max_val

# clamp inp
inp = fp8_clamp.case_clamp_impl(inp)
inp = fp8_clamp.float_clamp_impl(inp)

if fp8_clamp.case_clamp_impl.saturating:
if fp8_clamp.float_clamp_impl.fpx_clamp_impl.saturating:
# should be clamped to +- max val
assert (inp[over_limit_mask].abs() == max_val).all()
else:
# if inf_values, over limit mask should now be all inf
if fp8_clamp.case_clamp_impl.inf_values is not None:
if fp8_clamp.float_clamp_impl.fpx_clamp_impl.inf_values is not None:
# all values exceeding max_val should be inf
assert inp[over_limit_mask].isinf().all()
else:
Expand Down

0 comments on commit 1b2a64b

Please sign in to comment.