Skip to content

Commit

Permalink
Precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 29, 2024
1 parent 6a1fd4a commit 57ab1fe
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def minifloat(self, float_datatype=True):
int_scale = float_internal_scale(
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale

def check_input_type(tensor):
if not isinstance(tensor, FloatQuantTensor):
raise RuntimeError("Tensor is not a FloatQuantTensor")
Expand Down
8 changes: 5 additions & 3 deletions tests/brevitas/core/test_quant_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# pylint: disable=missing-function-docstring, redefined-outer-name

import struct
from typing import Tuple
from typing import List, Tuple, Union

from hypothesis import given
import pytest_cases
Expand All @@ -20,7 +20,9 @@


# debug utility
def to_string(val: torch.Tensor | float, spaced: bool = True, code: str = "f") -> str | list[str]:
def to_string(val: Union[torch.Tensor, float],
spaced: bool = True,
code: str = "f") -> Union[str, List[str]]:
""" Debug util for visualizing float values """

def scalar_to_string(val: float, spaced: bool) -> str:
Expand All @@ -35,7 +37,7 @@ def scalar_to_string(val: float, spaced: bool) -> str:


# debug utility
def check_bits(val: torch.Tensor | float, mbits: int) -> Tuple[bool, int]:
def check_bits(val: Union[torch.Tensor, float], mbits: int) -> Tuple[bool, int]:
""" return (too many precision bits, lowest mantissa bit) """
strings = to_string(val, spaced=False)
if isinstance(strings, str):
Expand Down

0 comments on commit 57ab1fe

Please sign in to comment.