Skip to content

Commit

Permalink
Fix (minifloat): fix jit issues with FloatClamp
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 19, 2024
1 parent 966085e commit 3db9fdc
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 54 deletions.
47 changes: 18 additions & 29 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
from brevitas.core.utils import StatelessBuffer
from brevitas.function import max_float
from brevitas.function import tensor_clamp
from brevitas.utils.float_quant_utils import dec_to_bits
from brevitas.utils.float_quant_utils import get_minifloat_value

P_INF_TENSOR = torch.tensor(float('inf'))
N_INF_TENSOR = torch.tensor(float('-inf'))
NAN_TENSOR = torch.tensor(float('nan'))


class TensorClamp(brevitas.jit.ScriptModule):
"""
Expand Down Expand Up @@ -91,13 +88,7 @@ class FloatClamp(brevitas.jit.ScriptModule):
I.e. setting inf to 1101.111 (E4M3) is not a valid code.
"""

__constants__ = [
'exponent_bit_width',
'mantissa_bit_width',
'exponent_bias',
'nan_values',
'inf_values',
'saturating']
__constants__ = ['nan_values', 'inf_values', 'saturating']

def __init__(
self,
Expand All @@ -109,9 +100,9 @@ def __init__(
saturating: bool = False) -> None:
super(FloatClamp, self).__init__()

self.exponent_bit_width = exponent_bit_width
self.mantissa_bit_width = mantissa_bit_width
self.exponent_bias = exponent_bias
self.exponent_bit_width = torch.tensor(exponent_bit_width)
self.mantissa_bit_width = torch.tensor(mantissa_bit_width)
self.exponent_bias = torch.tensor(exponent_bias)

self.nan_values = nan_values
self.inf_values = inf_values
Expand All @@ -120,8 +111,7 @@ def __init__(
# inf without NaN not possible
if self.inf_values is None and self.nan_values is None:
self.max_val_impl = StatelessBuffer(
max_float(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()))
max_float(self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias))
elif self.nan_values is not None:
# we at least have values for NaN, so initiate MaxValInfNaN
self.max_val_impl = MaxFloatInfNaN(
Expand Down Expand Up @@ -170,7 +160,7 @@ def __init__(
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

# move computation of min for forward pass here so it's jit compatible
self._min_special_case = min(map(lambda x: int(x, 2), self._special_values))
self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), self._special_values)))

@brevitas.jit.script_method
def forward(self):
Expand All @@ -179,21 +169,20 @@ def forward(self):

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
exponent_string = '1' * (self.exponent_bit_width - 1)
exponent_string += '0' # add trailing 0 to reach bit width
exponent = torch.tensor(1).repeat(self.exponent_bit_width - 1)
exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)
]) # add trailing 0 to reach bit width
# since we decreased exponent, we can use full mantissa
mantissa_string = '1' * self.mantissa_bit_width
mantissa = torch.tensor(1).repeat(self.mantissa_bit_width)
else:
# there is a free mantissa code, so use full exponent
exponent_string = '1' * self.exponent_bit_width
exponent = torch.tensor(1).repeat(self.exponent_bit_width)
# get binary code for max_value_mantissa in the number of mantissa bits
mantissa_string = format(max_value_mantissa, f'0{self.mantissa_bit_width}b')
mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width)

# we don't need the sign since we're looking for the max value
max_value = get_minifloat_value(
exponent_string=exponent_string,
mantissa_string=mantissa_string,
exponent_bias=self.exponent_bias)
exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias)
return max_value


Expand All @@ -218,14 +207,14 @@ def forward(self, x: Tensor, max_value: Tensor):
else:
if self.inf_values is not None:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = P_INF_TENSOR
x[n_max_val_mask] = N_INF_TENSOR
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
else:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = NAN_TENSOR
x[full_max_val_mask] = torch.tensor(float('nan'))

# we also map the inf values to NaN in this case
x[inf_mask] = NAN_TENSOR
x[inf_mask] = torch.tensor(float('nan'))

return x
6 changes: 3 additions & 3 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __init__(
scaling_impl = ConstScaling(1., device=device, dtype=dtype)
if case_clamp_impl is None:
self.case_clamp_impl = FloatClamp(
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
exponent_bias=self.exponent_bias)
exponent_bit_width=self.exponent_bit_width(),
mantissa_bit_width=self.mantissa_bit_width(),
exponent_bias=self.exponent_bias())
# Zero-point is currently hardcoded to 0
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
Expand Down
36 changes: 26 additions & 10 deletions src/brevitas/utils/float_quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch import Tensor


def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float:
def mantissa_bits_to_float(bits: Tensor, frexp_compatible: bool = False) -> float:
# computes the decimal place value from a given binary tensor
res = 1.0
for i, val in enumerate(bits):
# iterating through from left to right
res += ((2 ** -(i + 1)) * float(val))
res += ((2 ** -(i + 1)) * val)
if frexp_compatible:
return res / 2.
else:
return res


