diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 9b703e0f7..4a99b0207 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -144,9 +144,9 @@ def _pre_round_float_value(self): scale = scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - eps = torch.finfo(self.scale_.dtype).tiny + eps = torch.finfo(scale.dtype).tiny int_scale = float_internal_scale( - self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) + minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) minifloat_value = minifloat_value / int_scale return minifloat_value @@ -181,10 +181,15 @@ def minifloat(self, float_datatype=True): assert float_datatype, "Minifloat quant returns only higher precision dtype" if self.is_valid: + value, scale, zp = self.expand() + if self.scale.dtype == torch.bfloat16: + value = value.type(torch.float32) + scale = scale.type(torch.float32) + minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - eps = torch.finfo(self.scale_.dtype).tiny + eps = torch.finfo(scale.dtype).tiny int_scale = float_internal_scale( - self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps) + minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: