Skip to content

Commit

Permalink
Feat (quant_tensor): support for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 21, 2023
1 parent 6a5f496 commit 3caf0d5
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .torch_handler import QUANT_TENSOR_FN_HANDLER

IS_VALID_ATOL = 2e-1
BFLOAT16_IS_VALID_ATOL = 0.5


class QuantTensorBase(NamedTuple):
Expand Down Expand Up @@ -104,18 +105,27 @@ def is_not_none(self):

@property
def _pre_round_int_value(self):
int_value = self.value / self.scale
int_value = int_value + self.zero_point
value = self.value
scale = self.scale
zero_point = self.zero_point
if self.scale.dtype == torch.bfloat16:
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
zero_point = self.zero_point.type(torch.float32)
int_value = value / scale
int_value = int_value + zero_point
return int_value

@property
def is_valid(self):
if self.is_not_none:
with torch.no_grad():
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
is_int = torch.isclose(
pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all()
# After rounding, cast to the original dtype of the scale factor
rounded_int_value = torch.round(pre_round_int_value).type(self.scale.dtype)
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
Expand Down Expand Up @@ -176,7 +186,12 @@ def int(self, float_datatype=False):
if self.is_valid:
int_value = round_ste(self._pre_round_int_value)
if float_datatype:
return int_value
# Values at 8bit and lower can be represented exactly with float16 and bfloat16
# otherwise (e.g. Int16 bias), we upscale to float32
if self.bit_width <= 8.:
return int_value.type(self.scale.dtype)
else:
return int_value.type(torch.float32)
else:
if self.bit_width <= 8. and self.signed_t.item():
return int_value.to(torch.int8)
Expand Down Expand Up @@ -301,6 +316,8 @@ def cat(tensors, dim, out=None):

def __neg__(self):
neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
neg_value = neg_value.type(self.scale.dtype)
if self.signed:
return QuantTensor(
value=neg_value,
Expand Down Expand Up @@ -432,6 +449,8 @@ def __truediv__(self, other):
def __abs__(self):
if self.signed:
abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
abs_value = abs_value.type(self.scale.dtype)
return QuantTensor(
value=abs_value,
scale=self.scale,
Expand Down

0 comments on commit 3caf0d5

Please sign in to comment.