Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Act Quant support #796

Merged
merged 14 commits into from
Jan 22, 2024
49 changes: 49 additions & 0 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector

from .base import BitWidthHandlerMixin
Expand Down Expand Up @@ -102,6 +103,13 @@ def signed_dtype(cls, bit_width, is_signed):
return dtype


class DynamicQMixin(QMixin, ABC):

@abstractmethod
def quantize_fn(self, x, dtype):
pass


class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
Expand Down Expand Up @@ -265,6 +273,47 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class DynamicQDQCastActQuantProxyHandlerMixin(DynamicQMixin, DQCastMixin, ABC):
handled_layer = DynamicActQuantProxyFromInjector

def validate(self, module):
super().validate(module)
assert module.is_signed == False, "Only unsigned quantization is supported"

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
bit_width = module.bit_width()
is_signed = module.is_signed
dtype = self.signed_dtype(bit_width, is_signed)
self.symbolic_kwargs['bit_width'] = bit_width
self.symbolic_kwargs['is_signed'] = is_signed
self.symbolic_kwargs['dtype'] = dtype
else:
self.symbolic_kwargs = None

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'

bit_width = self.symbolic_kwargs['bit_width']
int_dtype = self.symbolic_kwargs['dtype']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(bit_width)
# If original dtype of the input is (b)float16, cast the input to float32
x_dtype = x.dtype
if x_dtype == torch.float16 or x_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)

x, scale, zero_point = self.quantize_fn(x, int_dtype)

x = self.dequantize_fn(x, scale, zero_point, None)
# After dequantization, cast both output and scale to the correct dtype
if x_dtype == torch.float16 or x_dtype == torch.bfloat16:
x = self.cast_fn(x, x_dtype)
scale = self.cast_fn(scale, x_dtype)
return x, scale, zero_point, bit_width


class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC):
handled_layer = BiasQuantProxyFromInjector

Expand Down
16 changes: 16 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@ def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis):
@staticmethod
def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis):
return x.type(output_dtype)


class DynamicQuantizeLinearFn(Function):

@staticmethod
def symbolic(g, x, output_dtype):
x, scale, zp = g.op('DynamicQuantizeLinear', x, outputs=3)
return x, scale, zp

@staticmethod
def forward(ctx, x, output_dtype):
device = x.device
dtype = x.dtype
scale = torch.empty(1, device=device, dtype=dtype)
zero_point = torch.empty(1, device=device, dtype=output_dtype)
return x.type(output_dtype), scale, zero_point
36 changes: 36 additions & 0 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import DynamicQDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DynamicQMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
Expand All @@ -20,6 +22,7 @@

from ..function import CastFn
from ..function import DequantizeLinearFn
from ..function import DynamicQuantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn

Expand Down Expand Up @@ -75,6 +78,33 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdDynamicQDQCastONNXMixin(DynamicQMixin, StdDQCastONNXMixin, ABC):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def int8_dtype(cls):
return torch.int8

@classmethod
def uint8_dtype(cls):
return torch.uint8

@classmethod
def int32_dtype(cls):
return torch.int32

def validate(self, module):
super().validate(module)
# ONNX DynamicQuantizeLinear supports only 8b output with round to nearest even.
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
# Below 8b quantization is not supported.
self.validate_8b_bit_width(module.bit_width(), le_then=False)
# Only per tensor quantization is supported
assert module.is_per_output_channel_scaling, "Only per tensor scaling supported"

def quantize_fn(self, x, dtype):
return DynamicQuantizeLinearFn.apply(x, dtype)


class StdCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
Expand All @@ -99,6 +129,12 @@ class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
pass


class StdDynamicQDQCastONNXActQuantProxyHandler(StdDynamicQDQCastONNXMixin,
DynamicQDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
CDQCastBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from brevitas.export.onnx.function import LSTMCellFn

from ..function import DequantizeLinearFn
from ..function import DynamicQuantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn
from ..manager import StdONNXBaseManager
from .handler import StdCDQCastONNXBiasQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdCDQCastONNXWeightQuantProxyHandler
from .handler import StdDynamicQDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXQuantLSTMLayerHandler
from .handler import StdQCDQCastONNXTruncQuantProxyHandler
Expand All @@ -37,13 +39,15 @@ class StdQCDQONNXManager(StdONNXBaseManager):
StdCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdCDQCastONNXDecoupledWeightQuantProxyHandler,
StdDynamicQDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXQuantLSTMLayerHandler]

