Skip to content

Commit

Permalink
style: fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Dec 19, 2024
1 parent 8e15a53 commit 6b5f7c3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
4 changes: 3 additions & 1 deletion tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def test_inner_scale(inp, minifloat_format, scale):
assert torch.equal(out[~out_nans], expected_out[~expected_out_nans])


@given(minifloat_format_and_value=random_minifloat_format_and_value(min_bit_width=4, max_bit_with=10, rand_exp_bias=True))
@given(
minifloat_format_and_value=random_minifloat_format_and_value(
min_bit_width=4, max_bit_with=10, rand_exp_bias=True))
@settings(max_examples=10000)
@jit_disabled_for_mock()
@torch.no_grad()
Expand Down
21 changes: 12 additions & 9 deletions tests/brevitas/hyp_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,18 @@ def min_max_tensor_random_shape_st(draw, min_dims=1, max_dims=4, max_size=3, wid


@st.composite
def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH, rand_exp_bias=False):
def random_minifloat_format(
draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH, rand_exp_bias=False):
""""
Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed.
"""
# TODO: add support for new minifloat format that comes with FloatQuantTensor
bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with))
signed = draw(st.booleans())
exponent_bit_width = draw(st.integers(min_value=1, max_value=bit_width-1-int(signed)))
exponent_bit_width = draw(st.integers(min_value=1, max_value=bit_width - 1 - int(signed)))

if rand_exp_bias:
exponent_bias = draw(st.integers(min_value=-127, max_value=127))
exponent_bias = draw(st.integers(min_value=-127, max_value=127))
else:
exponent_bias = 2 ** (exponent_bit_width - 1) - 1

Expand All @@ -250,28 +251,30 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=


@st.composite
def random_valid_minifloat(draw, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias):
def random_valid_minifloat(
draw, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias):
""""
Generate a valid minifloat value, from the given format. Returns a valid minifloat value
"""
# Sanity-check that the format is valid
assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed)
# Generate int values of the minifloat components
sign = draw(st.integers(min_value=0, max_value=int(signed)))
mantissa = draw(st.integers(min_value=0, max_value=int(2**mantissa_bit_width-1)))
exponent = draw(st.integers(min_value=0, max_value=int(2**exponent_bit_width-1)))
mantissa = draw(st.integers(min_value=0, max_value=int(2 ** mantissa_bit_width - 1)))
exponent = draw(st.integers(min_value=0, max_value=int(2 ** exponent_bit_width - 1)))
# Scale mantissa between 0-1
mantissa_fixed = mantissa / 2**mantissa_bit_width
mantissa_fixed = mantissa / 2 ** mantissa_bit_width
# Add 1 unless denormalised
mantissa_fixed += 0. if exponent == 0 else 1.
# Adjust exponent if denormalised, otherwise leave it unchanged
exponent_value = 1 if exponent == 0 else exponent
valid_minifloat = ((-1.)**sign) * (mantissa_fixed * 2**(exponent_value-exponent_bias))
valid_minifloat = ((-1.) ** sign) * (mantissa_fixed * 2 ** (exponent_value - exponent_bias))
return valid_minifloat, exponent, mantissa, sign


@st.composite
def random_minifloat_format_and_value(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH, rand_exp_bias=False):
def random_minifloat_format_and_value(
draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH, rand_exp_bias=False):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = draw(random_minifloat_format(min_bit_width=min_bit_width, max_bit_with=max_bit_with, rand_exp_bias=rand_exp_bias))
valid_minifloat, exponent, mantissa, sign = draw(random_valid_minifloat(bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, signed=signed, exponent_bias=exponent_bias))
return valid_minifloat, exponent, mantissa, sign, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias

0 comments on commit 6b5f7c3

Please sign in to comment.