diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index bd4b6a378..c3d521964 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -58,6 +58,11 @@ def __init__( self.scaling_impl = scaling_impl self.float_clamp_impl = float_clamp_impl + # To avoid log(0), we add small a small value based on the used dtype + if dtype is None: + dtype = torch.get_default_dtype() + self.eps = torch.finfo(dtype).tiny + @brevitas.jit.script_method def quantize(self, x: torch.Tensor): scaling_impl_value = self.scaling_impl(x) @@ -66,7 +71,7 @@ def quantize(self, x: torch.Tensor): scale = scaling_impl_value / float_scaling_impl_value scaled_x = x / scale internal_scale = float_internal_scale( - scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min()) + scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) return val_fp_quant, scale diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index b06466d2d..74f42dc94 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -74,6 +74,10 @@ def training(self): def saturating(self): return self.saturating_t.item() + @property + def eps(self): + return torch.finfo(self.scale.dtype).tiny + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} @@ -99,7 +103,8 @@ def _pre_round_float_value(self): scale = self.scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + int_scale = float_internal_scale( + self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -135,7 +140,8 @@ def minifloat(self, float_datatype=True): if self.is_valid: fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + int_scale = float_internal_scale( + self.value, self.mantissa_bit_width, fp_internal_scale, self.eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 29839e971..f838bcf4f 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -93,10 +93,10 @@ def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int): @brevitas.jit.script def float_internal_scale( - x: torch.Tensor, mantissa_bit_width: torch.Tensor, - fp_internal_scale_min: torch.Tensor) -> torch.Tensor: - # Add small EPS to avoid log(0) - eps = torch.finfo(x.dtype).tiny + x: torch.Tensor, + mantissa_bit_width: torch.Tensor, + fp_internal_scale_min: torch.Tensor, + eps: torch.Tensor) -> torch.Tensor: internal_scale = floor_ste(torch.log2(torch.abs(x) + eps)) - mantissa_bit_width internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 021365239..2d4c829f0 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -193,8 +193,9 @@ def test_inner_scale(inp, minifloat_format, scale): max_value = max_val if max_available_float is None else torch.min( max_value, max_available_float) # call internal scale + eps = torch.finfo(inp.dtype).tiny internal_scale = float_internal_scale( - scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min()) + scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min(), eps) val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale) if signed: val_fp_quant = torch.clip(val_fp_quant, -1. * max_val, max_val)