Skip to content

Commit

Permalink
Fix (export): add CastMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 12, 2024
1 parent ade1036 commit 3472e7c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 40 deletions.
28 changes: 14 additions & 14 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def cast_fn(self, x, dtype):
pass


class CDQMixin(DQMixin, ABC):
class CDQCastMixin(DQCastMixin, ABC):

@abstractmethod
def clip_fn(self, x, min_val, max_val):
Expand Down Expand Up @@ -102,7 +102,7 @@ def signed_dtype(cls, bit_width, is_signed):
return dtype


class CDQProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQMixin, ABC):
class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
scale_orig_shape = scale.shape
Expand All @@ -125,7 +125,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
'scale_orig_shape': scale_orig_shape}


class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC):
class CDQCastWeightQuantProxyHandlerMixin(CDQCastProxyHandlerMixin, ABC):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
Expand Down Expand Up @@ -179,7 +179,7 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantProxyHandlerMixin(QCDQWeightQuantProxyHandlerMixin, ABC):
class CDQCastDecoupledWeightQuantProxyHandlerMixin(CDQCastWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantProxyFromInjector

def symbolic_execution(self, x: Tensor):
Expand All @@ -188,15 +188,15 @@ def symbolic_execution(self, x: Tensor):
return out, scale, zero_point, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantWithInputProxyHandlerMixin(QCDQDecoupledWeightQuantProxyHandlerMixin,
ABC):
class CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin(
CDQCastDecoupledWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantWithInputProxyFromInjector

def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool):
return super().symbolic_execution(x)


class QCDQActQuantProxyHandlerMixin(QMixin, CDQProxyHandlerMixin, ABC):
class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC):
handled_layer = ActQuantProxyFromInjector

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
Expand Down Expand Up @@ -265,7 +265,7 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class QCDQBiasQuantProxyHandlerMixin(DQMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC):
class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC):
handled_layer = BiasQuantProxyFromInjector

def validate(self, module):
Expand Down Expand Up @@ -325,12 +325,12 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
return y, scale, zero_point, bit_width


class QCDQTruncQuantProxyHandlerMixin(QuantAxisMixin,
ClipMixin,
ZeroPointHandlerMixin,
QMixin,
CDQMixin,
ABC):
class QCDQCastTruncQuantProxyHandlerMixin(QuantAxisMixin,
ClipMixin,
ZeroPointHandlerMixin,
QMixin,
CDQCastMixin,
ABC):
handled_layer = TruncQuantProxyFromInjector

def prepare_for_export(self, module: TruncQuantProxyFromInjector):
Expand Down
30 changes: 16 additions & 14 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import torch

from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler
Expand Down Expand Up @@ -43,7 +44,7 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC):
class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)
Expand Down Expand Up @@ -75,36 +76,37 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):


class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
CDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
StdCDQCastONNXMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
QCDQCastActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
CDQCastBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
QCDQCastTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
pass

Expand Down
25 changes: 13 additions & 12 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import torch

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin


Expand Down Expand Up @@ -94,7 +95,7 @@ def forward(self, *args, **kwargs):


class TorchQCDQCastWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQWeightQuantProxyHandlerMixin,
CDQCastWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand All @@ -104,7 +105,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand All @@ -114,7 +115,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler(
TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):
TorchCDQCastMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
Expand All @@ -123,7 +124,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQCastActQuantProxyHandler(TorchQCDQCastMixin,
QCDQActQuantProxyHandlerMixin,
QCDQCastActQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand All @@ -133,13 +134,13 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin,
QCDQBiasQuantProxyHandlerMixin,
CDQCastBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
pass


class TorchQCDQCastTruncQuantProxyHandler(TorchQCDQCastMixin,
QCDQTruncQuantProxyHandlerMixin,
QCDQCastTruncQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand Down

0 comments on commit 3472e7c

Please sign in to comment.