From 96ce806eef56e9c9af5ff9f1ef981687968cfb00 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 18 Jan 2024 11:10:16 +0000 Subject: [PATCH] Remove workaround --- src/brevitas/export/common/handler/qcdq.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index a63b6b5b5..de5801b9d 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -164,8 +164,7 @@ def symbolic_execution(self, x: Tensor): zero_point = dequantize_symbolic_kwargs['zero_point'] bit_width = self.symbolic_kwargs['bit_width'] scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') - # Workaround to trick the tracer into believing all return values are used - self.assert_ge_zero(scale, zero_point, bit_width) + if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) @@ -246,8 +245,7 @@ def symbolic_execution(self, x: Tensor): zero_point = dequantize_symbolic_kwargs['zero_point'] scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') bit_width = self.symbolic_kwargs['bit_width'] - # Workaround to trick the tracer into believing all return values are used - self.assert_ge_zero(scale, zero_point, bit_width) + # If original dtype of the input is (b)float16, cast the input to float32 if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = self.cast_fn(x, torch.float32)