def get_minifloat_value(
exponent_string: str,
mantissa_string: str,
exponent_bias: Tensor,
sign: str = '0') -> float:
exponent_value = int(exponent_string, 2)
mantissa_value = mantissa_bits_to_float(mantissa_string)
return ((-1) ** float(sign)) * 2 ** (exponent_value - exponent_bias) * mantissa_value
def get_minifloat_value(exponent: Tensor, mantissa: Tensor, exponent_bias: Tensor) -> Tensor:
"""
Returns the minifloat value for a given exponent, mantissa and exponent_bias.
It expects the exponent and mantissa in their binary format.
"""
exponent_value = bits_to_dec(exponent)
mantissa_value = mantissa_bits_to_float(mantissa)
return torch.exp2(exponent_value - exponent_bias) * mantissa_value


def dec_to_bits(value: Tensor, bits: int) -> Tensor:
# set up mask
mask = 2 ** torch.arange(bits - 1, -1, -1).to(value.device, value.dtype)
# add dimension, bitwise_and gets the bits needed for the value, the rest is converting to byte
return value.unsqueeze(-1).bitwise_and(mask).ne(0).byte()


def bits_to_dec(bits: Tensor) -> Tensor:
# get num of bits used
num_bits = len(bits)
# convert by summing decimal values of set bits
return torch.sum((2 ** torch.arange(num_bits - 1, -1, -1)) * bits)
35 changes: 28 additions & 7 deletions tests/brevitas/core/minifloat_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase):
exponent_bit_width = 4
mantissa_bit_width = 3
case_clamp_impl = FloatClamp
nan_values = tuple(('111',))
inf_values = None
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False

Expand All @@ -25,32 +23,55 @@ class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase):
exponent_bit_width = 5
mantissa_bit_width = 2
case_clamp_impl = FloatClamp
nan_values = ('01', '11', '10')
inf_values = tuple(('00',))
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False


@pytest_cases.fixture
@pytest_cases.parametrize('sat', [True, False])
def fp8e4m3(sat):
def fp8e4m3_regular(sat):

class Fp8e4m3(Fp8e4m3Base):
saturating = sat
nan_values = tuple(('111',))
inf_values = None

return Fp8e4m3


@pytest_cases.fixture
@pytest_cases.parametrize('sat', [True, False])
def fp8e5m2(sat):
def fp8e5m2_regular(sat):

class Fp8e5m2(Fp8e5m2Base):
saturating = sat
nan_values = ('01', '11', '10')
inf_values = tuple(('00',))

return Fp8e5m2


list_of_fixtures = ['fp8e4m3', 'fp8e5m2']
@pytest_cases.fixture
@pytest_cases.parametrize('sat', [True, False])
def fp8e4m3_no_special_values(sat):

class Fp8e4m3None(Fp8e4m3Base):
saturating = sat

return Fp8e4m3None


@pytest_cases.fixture
@pytest_cases.parametrize('sat', [True, False])
def fp8e5m2_no_special_values(sat):

class Fp8e5m2None(Fp8e5m2Base):
saturating = sat

return Fp8e5m2None


list_of_fixtures = [
'fp8e4m3_regular', 'fp8e5m2_regular', 'fp8e4m3_no_special_values', 'fp8e5m2_no_special_values']

fp8_clamp = fixture_union('fp8_clamp', list_of_fixtures, ids=list_of_fixtures)
2 changes: 0 additions & 2 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


@given(minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_float_quant_defaults(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format

Expand Down Expand Up @@ -50,7 +49,6 @@ def test_minifloat(minifloat_format):


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_float_to_quant_float(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
if exponent_bit_width == 0 or mantissa_bit_width == 0:
Expand Down
4 changes: 1 addition & 3 deletions tests/brevitas/core/test_minifloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.marker import jit_disabled_for_mock

from .minifloat_fixtures import *

FORMATS = {Fp8e5m2Mixin: 57344., Fp8e4m3Mixin: 448.}
FORMATS = {Fp8e5m2Mixin: 57344., Fp8e4m3Mixin: 448., Fp8e4m3Base: 480., Fp8e5m2Base: 114688.}


@pytest.mark.parametrize(
Expand All @@ -23,7 +22,6 @@ def test_max_value(minifloat, expected_max_val):


@given(inp=float_tensor_random_shape_st())
@jit_disabled_for_mock()
def test_clamp(inp, fp8_clamp):
max_val = fp8_clamp.case_clamp_impl.max_val_impl()
# get values that exceed max_val
Expand Down

0 comments on commit 3db9fdc

Please sign in to comment.