From d7d88c6d8ac9eb108db486a16ea05c4f48f0d182 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 7 Nov 2024 15:39:20 +0000 Subject: [PATCH] Feat (export): qonnx minifloat export (#1070) --- src/brevitas/export/onnx/qonnx/function.py | 49 ++++++++++ src/brevitas/export/onnx/qonnx/handler.py | 102 +++++++++++++++++++++ src/brevitas/export/onnx/qonnx/manager.py | 6 +- tests/brevitas/export/test_onnx_fp8.py | 12 +++ 4 files changed, 168 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 3e7faad0e..5160572ef 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -59,6 +59,55 @@ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding return y +class BrevitasFloatQuantFn(Function): + + @staticmethod + def symbolic( + g, + x, + scale, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + has_inf, + has_nan, + saturating, + has_subnormal, + rounding_mode, + max_val): + ret = g.op( + f'{DOMAIN_STRING}::FloatQuant', + x, + scale, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + max_val, + has_inf_i=int(has_inf), + has_nan_i=int(has_nan), + has_subnormal_i=int(has_subnormal), + rounding_mode_s=rounding_mode, + saturation_i=saturating) + ret.setType(x.type()) + return ret + + @staticmethod + def forward( + g, + x, + scale, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + has_inf, + has_nan, + saturating, + has_subnormal, + rounding_mode, + max_val): + return x + + class BrevitasTruncFn(Function): @staticmethod diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index c42714346..5468cd1aa 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -14,14 +14,116 @@ from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector from .function import BrevitasBinaryQuantFn +from .function import BrevitasFloatQuantFn from .function import BrevitasQuantFn from .function import BrevitasQuantLSTMCellFn from .function import BrevitasTruncFn +class BrevitasFloatQuantProxyHandler(ONNXBaseHandler, ABC): + + def validate(self, module): + assert not module.is_groupwise, "Export with Per Group quantization not supported" + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + self.symbolic_kwargs = { + 'scale': + module.scale(), + 'exponent_bit_width': + module.exponent_bit_width(), + 'mantissa_bit_width': + module.mantissa_bit_width(), + 'exponent_bias': + module.exponent_bias(), + 'has_inf': + module.inf_values() is not None, + 'has_nan': + module.nan_values() is not None, + 'saturating': + module.is_saturating(), + 'has_subnormal': + True, # Currently we only support subnormal + 'rounding_mode': + module.rounding_mode, + 'max_float': + torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())} + self.return_args = { + 'scale': module.scale(), + 'zero_point': torch.zeros_like(module.scale()), + 'exponent_bit_width': module.exponent_bit_width(), + 'mantissa_bit_width': module.mantissa_bit_width(), + 'exponent_bias': module.exponent_bias(), + 'saturating': module.is_saturating(), + 'inf_values': module.inf_values(), + 'nan_values': module.nan_values(),} + + def symbolic_execution(self, x: Tensor): + x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values()) + return_args = (x, *self.return_args.values()) + return return_args + + +class BrevitasWeightFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler): + handled_layer = WeightFloatQuantProxyFromInjector + + def __init__(self): + super().__init__() + self.quant_weights = None + + def validate(self, zero_point): + assert zero_point == 0, "Zero-point not supported for minifloat quant." + + def prepare_for_export(self, module: WeightQuantProxyFromInjector): + if module.is_quant_enabled: + first_qweight = module.tracked_module_list[0].quant_weight() + self.validate(first_qweight.zero_point) + self.symbolic_kwargs = { + 'scale': + first_qweight.scale, + 'exponent_bit_width': + first_qweight.exponent_bit_width, + 'mantissa_bit_width': + first_qweight.mantissa_bit_width, + 'exponent_bias': + first_qweight.exponent_bias, + 'has_inf': + first_qweight.inf_values is not None, + 'has_nan': + first_qweight.nan_values is not None, + 'saturating': + first_qweight.saturating, + 'has_subnormal': + True, # Currently we only support subnormal + 'rounding_mode': + module.rounding_mode, + 'max_float': + torch.tensor(module.quant_injector.max_available_float + ).type_as(first_qweight.scale)} + self.return_args = { + 'scale': first_qweight.scale, + 'zero_point': torch.zeros_like(first_qweight.scale), + 'exponent_bit_width': first_qweight.exponent_bit_width, + 'mantissa_bit_width': first_qweight.mantissa_bit_width, + 'exponent_bias': first_qweight.exponent_bias, + 'saturating': first_qweight.saturating, + 'inf_values': first_qweight.inf_values, + 'nan_values': first_qweight.nan_values,} + + def symbolic_execution(self, x: Tensor): + return super().symbolic_execution(x) + + +class BrevitasActFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler): + handled_layer = ActFloatQuantProxyFromInjector + + class BrevitasQuantProxyHandler(ONNXBaseHandler, ABC): def validate(self, module): diff --git a/src/brevitas/export/onnx/qonnx/manager.py b/src/brevitas/export/onnx/qonnx/manager.py index 4975dced3..b7419f489 100644 --- a/src/brevitas/export/onnx/qonnx/manager.py +++ b/src/brevitas/export/onnx/qonnx/manager.py @@ -17,12 +17,14 @@ from .function import BrevitasQuantFn from .function import BrevitasQuantLSTMCellFn from .function import BrevitasTruncFn +from .handler import BrevitasActFloatQuantProxyHandler from .handler import BrevitasActQuantProxyHandler from .handler import BrevitasBiasQuantProxyHandler from .handler import BrevitasDecoupledWeightQuantProxyHandler from .handler import BrevitasDecoupledWeightQuantWithInputProxyHandler from .handler import BrevitasQuantLSTMLayerHandler from .handler import BrevitasTruncQuantProxyHandler +from .handler import BrevitasWeightFloatQuantProxyHandler from .handler import BrevitasWeightQuantProxyHandler @@ -42,7 +44,9 @@ class QONNXManager(ONNXBaseManager): BrevitasDecoupledWeightQuantProxyHandler, BrevitasDecoupledWeightQuantWithInputProxyHandler, BrevitasTruncQuantProxyHandler, - BrevitasQuantLSTMLayerHandler] + BrevitasQuantLSTMLayerHandler, + BrevitasWeightFloatQuantProxyHandler, + BrevitasActFloatQuantProxyHandler] custom_fns = [ DebugMarkerFunction, diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index b7e017484..3f7f470b1 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -7,6 +7,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq +from brevitas.export import export_qonnx import brevitas.nn as qnn from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat @@ -23,6 +24,17 @@ def test_simple_fp8_export(): assert True +@jit_disabled_for_export() +def test_qonnx_simple_fp8_export(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + + model = qnn.QuantLinear( + 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat) + export_qonnx(model, torch.randn(1, 3), 'qonnx_act_weight_fp8.onnx') + assert True + + @jit_disabled_for_export() def test_fp8_export_activation(): if torch_version < version.parse('2.1.0'):