Skip to content

Commit

Permalink
Feat: support for dynamic quant export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 18, 2024
1 parent 26b2861 commit c28d8b4
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 1 deletion.
50 changes: 50 additions & 0 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector

from .base import BitWidthHandlerMixin
Expand Down Expand Up @@ -102,6 +103,13 @@ def signed_dtype(cls, bit_width, is_signed):
return dtype


class DynamicQMixin(QMixin, ABC):

@abstractmethod
def quantize_fn(self, x, dtype):
pass


class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
Expand Down Expand Up @@ -265,6 +273,48 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class DynamicQDQActQuantProxyHandlerMixin(DynamicQMixin, DQMixin, ABC):
handled_layer = DynamicActQuantProxyFromInjector

def validate(self, module):
super().validate(module)
assert module.is_signed == False, "Only unsigned quantization is supported"

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
bit_width = module.bit_width()
is_signed = module.is_signed
dtype = self.signed_dtype(bit_width, is_signed)
self.symbolic_kwargs['bit_width'] = bit_width
self.symbolic_kwargs['is_signed'] = is_signed
self.symbolic_kwargs['dtype'] = dtype
else:
self.symbolic_kwargs = None

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'

bit_width = self.symbolic_kwargs['bit_width']
int_dtype = self.symbolic_kwargs['dtype']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(bit_width)
# If original dtype of the input is (b)float16, cast the input to float32
x_dtype = x.dtype
if x_dtype == torch.float16 or x_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)

# x, scale = self.quantize_fn(x, int_dtype)
x, scale, zero_point = self.quantize_fn(x, int_dtype)

x = self.dequantize_fn(x, scale, zero_point, None)
# After dequantization, cast both output and scale to the correct dtype
if x_dtype == torch.float16 or x_dtype == torch.bfloat16:
x = self.cast_fn(x, x_dtype)
scale = self.cast_fn(scale, x_dtype)
return x, scale, zero_point, bit_width


class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC):
handled_layer = BiasQuantProxyFromInjector

Expand Down
16 changes: 16 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@ def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis):
@staticmethod
def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis):
return x.type(output_dtype)


class DynamicQuantizeLinearFn(Function):

@staticmethod
def symbolic(g, x, output_dtype):
x, scale, zp = g.op('DynamicQuantizeLinear', x, outputs=3)
return x, scale, zp

@staticmethod
def forward(ctx, x, output_dtype):
device = x.device
dtype = x.dtype
scale = torch.empty(1, device=device, dtype=dtype)
zero_point = torch.empty(1, device=device, dtype=output_dtype)
return x.type(output_dtype), scale, zero_point
34 changes: 34 additions & 0 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DynamicQDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DynamicQMixin
from brevitas.export.common.handler.qcdq import QMixin
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler

from ..function import CastFn
from ..function import DequantizeLinearFn
from ..function import DynamicQuantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn

Expand Down Expand Up @@ -75,6 +78,31 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdDynamicQDQCastONNXMixin(DynamicQMixin, StdDQCastONNXMixin, ABC):

@classmethod
def int8_dtype(cls):
return torch.int8

@classmethod
def uint8_dtype(cls):
return torch.uint8

@classmethod
def int32_dtype(cls):
return torch.int32

def validate(self, module):
super().validate(module)
# ONNX QuantizeLinear supports only 8b output with round to nearest even.
# Below 8b quantization is supported through clipping.
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
self.validate_8b_bit_width(module.bit_width(), le_then=False)

def quantize_fn(self, x, dtype):
return DynamicQuantizeLinearFn.apply(x, dtype)


class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
Expand All @@ -99,6 +127,12 @@ class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
pass


class StdDynamicQDQCastONNXActQuantProxyHandler(StdDynamicQDQCastONNXMixin,
DynamicQDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
CDQCastBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from brevitas.export.onnx.function import LSTMCellFn

from ..function import DequantizeLinearFn
from ..function import DynamicQuantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn
from ..manager import StdONNXBaseManager
from .handler import StdDynamicQDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXBiasQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler
Expand All @@ -36,6 +38,7 @@ class StdQCDQONNXManager(StdONNXBaseManager):
StdQCDQCastONNXWeightQuantProxyHandler,
StdQCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdDynamicQDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
Expand All @@ -44,6 +47,7 @@ class StdQCDQONNXManager(StdONNXBaseManager):
custom_fns = [
DebugMarkerFunction,
QuantizeLinearFn,
DynamicQuantizeLinearFn,
DequantizeLinearFn,
IntClipFn,
LSTMCellFn,]
Expand Down
30 changes: 30 additions & 0 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DynamicQDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin


Expand Down Expand Up @@ -89,6 +90,29 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return y.int_repr()


class TorchDynamicQDQCastMixin(QMixin, TorchDQCastMixin, ABC):

@classmethod
def int8_dtype(cls):
return torch.qint8

@classmethod
def uint8_dtype(cls):
return torch.quint8

@classmethod
def int32_dtype(cls):
return torch.qint32

def validate(self, module):
super().validate(module)
self.validate_8b_bit_width(module.bit_width(), le_then=False)
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'

def quantize_fn(self, x, scale, zero_point, dtype, axis):
raise RuntimeError("Currently there is no representation for Dynamic Quantization in Torch")


class TorchQCDQHandler(BaseHandler):

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -134,6 +158,12 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
return _itemize_clip_bounds(clip_args)


class StdDynamicQDQCastONNXActQuantProxyHandler(TorchDynamicQDQCastMixin,
DynamicQDQActQuantProxyHandlerMixin,
TorchQCDQHandler):
pass


class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin,
CDQCastBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
Expand Down
13 changes: 13 additions & 0 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
return QuantTensor(x, training=self.training)


class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector):

def scale(self, force_eval=True):
raise RuntimeError("Scale for Dynamic Act Quant is input-dependant")

def zero_point(self, force_eval=True):
raise RuntimeError("Zero point for Dynamic Act Quant is input-dependant")

def bit_width(self):
bit_width = self.__call__(self._zero_hw_sentinel()).bit_width
return bit_width


class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def forward(self, x: QuantTensor):
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloatMSE
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerTensorFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloatMSE
from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant
Expand Down Expand Up @@ -108,7 +109,8 @@
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActDynamicPerTensorFloat},
'sym': Int8ActDynamicPerTensorFloat,
'asym': ShiftedUint8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'per_group': {
Expand Down
16 changes: 16 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
from brevitas.inject import value
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE

Expand Down Expand Up @@ -122,11 +124,25 @@ class Int8ActDynamicPerTensorFloat(Int8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
"""
proxy_class = DynamicActQuantProxyFromInjector
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverBatchOverTensorView
scaling_stats_op = 'max'


class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
"""
proxy_class = DynamicActQuantProxyFromInjector
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'max'
zero_point_stats_impl = NegativeMinOrZero
dynamic_scaling_broadcastable_shape = (-1,)
stats_reduce_dim = 0


class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat):
"""
Symmetric quantizer with per row dynamic scale.
Expand Down

0 comments on commit c28d8b4

Please sign in to comment.