diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index d8bb253bb..bd1da8edd 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -15,6 +15,7 @@ from .torch_handler import QUANT_TENSOR_FN_HANDLER IS_VALID_ATOL = 2e-1 +BFLOAT16_IS_VALID_ATOL = 0.5 class QuantTensorBase(NamedTuple): @@ -104,8 +105,15 @@ def is_not_none(self): @property def _pre_round_int_value(self): - int_value = self.value / self.scale - int_value = int_value + self.zero_point + value = self.value + scale = self.scale + zero_point = self.zero_point + if self.scale.dtype == torch.bfloat16: + value = self.value.type(torch.float32) + scale = self.scale.type(torch.float32) + zero_point = self.zero_point.type(torch.float32) + int_value = value / scale + int_value = int_value + zero_point return int_value @property @@ -114,8 +122,9 @@ def is_valid(self): with torch.no_grad(): pre_round_int_value = self._pre_round_int_value rounded_int_value = torch.round(pre_round_int_value) - is_int = torch.isclose( - pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all() + max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_int = max_abs_diff < atol if self.bit_width >= 2: if self.signed: is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() @@ -176,7 +185,12 @@ def int(self, float_datatype=False): if self.is_valid: int_value = round_ste(self._pre_round_int_value) if float_datatype: - return int_value + # Values at 8bit and lower can be represented exactly with float16 and bfloat16 + # otherwise (e.g. Int16 bias), we upscale to float32 + if self.bit_width <= 8.: + return int_value.type(self.scale.dtype) + else: + return int_value.type(torch.float32) else: if self.bit_width <= 8. and self.signed_t.item(): return int_value.to(torch.int8) @@ -301,6 +315,8 @@ def cat(tensors, dim, out=None): def __neg__(self): neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + neg_value = neg_value.type(self.scale.dtype) if self.signed: return QuantTensor( value=neg_value, @@ -432,6 +448,8 @@ def __truediv__(self, other): def __abs__(self): if self.signed: abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + abs_value = abs_value.type(self.scale.dtype) return QuantTensor( value=abs_value, scale=self.scale,