Skip to content

Commit

Permalink
Fix (export): add CastMixin and rename classes (#794)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jan 18, 2024
1 parent 8f1f5ee commit 9c09cc4
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 71 deletions.
28 changes: 14 additions & 14 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def cast_fn(self, x, dtype):
pass


class CDQMixin(DQMixin, ABC):
class CDQCastMixin(DQCastMixin, ABC):

@abstractmethod
def clip_fn(self, x, min_val, max_val):
Expand Down Expand Up @@ -102,7 +102,7 @@ def signed_dtype(cls, bit_width, is_signed):
return dtype


class CDQProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQMixin, ABC):
class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
scale_orig_shape = scale.shape
Expand All @@ -125,7 +125,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
'scale_orig_shape': scale_orig_shape}


class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC):
class CDQCastWeightQuantProxyHandlerMixin(CDQCastProxyHandlerMixin, ABC):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
Expand Down Expand Up @@ -179,7 +179,7 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantProxyHandlerMixin(QCDQWeightQuantProxyHandlerMixin, ABC):
class CDQCastDecoupledWeightQuantProxyHandlerMixin(CDQCastWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantProxyFromInjector

def symbolic_execution(self, x: Tensor):
Expand All @@ -188,15 +188,15 @@ def symbolic_execution(self, x: Tensor):
return out, scale, zero_point, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantWithInputProxyHandlerMixin(QCDQDecoupledWeightQuantProxyHandlerMixin,
ABC):
class CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin(
CDQCastDecoupledWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantWithInputProxyFromInjector

def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool):
return super().symbolic_execution(x)


class QCDQActQuantProxyHandlerMixin(QMixin, CDQProxyHandlerMixin, ABC):
class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC):
handled_layer = ActQuantProxyFromInjector

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
Expand Down Expand Up @@ -265,7 +265,7 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class QCDQBiasQuantProxyHandlerMixin(DQMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC):
class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC):
handled_layer = BiasQuantProxyFromInjector

def validate(self, module):
Expand Down Expand Up @@ -325,12 +325,12 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
return y, scale, zero_point, bit_width


class QCDQTruncQuantProxyHandlerMixin(QuantAxisMixin,
ClipMixin,
ZeroPointHandlerMixin,
QMixin,
CDQMixin,
ABC):
class QCDQCastTruncQuantProxyHandlerMixin(QuantAxisMixin,
ClipMixin,
ZeroPointHandlerMixin,
QMixin,
CDQCastMixin,
ABC):
handled_layer = TruncQuantProxyFromInjector

def prepare_for_export(self, module: TruncQuantProxyFromInjector):
Expand Down
44 changes: 23 additions & 21 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import torch

from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler
Expand Down Expand Up @@ -43,7 +44,7 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC):
class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)
Expand Down Expand Up @@ -74,37 +75,38 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
class StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
QCDQCastActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
CDQCastBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
QCDQCastTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
pass

Expand Down
16 changes: 8 additions & 8 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from ..function import IntClipFn
from ..function import QuantizeLinearFn
from ..manager import StdONNXBaseManager
from .handler import StdCDQCastONNXBiasQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdCDQCastONNXWeightQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXBiasQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQCastONNXQuantLSTMLayerHandler
from .handler import StdQCDQCastONNXTruncQuantProxyHandler
from .handler import StdQCDQCastONNXWeightQuantProxyHandler


class StdQCDQONNXManager(StdONNXBaseManager):
Expand All @@ -33,12 +33,12 @@ class StdQCDQONNXManager(StdONNXBaseManager):
"eliminate_unused_initializer"]

handlers = [
StdQCDQCastONNXWeightQuantProxyHandler,
StdQCDQCastONNXBiasQuantProxyHandler,
StdCDQCastONNXWeightQuantProxyHandler,
StdCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantProxyHandler,
StdCDQCastONNXDecoupledWeightQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXQuantLSTMLayerHandler]

custom_fns = [
Expand Down
42 changes: 22 additions & 20 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import torch

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin


Expand Down Expand Up @@ -55,7 +57,7 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class TorchCDQCastMixin(TorchDQCastMixin, ABC):
class TorchCDQCastMixin(CDQCastMixin, TorchDQCastMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return torch.clamp(x, min_val, max_val)
Expand Down Expand Up @@ -93,28 +95,28 @@ def forward(self, *args, **kwargs):
return self.symbolic_execution(*args, **kwargs)


class TorchQCDQCastWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchCDQCastWeightQuantProxyHandler(TorchCDQCastMixin,
CDQCastWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler(
TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):
class TorchCDQCastDecoupledWeightQuantWithInputProxyHandler(
TorchCDQCastMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
Expand All @@ -123,7 +125,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQCastActQuantProxyHandler(TorchQCDQCastMixin,
QCDQActQuantProxyHandlerMixin,
QCDQCastActQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand All @@ -132,14 +134,14 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
return _itemize_clip_bounds(clip_args)


class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin,
QCDQBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchCDQCastBiasQuantProxyHandler(TorchDQCastMixin,
CDQCastBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
pass


class TorchQCDQCastTruncQuantProxyHandler(TorchQCDQCastMixin,
QCDQTruncQuantProxyHandlerMixin,
QCDQCastTruncQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand Down
16 changes: 8 additions & 8 deletions src/brevitas/export/torch/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
from brevitas.export.manager import BaseManager
from brevitas.export.manager import ExportContext

from .handler import TorchCDQCastBiasQuantProxyHandler
from .handler import TorchCDQCastDecoupledWeightQuantProxyHandler
from .handler import TorchCDQCastDecoupledWeightQuantWithInputProxyHandler
from .handler import TorchCDQCastWeightQuantProxyHandler
from .handler import TorchQCDQCastActQuantProxyHandler
from .handler import TorchQCDQCastBiasQuantProxyHandler
from .handler import TorchQCDQCastDecoupledWeightQuantProxyHandler
from .handler import TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler
from .handler import TorchQCDQCastTruncQuantProxyHandler
from .handler import TorchQCDQCastWeightQuantProxyHandler


class TorchQCDQManager(BaseManager):
target_name = 'torch'

handlers = [
TorchQCDQCastWeightQuantProxyHandler,
TorchQCDQCastDecoupledWeightQuantProxyHandler,
TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler,
TorchCDQCastWeightQuantProxyHandler,
TorchCDQCastDecoupledWeightQuantProxyHandler,
TorchCDQCastDecoupledWeightQuantWithInputProxyHandler,
TorchQCDQCastActQuantProxyHandler,
TorchQCDQCastBiasQuantProxyHandler,
TorchCDQCastBiasQuantProxyHandler,
TorchQCDQCastTruncQuantProxyHandler]

@classmethod
Expand Down

0 comments on commit 9c09cc4

Please sign in to comment.