From 9c09cc4a53893e93fa0a2bfe4b2bd39ecdaf819f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 18 Jan 2024 15:13:10 +0100 Subject: [PATCH] Fix (export): add CastMixin and rename classes (#794) --- src/brevitas/export/common/handler/qcdq.py | 28 ++++++------ .../export/onnx/standard/qcdq/handler.py | 44 ++++++++++--------- .../export/onnx/standard/qcdq/manager.py | 16 +++---- src/brevitas/export/torch/qcdq/handler.py | 42 +++++++++--------- src/brevitas/export/torch/qcdq/manager.py | 16 +++---- 5 files changed, 75 insertions(+), 71 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index a63b6b5b5..eba9c51f1 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -58,7 +58,7 @@ def cast_fn(self, x, dtype): pass -class CDQMixin(DQMixin, ABC): +class CDQCastMixin(DQCastMixin, ABC): @abstractmethod def clip_fn(self, x, min_val, max_val): @@ -102,7 +102,7 @@ def signed_dtype(cls, bit_width, is_signed): return dtype -class CDQProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQMixin, ABC): +class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): scale_orig_shape = scale.shape @@ -125,7 +125,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): 'scale_orig_shape': scale_orig_shape} -class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC): +class CDQCastWeightQuantProxyHandlerMixin(CDQCastProxyHandlerMixin, ABC): handled_layer = WeightQuantProxyFromInjector def prepare_for_export(self, module): @@ -179,7 +179,7 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width -class QCDQDecoupledWeightQuantProxyHandlerMixin(QCDQWeightQuantProxyHandlerMixin, ABC): +class CDQCastDecoupledWeightQuantProxyHandlerMixin(CDQCastWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantProxyFromInjector def symbolic_execution(self, x: Tensor): @@ -188,15 +188,15 @@ def symbolic_execution(self, x: Tensor): return out, scale, zero_point, scale, zero_point, bit_width -class QCDQDecoupledWeightQuantWithInputProxyHandlerMixin(QCDQDecoupledWeightQuantProxyHandlerMixin, - ABC): +class CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin( + CDQCastDecoupledWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantWithInputProxyFromInjector def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool): return super().symbolic_execution(x) -class QCDQActQuantProxyHandlerMixin(QMixin, CDQProxyHandlerMixin, ABC): +class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC): handled_layer = ActQuantProxyFromInjector def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): @@ -265,7 +265,7 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width -class QCDQBiasQuantProxyHandlerMixin(DQMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC): +class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC): handled_layer = BiasQuantProxyFromInjector def validate(self, module): @@ -325,12 +325,12 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): return y, scale, zero_point, bit_width -class QCDQTruncQuantProxyHandlerMixin(QuantAxisMixin, - ClipMixin, - ZeroPointHandlerMixin, - QMixin, - CDQMixin, - ABC): +class QCDQCastTruncQuantProxyHandlerMixin(QuantAxisMixin, + ClipMixin, + ZeroPointHandlerMixin, + QMixin, + CDQCastMixin, + ABC): handled_layer = TruncQuantProxyFromInjector def prepare_for_export(self, module: TruncQuantProxyFromInjector): diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 642ae9174..0723f9574 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -5,14 +5,15 @@ import torch -from brevitas.export.common.handler.qcdq import CDQMixin +from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import \ + CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastMixin +from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQCastMixin -from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin from brevitas.export.onnx.handler import ONNXBaseHandler from brevitas.export.onnx.handler import QuantLSTMLayerHandler @@ -43,7 +44,7 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' -class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC): +class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) @@ -74,37 +75,38 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) -class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, - QCDQWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, + CDQCastWeightQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, + CDQCastDecoupledWeightQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( - StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): +class StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( + StdCDQCastONNXMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, + ONNXBaseHandler): pass class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, - QCDQActQuantProxyHandlerMixin, + QCDQCastActQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin, - QCDQBiasQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin, + CDQCastBiasQuantProxyHandlerMixin, + ONNXBaseHandler): pass class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin, - QCDQTruncQuantProxyHandlerMixin, + QCDQCastTruncQuantProxyHandlerMixin, ONNXBaseHandler): pass diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index ec712672a..d475ef335 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -14,13 +14,13 @@ from ..function import IntClipFn from ..function import QuantizeLinearFn from ..manager import StdONNXBaseManager +from .handler import StdCDQCastONNXBiasQuantProxyHandler +from .handler import StdCDQCastONNXDecoupledWeightQuantProxyHandler +from .handler import StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler +from .handler import StdCDQCastONNXWeightQuantProxyHandler from .handler import StdQCDQCastONNXActQuantProxyHandler -from .handler import StdQCDQCastONNXBiasQuantProxyHandler -from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler -from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler from .handler import StdQCDQCastONNXQuantLSTMLayerHandler from .handler import StdQCDQCastONNXTruncQuantProxyHandler -from .handler import StdQCDQCastONNXWeightQuantProxyHandler class StdQCDQONNXManager(StdONNXBaseManager): @@ -33,12 +33,12 @@ class StdQCDQONNXManager(StdONNXBaseManager): "eliminate_unused_initializer"] handlers = [ - StdQCDQCastONNXWeightQuantProxyHandler, - StdQCDQCastONNXBiasQuantProxyHandler, + StdCDQCastONNXWeightQuantProxyHandler, + StdCDQCastONNXBiasQuantProxyHandler, StdQCDQCastONNXActQuantProxyHandler, - StdQCDQCastONNXDecoupledWeightQuantProxyHandler, + StdCDQCastONNXDecoupledWeightQuantProxyHandler, StdQCDQCastONNXTruncQuantProxyHandler, - StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler, + StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler, StdQCDQCastONNXQuantLSTMLayerHandler] custom_fns = [ diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index b3474a246..41460a6fa 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -6,13 +6,15 @@ import torch from brevitas.export.common.handler.base import BaseHandler +from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import \ + CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastMixin +from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQCastMixin -from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -55,7 +57,7 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' -class TorchCDQCastMixin(TorchDQCastMixin, ABC): +class TorchCDQCastMixin(CDQCastMixin, TorchDQCastMixin, ABC): def clip_fn(self, x, min_val, max_val): return torch.clamp(x, min_val, max_val) @@ -93,9 +95,9 @@ def forward(self, *args, **kwargs): return self.symbolic_execution(*args, **kwargs) -class TorchQCDQCastWeightQuantProxyHandler(TorchCDQCastMixin, - QCDQWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchCDQCastWeightQuantProxyHandler(TorchCDQCastMixin, + CDQCastWeightQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -103,9 +105,9 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin, + CDQCastDecoupledWeightQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -113,8 +115,8 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler( - TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): +class TorchCDQCastDecoupledWeightQuantWithInputProxyHandler( + TorchCDQCastMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -123,7 +125,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQCastActQuantProxyHandler(TorchQCDQCastMixin, - QCDQActQuantProxyHandlerMixin, + QCDQCastActQuantProxyHandlerMixin, TorchQCDQHandler): @classmethod @@ -132,14 +134,14 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin, - QCDQBiasQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchCDQCastBiasQuantProxyHandler(TorchDQCastMixin, + CDQCastBiasQuantProxyHandlerMixin, + TorchQCDQHandler): pass class TorchQCDQCastTruncQuantProxyHandler(TorchQCDQCastMixin, - QCDQTruncQuantProxyHandlerMixin, + QCDQCastTruncQuantProxyHandlerMixin, TorchQCDQHandler): @classmethod diff --git a/src/brevitas/export/torch/qcdq/manager.py b/src/brevitas/export/torch/qcdq/manager.py index 1da072a2d..88987986f 100644 --- a/src/brevitas/export/torch/qcdq/manager.py +++ b/src/brevitas/export/torch/qcdq/manager.py @@ -11,23 +11,23 @@ from brevitas.export.manager import BaseManager from brevitas.export.manager import ExportContext +from .handler import TorchCDQCastBiasQuantProxyHandler +from .handler import TorchCDQCastDecoupledWeightQuantProxyHandler +from .handler import TorchCDQCastDecoupledWeightQuantWithInputProxyHandler +from .handler import TorchCDQCastWeightQuantProxyHandler from .handler import TorchQCDQCastActQuantProxyHandler -from .handler import TorchQCDQCastBiasQuantProxyHandler -from .handler import TorchQCDQCastDecoupledWeightQuantProxyHandler -from .handler import TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler from .handler import TorchQCDQCastTruncQuantProxyHandler -from .handler import TorchQCDQCastWeightQuantProxyHandler class TorchQCDQManager(BaseManager): target_name = 'torch' handlers = [ - TorchQCDQCastWeightQuantProxyHandler, - TorchQCDQCastDecoupledWeightQuantProxyHandler, - TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler, + TorchCDQCastWeightQuantProxyHandler, + TorchCDQCastDecoupledWeightQuantProxyHandler, + TorchCDQCastDecoupledWeightQuantWithInputProxyHandler, TorchQCDQCastActQuantProxyHandler, - TorchQCDQCastBiasQuantProxyHandler, + TorchCDQCastBiasQuantProxyHandler, TorchQCDQCastTruncQuantProxyHandler] @classmethod