Skip to content

Commit

Permalink
Remove workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 18, 2024
1 parent 8f1f5ee commit 96ce806
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def symbolic_execution(self, x: Tensor):
zero_point = dequantize_symbolic_kwargs['zero_point']
bit_width = self.symbolic_kwargs['bit_width']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)

if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
Expand Down Expand Up @@ -246,8 +245,7 @@ def symbolic_execution(self, x: Tensor):
zero_point = dequantize_symbolic_kwargs['zero_point']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
bit_width = self.symbolic_kwargs['bit_width']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)

# If original dtype of the input is (b)float16, cast the input to float32
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
Expand Down

0 comments on commit 96ce806

Please sign in to comment.