From 6dd41d5ce01728497b847ef872f1d5b056dae917 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 7 Feb 2024 15:25:23 +0100 Subject: [PATCH] Fix (export): correct q node export (#829) --- src/brevitas/export/common/handler/qcdq.py | 6 +++--- src/brevitas/export/onnx/standard/qcdq/handler.py | 6 +++--- src/brevitas/export/onnx/standard/qcdq/manager.py | 2 +- src/brevitas/export/torch/qcdq/handler.py | 3 +++ 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 6ddee61a4..b8f3c1fc8 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -135,7 +135,6 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): handled_layer = WeightQuantProxyFromInjector - _export_q_node = False def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): # compute axis before redefining scale @@ -194,6 +193,9 @@ def prepare_for_export(self, module): def quantize_from_floating_point(self, x: Tensor): quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + # Before quantization, cast input to float32 + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) return x @@ -230,7 +232,6 @@ def symbolic_execution(self, x: Tensor): class QCDQCastDecoupledWeightQuantProxyHandlerMixin(QCDQCastWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantProxyFromInjector - _export_q_node = False def symbolic_execution(self, x: Tensor): out, scale, zero_point, bit_width = super().symbolic_execution(x) @@ -241,7 +242,6 @@ def symbolic_execution(self, x: Tensor): class QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin( QCDQCastDecoupledWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantWithInputProxyFromInjector - _export_q_node = False def validate(self, module): assert not self._export_q_node, "This proxy requires to export integer weights" diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index fc5764caa..fbc8a3562 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -113,19 +113,19 @@ def quantize_fn(self, x, dtype): class StdQCDQCastONNXWeightQuantProxyHandler(StdQCDQCastONNXMixin, QCDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): - pass + _export_q_node = False class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdQCDQCastONNXMixin, QCDQCastDecoupledWeightQuantProxyHandlerMixin, ONNXBaseHandler): - pass + _export_q_node = False class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( StdQCDQCastONNXMixin, QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): - pass + _export_q_node = False class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index 7d938935b..b1b05f4ad 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -71,4 +71,4 @@ def export_onnx(cls, *args, export_weight_q_node: bool = False, **kwargs): def change_weight_export(cls, export_weight_q_node: bool = False): for handler in cls.handlers: if hasattr(handler, '_export_q_node'): - handler._export_weight_q_node = export_weight_q_node + handler._export_q_node = export_weight_q_node diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index 4ef07598f..7c6b35ad6 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -99,6 +99,7 @@ def forward(self, *args, **kwargs): class TorchQCDQCastWeightQuantProxyHandler(TorchQCDQCastMixin, QCDQCastWeightQuantProxyHandlerMixin, TorchQCDQHandler): + _export_q_node = False @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -109,6 +110,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchQCDQCastMixin, QCDQCastDecoupledWeightQuantProxyHandlerMixin, TorchQCDQHandler): + _export_q_node = False @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -119,6 +121,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler( TorchQCDQCastMixin, QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): + _export_q_node = False @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):