Skip to content

Commit

Permalink
Fix (minifloat): fix private attr for jit and make hypothesis tests w…
Browse files Browse the repository at this point in the history
…ork with jit
  • Loading branch information
fabianandresgrob committed Feb 19, 2024
1 parent 085f443 commit a737c18
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,21 @@ def __init__(

self.inf_values = inf_values
self.nan_values = nan_values
self.__special_values = nan_values + inf_values if inf_values is not None else nan_values
self._special_values = nan_values + inf_values if inf_values is not None else nan_values

self.saturating = saturating

# check that NaN/inf values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, self.__special_values)):
if any(map(lambda x: len(x) > mantissa_bit_width, self._special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

# move computation of min for forward pass here so it's jit compatible
self.__min_special_case = min(map(lambda x: int(x, 2), self.__special_values))
self._min_special_case = min(map(lambda x: int(x, 2), self._special_values))

@brevitas.jit.script_method
def forward(self):
# idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1
max_value_mantissa = self.__min_special_case - 1
max_value_mantissa = self._min_special_case - 1

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
Expand Down
2 changes: 2 additions & 0 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


@given(minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_float_quant_defaults(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format

Expand Down Expand Up @@ -49,6 +50,7 @@ def test_minifloat(minifloat_format):


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_float_to_quant_float(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
if exponent_bit_width == 0 or mantissa_bit_width == 0:
Expand Down
2 changes: 2 additions & 0 deletions tests/brevitas/core/test_minifloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.marker import jit_disabled_for_mock

from .minifloat_fixtures import *

Expand All @@ -22,6 +23,7 @@ def test_max_value(minifloat, expected_max_val):


@given(inp=float_tensor_random_shape_st())
@jit_disabled_for_mock()
def test_clamp(inp, fp8_clamp):
max_val = fp8_clamp.case_clamp_impl.max_val_impl()
# get values that exceed max_val
Expand Down

0 comments on commit a737c18

Please sign in to comment.