Skip to content

Commit

Permalink
Feat (quant_tensor): update __truediv__ behaviour to match "standar…
Browse files Browse the repository at this point in the history
…d fixed point rules" (#769)

* [quant_tensor] Updated `__truediv__` behaviour based on #740

* [quant_tensor] Updated div behaviour to throw RuntimeError when non-zero zero-point operands are used

* Fix: changed other.tensor -> other.value
  • Loading branch information
nickfraser authored Dec 21, 2023
1 parent 2e6e179 commit d0c10a5
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,16 @@ def __sub__(self, other):

def __truediv__(self, other):
if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none:
output_tensor = self.value / other.tensor
output_scale = self.scale / other.scale
output_bit_width = self.bit_width - other.bit_width
output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid()
max_int_denominator = 2 ** (other.bit_width - int(other.signed))
output_scale = self.scale / (other.scale * max_int_denominator)
output_bit_width = self.bit_width + other.bit_width
output_signed = self.signed or other.signed
output_training = self.training or other.training
if self.is_zero_zero_point(self) and self.is_zero_zero_point(other):
output_zero_point = self.zero_point * other.zero_point # Output zero_point is a new, zero-valued tensor
else:
output_zero_point = None # TODO non-zero zero point
raise RuntimeError("Zero-points of div operands are non-zero, not supported.")
output = QuantTensor(
value=output_tensor,
scale=output_scale,
Expand Down

0 comments on commit d0c10a5

Please sign in to comment.