From b4c9d342501e7d93d50582cf834e12f0c5f52146 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 23 Oct 2024 13:34:08 +0100 Subject: [PATCH 1/3] Fix (minifloat): correct minifloat computation and tests --- .../quant_tensor/float_quant_tensor.py | 6 ++++-- .../groupwise_float_quant_tensor.py | 8 ++++++-- .../quant_tensor/test_quant_tensor.py | 19 +++++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index cf4ba1420..ecb2c7ae2 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -103,8 +103,9 @@ def _pre_round_float_value(self): scale = self.scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + eps = torch.finfo(self.scale.dtype).tiny int_scale = float_internal_scale( - self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -140,8 +141,9 @@ def minifloat(self, float_datatype=True): if self.is_valid: fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + eps = torch.finfo(self.scale.dtype).tiny int_scale = float_internal_scale( - self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index fa91bdca1..9b703e0f7 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -144,7 +144,9 @@ def _pre_round_float_value(self): scale = scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + eps = torch.finfo(self.scale_.dtype).tiny + int_scale = float_internal_scale( + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -180,7 +182,9 @@ def minifloat(self, float_datatype=True): if self.is_valid: fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + eps = torch.finfo(self.scale_.dtype).tiny + int_scale = float_internal_scale( + self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index 4e3401057..6f6b4c7d2 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -4,11 +4,14 @@ from packaging import version import pytest +import pytest_cases import torch from brevitas import torch_version from brevitas.nn import QuantIdentity +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant_tensor import FloatQuantTensor from brevitas.quant_tensor import IntQuantTensor @@ -119,3 +122,19 @@ def test_quant_tensor_view(): assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01) assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01) assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01) + + +QUANT_CLASS = {'fp8': Fp8e4m3ActPerTensorFloat, 'mxfp8': MXFloat8e4m3Act} + + +@pytest_cases.parametrize('quant_class_key_vale', QUANT_CLASS.items()) +def test_minifloat(quant_class_key_vale): + key, quant_class = quant_class_key_vale + + x = torch.randn((1, 32)) + q = QuantIdentity(quant_class, group_dim=-1, return_quant_tensor=True) + q.eval() + + qx = q(x) + # Check that minifloat doesn't raise error + qx.minifloat() From 475bd7f242a956252c2bf54d754bd63e401ef46a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 26 Oct 2024 22:48:56 +0200 Subject: [PATCH 2/3] Clean-up --- src/brevitas/quant_tensor/float_quant_tensor.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index ecb2c7ae2..459f0eec7 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -103,9 +103,9 @@ def _pre_round_float_value(self): scale = self.scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - eps = torch.finfo(self.scale.dtype).tiny + eps = torch.finfo(scale.dtype).tiny int_scale = float_internal_scale( - self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) + minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -138,12 +138,17 @@ def device(self): def minifloat(self, float_datatype=True): # TODO: Check if OCP and cast to proper data-type if matching assert float_datatype, "Minifloat quant returns only higher precision dtype" - if self.is_valid: + value = self.value + scale = self.scale + if self.scale.dtype == torch.bfloat16: + value = self.value.type(torch.float32) + scale = self.scale.type(torch.float32) + minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - eps = torch.finfo(self.scale.dtype).tiny + eps = torch.finfo(scale.dtype).tiny int_scale = float_internal_scale( - self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) + minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: From 82576fcf49fef3a5d28c56fe5fa0fc70172a25e1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 26 Oct 2024 22:50:53 +0200 Subject: [PATCH 3/3] Clean-up groupwise --- .../quant_tensor/groupwise_float_quant_tensor.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 9b703e0f7..4a99b0207 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -144,9 +144,9 @@ def _pre_round_float_value(self): scale = scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - eps = torch.finfo(self.scale_.dtype).tiny + eps = torch.finfo(scale.dtype).tiny int_scale = float_internal_scale( - self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) + minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -181,10 +181,15 @@ def minifloat(self, float_datatype=True): assert float_datatype, "Minifloat quant returns only higher precision dtype" if self.is_valid: + value, scale, zp = self.expand() + if self.scale.dtype == torch.bfloat16: + value = value.type(torch.float32) + scale = scale.type(torch.float32) + minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - eps = torch.finfo(self.scale_.dtype).tiny + eps = torch.finfo(scale.dtype).tiny int_scale = float_internal_scale( - self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) + minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: