Skip to content

Commit

Permalink
Feat (export/qcdq): add handler for explicit weight Q
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius committed Oct 20, 2023
1 parent 1412174 commit e0cefda
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 29 deletions.
78 changes: 59 additions & 19 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,24 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
'scale_orig_shape': scale_orig_shape}


class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC):
class QCDQProxyHandlerMixin(CDQProxyHandlerMixin, ABC):

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}


class CDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
Expand Down Expand Up @@ -162,7 +179,44 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantProxyHandlerMixin(QCDQWeightQuantProxyHandlerMixin, ABC):
class QCDQWeightQuantProxyHandlerMixin(QCDQProxyHandlerMixin, ABC):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs['bit_width'] = module.bit_width()
self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, module.bit_width())
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
else:
self.symbolic_kwargs = None

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
dequantize_symbolic_kwargs = self.symbolic_kwargs['dequantize_symbolic_kwargs']
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']
bit_width = self.symbolic_kwargs['bit_width']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
return x, scale, zero_point, bit_width


class CDQDecoupledWeightQuantProxyHandlerMixin(CDQWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantProxyFromInjector

def symbolic_execution(self, x: Tensor):
Expand All @@ -171,31 +225,17 @@ def symbolic_execution(self, x: Tensor):
return out, scale, zero_point, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantWithInputProxyHandlerMixin(QCDQDecoupledWeightQuantProxyHandlerMixin,
ABC):
class CDQDecoupledWeightQuantWithInputProxyHandlerMixin(CDQDecoupledWeightQuantProxyHandlerMixin,
ABC):
handled_layer = DecoupledWeightQuantWithInputProxyFromInjector

def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool):
return super().symbolic_execution(x)


class QCDQActQuantProxyHandlerMixin(QMixin, CDQProxyHandlerMixin, ABC):
class QCDQActQuantProxyHandlerMixin(QMixin, QCDQProxyHandlerMixin, ABC):
handled_layer = ActQuantProxyFromInjector

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
Expand Down
17 changes: 12 additions & 5 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import torch

from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import CDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
Expand Down Expand Up @@ -70,20 +71,26 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
class StdCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
CDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
CDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
StdCDQONNXMixin, CDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
pass


Expand Down
21 changes: 16 additions & 5 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import torch

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
Expand Down Expand Up @@ -90,7 +91,17 @@ def forward(self, *args, **kwargs):
return self.symbolic_execution(*args, **kwargs)


class TorchQCDQWeightQuantProxyHandler(TorchCDQMixin,
class TorchCDQWeightQuantProxyHandler(TorchCDQMixin,
CDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQWeightQuantProxyHandler(TorchQCDQMixin,
QCDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

Expand All @@ -101,7 +112,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
CDQDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand All @@ -111,7 +122,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQDecoupledWeightQuantWithInputProxyHandler(
TorchCDQMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):
TorchCDQMixin, CDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
Expand Down

0 comments on commit e0cefda

Please sign in to comment.