diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index c20e38a87..fa9fe1f73 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -154,6 +154,7 @@ def minifloat(self, float_datatype=True): int_scale = float_internal_scale( minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale + def check_input_type(tensor): if not isinstance(tensor, FloatQuantTensor): raise RuntimeError("Tensor is not a FloatQuantTensor") diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py index b2ab279d4..fb7c8ab44 100644 --- a/tests/brevitas/core/test_quant_mx.py +++ b/tests/brevitas/core/test_quant_mx.py @@ -4,7 +4,7 @@ # pylint: disable=missing-function-docstring, redefined-outer-name import struct -from typing import Tuple +from typing import List, Tuple, Union from hypothesis import given import pytest_cases @@ -20,7 +20,9 @@ # debug utility -def to_string(val: torch.Tensor | float, spaced: bool = True, code: str = "f") -> str | list[str]: +def to_string(val: Union[torch.Tensor, float], + spaced: bool = True, + code: str = "f") -> Union[str, List[str]]: """ Debug util for visualizing float values """ def scalar_to_string(val: float, spaced: bool) -> str: @@ -35,7 +37,7 @@ def scalar_to_string(val: float, spaced: bool) -> str: # debug utility -def check_bits(val: torch.Tensor | float, mbits: int) -> Tuple[bool, int]: +def check_bits(val: Union[torch.Tensor, float], mbits: int) -> Tuple[bool, int]: """ return (too many precision bits, lowest mantissa bit) """ strings = to_string(val, spaced=False) if isinstance(strings, str):