From 00c45f48bfded06c65804e09ffe4fb8af2646fad Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 23 Oct 2024 23:28:43 +0100 Subject: [PATCH 1/5] Feat (export): qonnx minifloat export --- src/brevitas/export/onnx/qonnx/function.py | 46 ++++++++++++ src/brevitas/export/onnx/qonnx/handler.py | 81 ++++++++++++++++++++++ src/brevitas/export/onnx/qonnx/manager.py | 6 +- tests/brevitas/export/test_onnx_fp8.py | 12 ++++ 4 files changed, 144 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 3e7faad0e..56701f6e8 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -59,6 +59,52 @@ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding return y +class BrevitasFloatQuantFn(Function): + + @staticmethod + def symbolic( + g, + x, + scale, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + has_inf, + has_nan, + has_subnormal, + rounding_mode, + max_val): + ret = g.op( + f'{DOMAIN_STRING}::FloatQuant', + x, + scale, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + has_inf_i=int(has_inf), + has_nan_i=int(has_nan), + has_subnormal_i=int(has_subnormal), + rounding_mode_s=rounding_mode, + max_val_f=max_val) + ret.setType(x.type()) + return ret + + @staticmethod + def forward( + g, + x, + scale, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + has_inf, + has_nan, + has_subnormal, + rounding_mode, + max_val): + return x + + class BrevitasTruncFn(Function): @staticmethod diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index c42714346..1f4d6780e 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -14,14 +14,95 @@ from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector from .function import BrevitasBinaryQuantFn +from .function import BrevitasFloatQuantFn from .function import BrevitasQuantFn from .function import BrevitasQuantLSTMCellFn from .function import BrevitasTruncFn +class BrevitasFloatQuantProxyHandler(ONNXBaseHandler, ABC): + + def validate(self, module): + assert not module.is_groupwise, "Export with Per Group quantization not supported" + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + self.symbolic_kwargs = { + 'scale': module.scale(), + 'exponent_bit_width': module.exponent_bit_width(), + 'mantissa_bit_width': module.mantissa_bit_width(), + 'exponent_bias': module.exponent_bias(), + 'has_inf': module.inf_values() is not None, + 'has_nan': module.nan_values() is not None, + 'has_subnormal': True, # Currently we only support subnormal + 'rounding_mode': module.rounding_mode, + 'max_float': module.quant_injector.max_available_float} + self.return_args = { + 'scale': module.scale(), + 'zero_point': torch.zeros_like(module.scale()), + 'exponent_bit_width': module.exponent_bit_width(), + 'mantissa_bit_width': module.mantissa_bit_width(), + 'exponent_bias': module.exponent_bias(), + 'saturating': module.is_saturating(), + 'inf_values': module.inf_values(), + 'nan_values': module.nan_values(),} + + def symbolic_execution(self, x: Tensor): + xx = tuple(self.symbolic_kwargs.values()) + scale = self.symbolic_kwargs['scale'] + print(self.symbolic_kwargs.values()) + x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values()) + return x, *self.return_args.values() + + +class BrevitasWeightFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler): + handled_layer = WeightFloatQuantProxyFromInjector + + def __init__(self): + super().__init__() + self.quant_weights = None + + def validate(self, zero_point): + assert zero_point == 0, "Zero-point not supported for binary quant." + + def prepare_for_export(self, module: WeightQuantProxyFromInjector): + if module.is_quant_enabled: + first_qweight = module.tracked_module_list[0].quant_weight() + self.validate(first_qweight.zero_point) + self.symbolic_kwargs = { + 'scale': first_qweight.scale, + 'exponent_bit_width': first_qweight.exponent_bit_width, + 'mantissa_bit_width': first_qweight.mantissa_bit_width, + 'exponent_bias': first_qweight.exponent_bias, + 'has_inf': first_qweight.inf_values is not None, + 'has_nan': first_qweight.nan_values is not None, + 'has_subnormal': True, # Currently we only support subnormal + 'rounding_mode': module.rounding_mode, + 'max_float': module.quant_injector.max_available_float,} + self.return_args = { + 'scale': first_qweight.scale, + 'zero_point': torch.zeros_like(first_qweight.scale), + 'exponent_bit_width': first_qweight.exponent_bit_width, + 'mantissa_bit_width': first_qweight.mantissa_bit_width, + 'exponent_bias': first_qweight.exponent_bias, + 'saturating': first_qweight.saturating, + 'inf_values': first_qweight.inf_values, + 'nan_values': first_qweight.nan_values,} + + def symbolic_execution(self, x: Tensor): + return super().symbolic_execution(x) + + +class BrevitasActFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler): + handled_layer = ActFloatQuantProxyFromInjector + + class BrevitasQuantProxyHandler(ONNXBaseHandler, ABC): def validate(self, module): diff --git a/src/brevitas/export/onnx/qonnx/manager.py b/src/brevitas/export/onnx/qonnx/manager.py index 4975dced3..b7419f489 100644 --- a/src/brevitas/export/onnx/qonnx/manager.py +++ b/src/brevitas/export/onnx/qonnx/manager.py @@ -17,12 +17,14 @@ from .function import BrevitasQuantFn from .function import BrevitasQuantLSTMCellFn from .function import BrevitasTruncFn +from .handler import BrevitasActFloatQuantProxyHandler from .handler import BrevitasActQuantProxyHandler from .handler import BrevitasBiasQuantProxyHandler from .handler import BrevitasDecoupledWeightQuantProxyHandler from .handler import BrevitasDecoupledWeightQuantWithInputProxyHandler from .handler import BrevitasQuantLSTMLayerHandler from .handler import BrevitasTruncQuantProxyHandler +from .handler import BrevitasWeightFloatQuantProxyHandler from .handler import BrevitasWeightQuantProxyHandler @@ -42,7 +44,9 @@ class QONNXManager(ONNXBaseManager): BrevitasDecoupledWeightQuantProxyHandler, BrevitasDecoupledWeightQuantWithInputProxyHandler, BrevitasTruncQuantProxyHandler, - BrevitasQuantLSTMLayerHandler] + BrevitasQuantLSTMLayerHandler, + BrevitasWeightFloatQuantProxyHandler, + BrevitasActFloatQuantProxyHandler] custom_fns = [ DebugMarkerFunction, diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index b7e017484..3f7f470b1 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -7,6 +7,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq +from brevitas.export import export_qonnx import brevitas.nn as qnn from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat @@ -23,6 +24,17 @@ def test_simple_fp8_export(): assert True +@jit_disabled_for_export() +def test_qonnx_simple_fp8_export(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + + model = qnn.QuantLinear( + 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat) + export_qonnx(model, torch.randn(1, 3), 'qonnx_act_weight_fp8.onnx') + assert True + + @jit_disabled_for_export() def test_fp8_export_activation(): if torch_version < version.parse('2.1.0'): From 5cc8b95baeb88fda624d3049bcc92057b2ca8905 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 30 Oct 2024 14:13:17 +0000 Subject: [PATCH 2/5] Review --- src/brevitas/export/onnx/qonnx/function.py | 5 ++++- src/brevitas/export/onnx/qonnx/handler.py | 7 +++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 56701f6e8..5160572ef 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -71,6 +71,7 @@ def symbolic( exponent_bias, has_inf, has_nan, + saturating, has_subnormal, rounding_mode, max_val): @@ -81,11 +82,12 @@ def symbolic( exponent_bit_width, mantissa_bit_width, exponent_bias, + max_val, has_inf_i=int(has_inf), has_nan_i=int(has_nan), has_subnormal_i=int(has_subnormal), rounding_mode_s=rounding_mode, - max_val_f=max_val) + saturation_i=saturating) ret.setType(x.type()) return ret @@ -99,6 +101,7 @@ def forward( exponent_bias, has_inf, has_nan, + saturating, has_subnormal, rounding_mode, max_val): diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 1f4d6780e..3b9a2d41c 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -40,6 +40,7 @@ def prepare_for_export(self, module): 'exponent_bias': module.exponent_bias(), 'has_inf': module.inf_values() is not None, 'has_nan': module.nan_values() is not None, + 'saturating': module.is_saturating(), 'has_subnormal': True, # Currently we only support subnormal 'rounding_mode': module.rounding_mode, 'max_float': module.quant_injector.max_available_float} @@ -54,9 +55,6 @@ def prepare_for_export(self, module): 'nan_values': module.nan_values(),} def symbolic_execution(self, x: Tensor): - xx = tuple(self.symbolic_kwargs.values()) - scale = self.symbolic_kwargs['scale'] - print(self.symbolic_kwargs.values()) x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values()) return x, *self.return_args.values() @@ -69,7 +67,7 @@ def __init__(self): self.quant_weights = None def validate(self, zero_point): - assert zero_point == 0, "Zero-point not supported for binary quant." + assert zero_point == 0, "Zero-point not supported for minifloat quant." def prepare_for_export(self, module: WeightQuantProxyFromInjector): if module.is_quant_enabled: @@ -82,6 +80,7 @@ def prepare_for_export(self, module: WeightQuantProxyFromInjector): 'exponent_bias': first_qweight.exponent_bias, 'has_inf': first_qweight.inf_values is not None, 'has_nan': first_qweight.nan_values is not None, + 'saturating': first_qweight.saturating, 'has_subnormal': True, # Currently we only support subnormal 'rounding_mode': module.rounding_mode, 'max_float': module.quant_injector.max_available_float,} From c5ac133c844a07d0e5b619587d870d3e5d8c35b8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 30 Oct 2024 14:17:03 +0000 Subject: [PATCH 3/5] return args and precommit --- src/brevitas/export/onnx/qonnx/handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 3b9a2d41c..939e3d32d 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -56,7 +56,8 @@ def prepare_for_export(self, module): def symbolic_execution(self, x: Tensor): x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values()) - return x, *self.return_args.values() + return_args = (x, *self.return_args.values()) + return return_args class BrevitasWeightFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler): From 1ddea1c8feed2d9482fd4c18523c41767c13b5de Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 30 Oct 2024 16:20:44 +0100 Subject: [PATCH 4/5] Update handler.py --- src/brevitas/export/onnx/qonnx/handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 939e3d32d..83cc82b03 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -43,7 +43,7 @@ def prepare_for_export(self, module): 'saturating': module.is_saturating(), 'has_subnormal': True, # Currently we only support subnormal 'rounding_mode': module.rounding_mode, - 'max_float': module.quant_injector.max_available_float} + 'max_float': torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())} self.return_args = { 'scale': module.scale(), 'zero_point': torch.zeros_like(module.scale()), @@ -84,7 +84,7 @@ def prepare_for_export(self, module: WeightQuantProxyFromInjector): 'saturating': first_qweight.saturating, 'has_subnormal': True, # Currently we only support subnormal 'rounding_mode': module.rounding_mode, - 'max_float': module.quant_injector.max_available_float,} + 'max_float': torch.tensor(module.quant_injector.max_available_float).type_as(first_qweight.scale)} self.return_args = { 'scale': first_qweight.scale, 'zero_point': torch.zeros_like(first_qweight.scale), From f27071d9ed960f711babcd8d9fc8055c427d230a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 5 Nov 2024 13:14:02 +0000 Subject: [PATCH 5/5] precommit --- src/brevitas/export/onnx/qonnx/handler.py | 61 +++++++++++++++-------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 83cc82b03..5468cd1aa 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -34,16 +34,26 @@ def prepare_for_export(self, module): if module.is_quant_enabled: self.validate(module) self.symbolic_kwargs = { - 'scale': module.scale(), - 'exponent_bit_width': module.exponent_bit_width(), - 'mantissa_bit_width': module.mantissa_bit_width(), - 'exponent_bias': module.exponent_bias(), - 'has_inf': module.inf_values() is not None, - 'has_nan': module.nan_values() is not None, - 'saturating': module.is_saturating(), - 'has_subnormal': True, # Currently we only support subnormal - 'rounding_mode': module.rounding_mode, - 'max_float': torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())} + 'scale': + module.scale(), + 'exponent_bit_width': + module.exponent_bit_width(), + 'mantissa_bit_width': + module.mantissa_bit_width(), + 'exponent_bias': + module.exponent_bias(), + 'has_inf': + module.inf_values() is not None, + 'has_nan': + module.nan_values() is not None, + 'saturating': + module.is_saturating(), + 'has_subnormal': + True, # Currently we only support subnormal + 'rounding_mode': + module.rounding_mode, + 'max_float': + torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())} self.return_args = { 'scale': module.scale(), 'zero_point': torch.zeros_like(module.scale()), @@ -75,16 +85,27 @@ def prepare_for_export(self, module: WeightQuantProxyFromInjector): first_qweight = module.tracked_module_list[0].quant_weight() self.validate(first_qweight.zero_point) self.symbolic_kwargs = { - 'scale': first_qweight.scale, - 'exponent_bit_width': first_qweight.exponent_bit_width, - 'mantissa_bit_width': first_qweight.mantissa_bit_width, - 'exponent_bias': first_qweight.exponent_bias, - 'has_inf': first_qweight.inf_values is not None, - 'has_nan': first_qweight.nan_values is not None, - 'saturating': first_qweight.saturating, - 'has_subnormal': True, # Currently we only support subnormal - 'rounding_mode': module.rounding_mode, - 'max_float': torch.tensor(module.quant_injector.max_available_float).type_as(first_qweight.scale)} + 'scale': + first_qweight.scale, + 'exponent_bit_width': + first_qweight.exponent_bit_width, + 'mantissa_bit_width': + first_qweight.mantissa_bit_width, + 'exponent_bias': + first_qweight.exponent_bias, + 'has_inf': + first_qweight.inf_values is not None, + 'has_nan': + first_qweight.nan_values is not None, + 'saturating': + first_qweight.saturating, + 'has_subnormal': + True, # Currently we only support subnormal + 'rounding_mode': + module.rounding_mode, + 'max_float': + torch.tensor(module.quant_injector.max_available_float + ).type_as(first_qweight.scale)} self.return_args = { 'scale': first_qweight.scale, 'zero_point': torch.zeros_like(first_qweight.scale),