diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index eba9c51f1..8953b3338 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -15,6 +15,7 @@ from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector from .base import BitWidthHandlerMixin @@ -102,6 +103,13 @@ def signed_dtype(cls, bit_width, is_signed): return dtype +class DynamicQMixin(QMixin, ABC): + + @abstractmethod + def quantize_fn(self, x, dtype): + pass + + class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): @@ -265,6 +273,43 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width +class DynamicQDQCastActQuantProxyHandlerMixin(DynamicQMixin, DQCastMixin, ABC): + handled_layer = DynamicActQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + bit_width = module.bit_width() + is_signed = module.is_signed + dtype = self.signed_dtype(bit_width, is_signed) + self.symbolic_kwargs['bit_width'] = bit_width + self.symbolic_kwargs['is_signed'] = is_signed + self.symbolic_kwargs['dtype'] = dtype + else: + self.symbolic_kwargs = None + + def symbolic_execution(self, x: Tensor): + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + + bit_width = self.symbolic_kwargs['bit_width'] + int_dtype = self.symbolic_kwargs['dtype'] + # Workaround to trick the tracer into believing all return values are used + self.assert_ge_zero(bit_width) + # If original dtype of the input is (b)float16, cast the input to float32 + x_dtype = x.dtype + if x_dtype == torch.float16 or x_dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) + + x, scale, zero_point = self.quantize_fn(x, int_dtype) + + x = self.dequantize_fn(x, scale, zero_point, None) + # After dequantization, cast both output and scale to the correct dtype + if x_dtype == torch.float16 or x_dtype == torch.bfloat16: + x = self.cast_fn(x, x_dtype) + scale = self.cast_fn(scale, x_dtype) + return x, scale, zero_point, bit_width + + class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC): handled_layer = BiasQuantProxyFromInjector diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py index e19ab4708..aed2782bc 100644 --- a/src/brevitas/export/onnx/standard/function.py +++ b/src/brevitas/export/onnx/standard/function.py @@ -75,3 +75,19 @@ def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis): @staticmethod def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis): return x.type(output_dtype) + + +class DynamicQuantizeLinearFn(Function): + + @staticmethod + def symbolic(g, x, output_dtype): + x, scale, zp = g.op('DynamicQuantizeLinear', x, outputs=3) + return x, scale, zp + + @staticmethod + def forward(ctx, x, output_dtype): + device = x.device + dtype = x.dtype + scale = torch.empty(1, device=device, dtype=dtype) + zero_point = torch.empty(1, device=device, dtype=output_dtype) + return x.type(output_dtype), scale, zero_point diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 0723f9574..84f597eb6 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -12,6 +12,8 @@ from brevitas.export.common.handler.qcdq import CDQCastMixin from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQCastMixin +from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import DynamicQMixin from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -20,6 +22,7 @@ from ..function import CastFn from ..function import DequantizeLinearFn +from ..function import DynamicQuantizeLinearFn from ..function import IntClipFn from ..function import QuantizeLinearFn @@ -75,6 +78,37 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) +class StdDynamicQDQCastONNXMixin(DynamicQMixin, StdDQCastONNXMixin, ABC): + + @classmethod + def int8_dtype(cls): + return torch.int8 + + @classmethod + def uint8_dtype(cls): + return torch.uint8 + + @classmethod + def int32_dtype(cls): + return torch.int32 + + def validate(self, module): + super().validate(module) + + assert module.is_signed == False, "Only unsigned quantization supported" + assert module.quant_injector.scaling_stats_op == 'min_max', "Only min_max scaling op supported" + # ONNX QuantizeLinear supports only 8b output with round to nearest even. + # Below 8b quantization is supported through clipping. + assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + # Below 8b quantization is not supported. + self.validate_8b_bit_width(module.bit_width(), le_then=False) + # Only per tensor quantization is supported + assert not module.quant_injector.scaling_per_output_channel, "Only per tensor scaling supported" + + def quantize_fn(self, x, dtype): + return DynamicQuantizeLinearFn.apply(x, dtype) + + class StdCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, CDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): @@ -99,6 +133,12 @@ class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, pass +class StdDynamicQDQCastONNXActQuantProxyHandler(StdDynamicQDQCastONNXMixin, + DynamicQDQCastActQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + class StdCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin, CDQCastBiasQuantProxyHandlerMixin, ONNXBaseHandler): diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index d475ef335..fb2a2ef50 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -11,6 +11,7 @@ from brevitas.export.onnx.function import LSTMCellFn from ..function import DequantizeLinearFn +from ..function import DynamicQuantizeLinearFn from ..function import IntClipFn from ..function import QuantizeLinearFn from ..manager import StdONNXBaseManager @@ -18,6 +19,7 @@ from .handler import StdCDQCastONNXDecoupledWeightQuantProxyHandler from .handler import StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler from .handler import StdCDQCastONNXWeightQuantProxyHandler +from .handler import StdDynamicQDQCastONNXActQuantProxyHandler from .handler import StdQCDQCastONNXActQuantProxyHandler from .handler import StdQCDQCastONNXQuantLSTMLayerHandler from .handler import StdQCDQCastONNXTruncQuantProxyHandler @@ -37,6 +39,7 @@ class StdQCDQONNXManager(StdONNXBaseManager): StdCDQCastONNXBiasQuantProxyHandler, StdQCDQCastONNXActQuantProxyHandler, StdCDQCastONNXDecoupledWeightQuantProxyHandler, + StdDynamicQDQCastONNXActQuantProxyHandler, StdQCDQCastONNXTruncQuantProxyHandler, StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler, StdQCDQCastONNXQuantLSTMLayerHandler] @@ -44,6 +47,7 @@ class StdQCDQONNXManager(StdONNXBaseManager): custom_fns = [ DebugMarkerFunction, QuantizeLinearFn, + DynamicQuantizeLinearFn, DequantizeLinearFn, IntClipFn, LSTMCellFn,] diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index a650a7755..9d15f3bba 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -169,6 +169,19 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: return QuantTensor(x, training=self.training) +class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector): + + def scale(self, force_eval=True): + raise RuntimeError("Scale for Dynamic Act Quant is input-dependant") + + def zero_point(self, force_eval=True): + raise RuntimeError("Zero point for Dynamic Act Quant is input-dependant") + + def bit_width(self): + bit_width = self.__call__(self._zero_hw_sentinel()).bit_width + return bit_width + + class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): def forward(self, x: QuantTensor): diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 1d316d9a9..2ae9b5f4a 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -34,14 +34,15 @@ from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant -from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerGroupFloat -from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerRowFloat -from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerTensorFloat from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloatMSE +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloatMSE +from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant WEIGHT_QUANT_MAP = { @@ -108,11 +109,12 @@ 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Int8ActDynamicPerTensorFloat}, + 'sym': Int8DynamicActPerTensorFloat, + 'asym': ShiftedUint8DynamicActPerTensorFloat}, 'per_row': { - 'sym': Int8ActDynamicPerRowFloat}, + 'sym': Int8DynamicActPerRowFloat}, 'per_group': { - 'sym': Int8ActDynamicPerGroupFloat},}}}}, + 'sym': Int8DynamicActPerGroupFloat},}}}}, 'float': { 'static': { 'float_scale': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 5c7e82513..e667a857d 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -15,6 +15,7 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value +from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE @@ -65,6 +66,10 @@ def reshaped_scaling_shape(module): block_size = None +class DynamicActProxyMixin(ExtendedInjector): + proxy_class = DynamicActQuantProxyFromInjector + + class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat): """ Block / group / vector signed symmetric int weight quantizer with float scales. @@ -118,32 +123,44 @@ class ShiftedUint8ActPerRowFloatMSE(ShiftedUint8ActPerTensorFloatMSE): scaling_per_output_channel = True -class Int8ActDynamicPerTensorFloat(Int8ActPerTensorFloat): +class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. """ scaling_impl = RuntimeDynamicStatsScaling - scaling_stats_input_view_shape_impl = OverBatchOverTensorView - scaling_stats_op = 'max' + scaling_stats_input_view_shape_impl = OverTensorView + scaling_stats_op = 'min_max' + dynamic_scaling_broadcastable_shape = this.scaling_shape -class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat): +class Int8DynamicActPerRowFloat(DynamicActProxyMixin, Int8ActPerRowFloat): """ Symmetric quantizer with per row dynamic scale. """ scaling_impl = RuntimeDynamicStatsScaling scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView - scaling_stats_op = 'max' + scaling_stats_op = 'min_max' -class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat): +class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerRowFloat): """ Symmetric quantizer with per group scale. """ scaling_impl = RuntimeDynamicGroupStatsScaling keepdim = True - scaling_stats_op = 'max' + scaling_stats_op = 'min_max' @value def stats_reduce_dim(group_dim): return group_dim + 1 + + +class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8ActPerTensorFloat): + """ + Symmetric quantizer with per tensor dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverTensorView + scaling_stats_op = 'min_max' + zero_point_stats_impl = NegativeMinOrZero + dynamic_scaling_broadcastable_shape = this.scaling_shape diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index c05fd59b9..a7f87cbef 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -16,10 +16,7 @@ from brevitas.nn import QuantConv2d from brevitas.nn import QuantConvTranspose1d from brevitas.nn import QuantConvTranspose2d -from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear -from brevitas.nn import QuantLSTM -from brevitas.nn import TruncAvgPool2d from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint @@ -30,6 +27,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat SEED = 123456 OUT_CH = 16 @@ -57,7 +55,9 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'a2q': (A2QWeightQuantizerForTests, Int8ActPerTensorFloat), 'symmetric_per_tensor_fixed_point': (Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint), 'symmetric_per_channel_fixed_point': - (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)} + (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), + 'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float': + (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat)} LSTM_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), @@ -114,7 +114,8 @@ def recursive_allclose(ort_output, brevitas_output, tolerance): def is_brevitas_ort_close( model, np_input, export_name, export_type, tolerance=None, first_output_only=False): input_t = torch.from_numpy(np_input) - brevitas_output = model(input_t) + with torch.no_grad(): + brevitas_output = model(input_t) if tolerance is not None and export_type == 'qcdq': tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index b12f75b8b..7361c1231 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -5,6 +5,9 @@ from pytest_cases import set_case_id from torch import nn +from brevitas.nn.quant_activation import QuantIdentity +from brevitas.nn.quant_avg_pool import TruncAvgPool2d +from brevitas.nn.quant_rnn import QuantLSTM from brevitas.quant.scaled_int import Int32Bias from .common import * diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index a4f7e7f5c..0b1277686 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -27,6 +27,8 @@ def test_ort_wbiol(model, export_type, current_cases): impl = case_id.split('-')[ -2] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc. quantizer = case_id.split('-')[-6] + o_bit_width = case_id.split('-')[-5] + i_bit_width = case_id.split('-')[-3] if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') @@ -34,6 +36,9 @@ def test_ort_wbiol(model, export_type, current_cases): pytest.skip('Per-channel zero-point is not well supported in ORT.') if 'QuantLinear' in impl and 'asymmetric' in quantizer: pytest.skip('ORT execution is unreliable and fails randomly on a subset of cases.') + if 'dynamic' in quantizer and ((o_bit_width != "o8" or i_bit_width != "i8") or + export_type != "qcdq"): + pytest.skip('Dynamic Act Quant supported only for 8bit and QCDQ export') if impl in ('QuantLinear'): in_size = (1, IN_CH)