Skip to content

Commit

Permalink
Fix (core): compatibility with FloatClamp and JIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Apr 19, 2024
1 parent 578fd4c commit 4481928
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class FloatClamp(brevitas.jit.ScriptModule):
I.e. setting inf to 1101.111 (E4M3) is not a valid code.
"""

__constants__ = ['saturating', 'has_inf_values']
__constants__ = ['saturating', 'inf_values', 'nan_values', 'signed', 'max_available_float']

def __init__(
self,
Expand Down Expand Up @@ -133,11 +133,11 @@ def forward(

if not self.saturating:
# if non-saturating, we need to map values greater than max_val to nan or inf
if self.inf_values:
if self.inf_values is not None:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
elif self.nan_values:
elif self.nan_values is not None:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = torch.tensor(float('nan'))
Expand Down
23 changes: 19 additions & 4 deletions tests/brevitas/core/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.quant.experimental.float import Fp8e5m2Weight
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight
from brevitas.utils.float_quant_utils import get_max_available_float
from tests.brevitas.hyp_helper import float_tensor_random_shape_st

from .minifloat_fixtures import *
Expand All @@ -26,8 +27,15 @@ def test_max_value(minifloat, expected_max_val):
torch.tensor(minifloat.exponent_bit_width, dtype=torch.float32),
torch.tensor(minifloat.mantissa_bit_width, dtype=torch.float32),
torch.tensor(minifloat.exponent_bias, dtype=torch.float32))
max_available_float = minifloat.float_clamp_impl.max_available_float
max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float())
max_available_float = get_max_available_float(
minifloat.exponent_bit_width,
minifloat.mantissa_bit_width,
minifloat.exponent_bias,
minifloat.float_clamp_impl.nan_values,
minifloat.float_clamp_impl.inf_values,
minifloat.float_clamp_impl.saturating)
max_available_float = torch.tensor(max_available_float)
max_val = torch.min(max_val, max_available_float)

assert expected_max_val == max_val

Expand All @@ -39,8 +47,15 @@ def test_float_clamp(inp, fp8_clamp):
torch.tensor(fp8_clamp.exponent_bit_width, dtype=torch.float32),
torch.tensor(fp8_clamp.mantissa_bit_width, dtype=torch.float32),
torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32))
max_available_float = fp8_clamp.float_clamp_impl.max_available_float
max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float())
max_available_float = get_max_available_float(
fp8_clamp.exponent_bit_width,
fp8_clamp.mantissa_bit_width,
fp8_clamp.exponent_bias,
fp8_clamp.float_clamp_impl.nan_values,
fp8_clamp.float_clamp_impl.inf_values,
fp8_clamp.float_clamp_impl.saturating)
max_available_float = torch.tensor(max_available_float)
max_val = torch.min(max_val, max_available_float)
# get values that exceed max_val
over_limit_mask = inp.abs() > max_val

Expand Down
5 changes: 2 additions & 3 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,13 @@ def test_float_to_quant_float(inp, minifloat_format):
exponent_bias=exponent_bias,
signed=signed,
float_clamp_impl=float_clamp)
expected_out, _, _, bit_width_out = float_quant(inp)
expected_out, *_ = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float)
out_quant = float_quant.float_clamp_impl(
out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias)
assert bit_width_out == bit_width
assert torch.equal(expected_out, out_quant * scale)
assert torch.allclose(expected_out, out_quant * scale)


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
Expand Down

0 comments on commit 4481928

Please sign in to comment.