Skip to content

Commit

Permalink
Clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 26, 2024
1 parent b4c9d34 commit 475bd7f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def _pre_round_float_value(self):
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(self.scale.dtype).tiny
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -138,12 +138,17 @@ def device(self):
def minifloat(self, float_datatype=True):
# TODO: Check if OCP and cast to proper data-type if matching
assert float_datatype, "Minifloat quant returns only higher precision dtype"

if self.is_valid:
value = self.value
scale = self.scale
if self.scale.dtype == torch.bfloat16:
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(self.scale.dtype).tiny
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down

0 comments on commit 475bd7f

Please sign in to comment.