custom_fns = [
DebugMarkerFunction,
QuantizeLinearFn,
DynamicQuantizeLinearFn,
DequantizeLinearFn,
IntClipFn,
LSTMCellFn,]
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def _update_state_dict_impl(quant_injector):
return impl


def _is_per_output_channel_scaling(quant_injector):
if 'scaling_per_output_channel' in quant_injector:
return quant_injector.scaling_per_output_channel
return None


@runtime_checkable
class QuantProxyProtocol(Protocol):
is_quant_enabled: bool
Expand Down
18 changes: 18 additions & 0 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import brevitas
from brevitas.quant_tensor import QuantTensor

from .quant_proxy import _is_per_output_channel_scaling
from .quant_proxy import QuantProxyFromInjector
from .quant_proxy import QuantProxyProtocol

Expand Down Expand Up @@ -169,6 +170,23 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
return QuantTensor(x, training=self.training)


class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector):

def scale(self, force_eval=True):
raise RuntimeError("Scale for Dynamic Act Quant is input-dependant")

def zero_point(self, force_eval=True):
raise RuntimeError("Zero point for Dynamic Act Quant is input-dependant")

def bit_width(self):
bit_width = self.__call__(self._zero_hw_sentinel()).bit_width
return bit_width

@property
def is_per_output_channel_scaling(self):
return _is_per_output_channel_scaling(self.quant_injector)


class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def forward(self, x: QuantTensor):
Expand Down
12 changes: 9 additions & 3 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloatMSE
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerGroupFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerTensorFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloatMSE
from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant
Expand Down Expand Up @@ -108,11 +111,14 @@
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActDynamicPerTensorFloat},
'sym': Int8ActDynamicPerTensorFloat,
'asym': ShiftedUint8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'sym': Int8ActDynamicPerRowFloat,
'asym': ShiftedUint8ActDynamicPerRowFloat},
'per_group': {
'sym': Int8ActDynamicPerGroupFloat},}}}},
'sym': Int8ActDynamicPerGroupFloat,
'asym': ShiftedUint8ActDynamicPerGroupFloat},}}}},
'float': {
'static': {
'float_scale': {
Expand Down
46 changes: 44 additions & 2 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
from brevitas.inject import value
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE
Expand Down Expand Up @@ -65,6 +66,10 @@ def reshaped_scaling_shape(module):
block_size = None


class ActDynamicProxyMixin(ExtendedInjector):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
proxy_class = DynamicActQuantProxyFromInjector


class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat):
"""
Block / group / vector signed symmetric int weight quantizer with float scales.
Expand Down Expand Up @@ -122,12 +127,13 @@ class Int8ActDynamicPerTensorFloat(Int8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
"""
proxy_class = DynamicActQuantProxyFromInjector
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverBatchOverTensorView
scaling_stats_op = 'max'


class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat):
class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat, ActDynamicProxyMixin):
"""
Symmetric quantizer with per row dynamic scale.
"""
Expand All @@ -136,7 +142,7 @@ class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat):
scaling_stats_op = 'max'


class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat):
class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat, ActDynamicProxyMixin):
"""
Symmetric quantizer with per group scale.
"""
Expand All @@ -147,3 +153,39 @@ class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat):
@value
def stats_reduce_dim(group_dim):
return group_dim + 1


class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat, ActDynamicProxyMixin):
"""
Asymmetric quantizer with per tensor dynamic scale.
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'max'
zero_point_stats_impl = NegativeMinOrZero
dynamic_scaling_broadcastable_shape = (-1,)
stats_reduce_dim = 0


class ShiftedUint8ActDynamicPerRowFloat(ShiftedUint8ActPerRowFloat, ActDynamicProxyMixin):
"""
Asymmetric quantizer with per row dynamic scale.
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView
scaling_stats_op = 'max'
zero_point_stats_impl = NegativeMinOrZero


class ShiftedUint8ActDynamicPerGroupFloat(ShiftedUint8ActPerRowFloat, ActDynamicProxyMixin):
"""
Asymmetric quantizer with per group dynamic scale.
"""
scaling_impl = RuntimeDynamicGroupStatsScaling
keepdim = True
scaling_stats_op = 'max'
zero_point_stats_impl = NegativeMinOrZero

@value
def stats_reduce_dim(group_dim):
return group_dim + 1
Loading