diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index cf4ba1420..ecb2c7ae2 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -103,8 +103,9 @@ def _pre_round_float_value(self): scale = self.scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + eps = torch.finfo(self.scale.dtype).tiny int_scale = float_internal_scale( - self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -140,8 +141,9 @@ def minifloat(self, float_datatype=True): if self.is_valid: fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + eps = torch.finfo(self.scale.dtype).tiny int_scale = float_internal_scale( - self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index fa91bdca1..9b703e0f7 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -144,7 +144,9 @@ def _pre_round_float_value(self): scale = scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + eps = torch.finfo(self.scale_.dtype).tiny + int_scale = float_internal_scale( + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -180,7 +182,9 @@ def minifloat(self, float_datatype=True): if self.is_valid: fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + eps = torch.finfo(self.scale_.dtype).tiny + int_scale = float_internal_scale( + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index 4e3401057..6f6b4c7d2 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -4,11 +4,14 @@ from packaging import version import pytest +import pytest_cases import torch from brevitas import torch_version from brevitas.nn import QuantIdentity +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant_tensor import FloatQuantTensor from brevitas.quant_tensor import IntQuantTensor @@ -119,3 +122,19 @@ def test_quant_tensor_view(): assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01) assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01) assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01) + + +QUANT_CLASS = {'fp8': Fp8e4m3ActPerTensorFloat, 'mxfp8': MXFloat8e4m3Act} + + +@pytest_cases.parametrize('quant_class_key_vale', QUANT_CLASS.items()) +def test_minifloat(quant_class_key_vale): + key, quant_class = quant_class_key_vale + + x = torch.randn((1, 32)) + q = QuantIdentity(quant_class, group_dim=-1, return_quant_tensor=True) + q.eval() + + qx = q(x) + # Check that minifloat doesn't raise error + qx.minifloat()