Skip to content

Commit

Permalink
remove second max val class
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 15, 2024
1 parent cee0b58 commit 61ba479
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 70 deletions.
64 changes: 5 additions & 59 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,15 @@ def __init__(
self.max_val_impl = StatelessBuffer(
max_float(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()))
elif self.inf_values is not None and self.nan_values is not None:
# we have values for NaN and inf, so initiate MaxValInfNaN
elif self.nan_values is not None:
# we at least have values for NaN, so initiate MaxValInfNaN
self.max_val_impl = MaxFloatInfNaN(
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
exponent_bias=self.exponent_bias,
nan_values=self.nan_values,
inf_values=self.inf_values,
saturating=self.saturating)
elif self.inf_values is None and self.nan_values is not None:
# we only have values for NaN, so initiate MaxValNaN
self.max_val_impl = MaxFloatNaN(
exponent_bit_width=self.exponent_bit_width,
mantissa_bit_width=self.mantissa_bit_width,
exponent_bias=self.exponent_bias,
nan_values=self.nan_values,
saturating=self.saturating)
else:
# no NaN values but inf values
raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.')
Expand Down Expand Up @@ -169,64 +161,18 @@ def __init__(

self.inf_values = inf_values
self.nan_values = nan_values
self.__special_values = nan_values + inf_values if inf_values is not None else nan_values

self.saturating = saturating

# check that NaN/inf values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values + self.inf_values)):
if any(map(lambda x: len(x) > mantissa_bit_width, self.__special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

@brevitas.jit.script_method
def forward(self):
# idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1
min_special_case = min(map(lambda x: int(x, 2), self.nan_values + self.inf_values))
max_value_mantissa = min_special_case - 1

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
exponent_string = '1' * (self.exponent_bit_width - 1)
exponent_string += '0' # add trailing 0 to reach bit width
# since we decreased exponent, we can use full mantissa
mantissa_string = '1' * self.mantissa_bit_width
else:
# there is a free mantissa code, so use full exponent
exponent_string = '1' * self.exponent_bit_width
# get binary code for max_value_mantissa in the number of mantissa bits
mantissa_string = format(max_value_mantissa, f'0{self.mantissa_bit_width}b')

# we don't need the sign since we're looking for the max value
max_value = get_minifloat_value(
exponent_string=exponent_string,
mantissa_string=mantissa_string,
exponent_bias=self.exponent_bias)
return max_value


class MaxFloatNaN(brevitas.jit.ScriptModule):

def __init__(
self,
exponent_bit_width: Tensor,
mantissa_bit_width: Tensor,
exponent_bias: Tensor,
nan_values: Tuple[str],
saturating: bool = False) -> None:
super(MaxFloatNaN, self).__init__()
self.exponent_bit_width = exponent_bit_width
self.mantissa_bit_width = mantissa_bit_width
self.exponent_bias = exponent_bias

self.nan_values = nan_values
self.saturating = saturating

# check that NaN values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, self.nan_values)):
raise RuntimeError('NaN codes need to be the same length as the mantissa.')

@brevitas.jit.script_method
def forward(self):
# idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1
min_special_case = min(map(lambda x: int(x, 2), self.nan_values))
min_special_case = min(map(lambda x: int(x, 2), self.__special_values))
max_value_mantissa = min_special_case - 1

if max_value_mantissa < 0:
Expand Down
4 changes: 1 addition & 3 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(
signed: bool,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
case_clamp_impl: Optional[nn.Module] = None,
exponent_bias: Optional[int] = None,
scaling_impl: Optional[nn.Module] = None,
float_scaling_impl: Optional[nn.Module] = None,
float_to_int_impl: nn.Module = RoundSte(),
Expand All @@ -45,8 +45,6 @@ def __init__(
raise RuntimeError("Mantissa bit width cannot be 0.")
self.mantissa_bit_width = StatelessBuffer(
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype)))
if exponent_bias is None:
exponent_bias = 2 ** (exponent_bit_width - 1) - 1
self.exponent_bias = StatelessBuffer(
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
self.fp_max_val = StatelessBuffer(
Expand Down
19 changes: 14 additions & 5 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

@given(minifloat_format=random_minifloat_format())
def test_float_quant_defaults(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format

# specifically don't set exponent bias to see if default works
expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1
if exponent_bit_width == 0 or mantissa_bit_width == 0:
Expand All @@ -26,12 +27,14 @@ def test_float_quant_defaults(minifloat_format):
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed)
assert expected_exponent_bias == float_quant.exponent_bias()
assert isinstance(float_quant.float_to_int_impl, RoundSte)
Expand All @@ -41,25 +44,27 @@ def test_float_quant_defaults(minifloat_format):

@given(minifloat_format=random_minifloat_format())
def test_minifloat(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
bit_width, exponent_bit_width, mantissa_bit_width, signed, _ = minifloat_format
assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed)


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
def test_float_to_quant_float(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed)
else:
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed)
expected_out, _, _, bit_width_out = float_quant(inp)

Expand All @@ -71,7 +76,7 @@ def test_float_to_quant_float(inp, minifloat_format):
@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_scaling_impls_called_once(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
scaling_impl = mock.Mock(side_effect=lambda x: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
Expand All @@ -80,6 +85,7 @@ def test_scaling_impls_called_once(inp, minifloat_format):
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
Expand All @@ -88,6 +94,7 @@ def test_scaling_impls_called_once(inp, minifloat_format):
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
Expand All @@ -103,7 +110,7 @@ def test_scaling_impls_called_once(inp, minifloat_format):
scale=float_st())
@jit_disabled_for_mock()
def test_inner_scale(inp, minifloat_format, scale):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
# set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here
scaling_impl = mock.Mock(side_effect=lambda x: scale)
float_scaling_impl = mock.Mock(side_effect=lambda x: 1.)
Expand All @@ -113,6 +120,7 @@ def test_inner_scale(inp, minifloat_format, scale):
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
Expand All @@ -121,6 +129,7 @@ def test_inner_scale(inp, minifloat_format, scale):
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed,
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl)
Expand Down
9 changes: 6 additions & 3 deletions tests/brevitas/hyp_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,14 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=
bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with))
exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width))
signed = draw(st.booleans())

exponent_bias = 2 ** (exponent_bit_width - 1) - 1

# if no budget is left, return
if bit_width == exponent_bit_width:
return bit_width, exponent_bit_width, 0, False
return bit_width, exponent_bit_width, 0, False, exponent_bias
elif bit_width == (exponent_bit_width + int(signed)):
return bit_width, exponent_bit_width, 0, signed
return bit_width, exponent_bit_width, 0, signed, exponent_bias
mantissa_bit_width = bit_width - exponent_bit_width - int(signed)

return bit_width, exponent_bit_width, mantissa_bit_width, signed
return bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias

0 comments on commit 61ba479

Please sign in to comment.