From db7a6f5702deb76ddbfec70433eae5956a263f4b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 10 Jun 2024 13:54:52 +0100 Subject: [PATCH] Fix (quant/float): restore fix for log(0) --- src/brevitas/core/quant/float.py | 2 +- src/brevitas/quant_tensor/float_quant_tensor.py | 10 ++++++++-- src/brevitas/utils/torch_utils.py | 8 +++++--- tests/brevitas/core/test_float_quant.py | 3 ++- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 929024c63..c3d521964 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -71,7 +71,7 @@ def quantize(self, x: torch.Tensor): scale = scaling_impl_value / float_scaling_impl_value scaled_x = x / scale internal_scale = float_internal_scale( - scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min()) + scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) return val_fp_quant, scale diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index b06466d2d..74f42dc94 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -74,6 +74,10 @@ def training(self): def saturating(self): return self.saturating_t.item() + @property + def eps(self): + return torch.finfo(self.scale.dtype).tiny + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} @@ -99,7 +103,8 @@ def _pre_round_float_value(self): scale = self.scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + int_scale = float_internal_scale( + self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -135,7 +140,8 @@ def minifloat(self, float_datatype=True): if self.is_valid: fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + int_scale = float_internal_scale( + self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index f7dbe9ef3..f838bcf4f 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -93,10 +93,12 @@ def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int): @brevitas.jit.script def float_internal_scale( - x: torch.Tensor, mantissa_bit_width: torch.Tensor, - fp_internal_scale_min: torch.Tensor) -> torch.Tensor: + x: torch.Tensor, + mantissa_bit_width: torch.Tensor, + fp_internal_scale_min: torch.Tensor, + eps: torch.Tensor) -> torch.Tensor: - internal_scale = floor_ste(torch.log2(torch.abs(x))) - mantissa_bit_width + internal_scale = floor_ste(torch.log2(torch.abs(x) + eps)) - mantissa_bit_width internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) internal_scale = torch.exp2(internal_scale) return internal_scale diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 021365239..2d4c829f0 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -193,8 +193,9 @@ def test_inner_scale(inp, minifloat_format, scale): max_value = max_val if max_available_float is None else torch.min( max_value, max_available_float) # call internal scale + eps = torch.finfo(inp.dtype).tiny internal_scale = float_internal_scale( - scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min()) + scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min(), eps) val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale) if signed: val_fp_quant = torch.clip(val_fp_quant, -1. * max_val, max_val)