From 5b168c5ab7b33cc87a43b8288fa52c6990d697d4 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:03:48 +0100 Subject: [PATCH 01/44] placeholder version --- src/brevitas/export/common/handler/base.py | 5 ++- src/brevitas/export/common/handler/qcdq.py | 33 ++++++++++++++ .../export/onnx/standard/qcdq/handler.py | 39 ++++++++++++++++ .../export/onnx/standard/qcdq/manager.py | 2 + src/brevitas/proxy/__init__.py | 1 + src/brevitas/proxy/parameter_quant.py | 45 +++++++++++++++++++ src/brevitas/quant/experimental/float_base.py | 4 +- .../quant/experimental/scaled_float.py | 16 +++++++ tests/brevitas/export/test_onnx_fp8.py | 19 ++++++++ 9 files changed, 161 insertions(+), 3 deletions(-) create mode 100644 src/brevitas/quant/experimental/scaled_float.py create mode 100644 tests/brevitas/export/test_onnx_fp8.py diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index 6136a4cdc..2c7ec12b2 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -12,7 +12,7 @@ from brevitas.function.ops import max_int from brevitas.function.ops import min_int -__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin'] +__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin'] class BaseHandler(Module, ABC): @@ -112,6 +112,9 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor): return -cls.validate_scalar_int_exponent(scale) +class FloatZeroPointHandlerMixin(ABC): + pass + class ZeroPointHandlerMixin(ABC): @classmethod diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index c4659ac87..50b6000f3 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -15,6 +15,7 @@ from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.proxy import FloatWeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector @@ -22,6 +23,7 @@ from .base import ClipMixin from .base import QuantAxisMixin from .base import ZeroPointHandlerMixin +from .base import FloatZeroPointHandlerMixin class DQMixin(ABC): @@ -65,6 +67,8 @@ class CDQCastMixin(DQCastMixin, ABC): def clip_fn(self, x, min_val, max_val): pass +class FloatQMixin(ABC): + pass class QMixin(BitWidthHandlerMixin, ABC): @@ -110,6 +114,11 @@ def quantize_fn(self, x, dtype): pass +class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): + + def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): + raise NotImplementedError() + class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): @@ -133,6 +142,30 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): 'scale_orig_shape': scale_orig_shape} +class QCDQCastFloatWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): + handled_layer = FloatWeightQuantProxyFromInjector + + def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): + raise NotImplementedError() + + def prepare_quantize_from_floating_point(self, module): + raise NotImplementedError() + + def prepare_quantize_from_integer(self, module): + raise NotImplementedError() + + def prepare_for_export(self, module): + raise NotImplementedError() + + def quantize_from_floating_point(self, x: Tensor): + raise NotImplementedError() + + def quantize_from_integer(self, x: Tensor): + raise NotImplementedError() + + def symbolic_execution(self, x: Tensor): + raise NotImplementedError() + class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): handled_layer = WeightQuantProxyFromInjector diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index a8f3c507b..a2ec42a00 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -16,7 +16,9 @@ QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastFloatWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin +from brevitas.export.common.handler.qcdq import FloatQMixin from brevitas.export.onnx.handler import ONNXBaseHandler from brevitas.export.onnx.handler import QuantLSTMLayerHandler @@ -27,6 +29,26 @@ from ..function import QuantizeLinearFn + +class StdFloatDQCastONNXMixin(DQCastMixin, ABC): + + def dequantize_fn(self, x, scale, zero_point, axis): + raise NotImplementedError() + + def cast_fn(self, x, dtype): + raise NotImplementedError() + + @property + def flatten_dequantize_params(self): + raise NotImplementedError() + + @property + def itemize_quantize_scalar_params(self): + raise NotImplementedError() + + def validate(self, module): + raise NotImplementedError() + class StdDQCastONNXMixin(DQCastMixin, ABC): def dequantize_fn(self, x, scale, zero_point, axis): @@ -47,11 +69,22 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' +class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): + + def clip_fn(self, x, min_val, max_val): + return IntClipFn.apply(x, min_val, max_val) + class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) +class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): + def validate(self, module): + raise NotImplementedError() + + def quantize_fn(self, x, scale, zero_point, dtype, axis): + raise NotImplementedError() class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC): @@ -112,6 +145,12 @@ def quantize_fn(self, x, dtype): return DynamicQuantizeLinearFn.apply(x, dtype) +class StdQCDQCastONNXFloatWeightQuantProxyHandler(StdFloatQCDQCastONNXMixin, + QCDQCastFloatWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + _export_q_node = False + + class StdQCDQCastONNXWeightQuantProxyHandler(StdQCDQCastONNXMixin, QCDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index b1b05f4ad..dc006226c 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -23,6 +23,7 @@ from .handler import StdQCDQCastONNXQuantLSTMLayerHandler from .handler import StdQCDQCastONNXTruncQuantProxyHandler from .handler import StdQCDQCastONNXWeightQuantProxyHandler +from .handler import StdQCDQCastONNXFloatWeightQuantProxyHandler class StdQCDQONNXManager(StdONNXBaseManager): @@ -36,6 +37,7 @@ class StdQCDQONNXManager(StdONNXBaseManager): handlers = [ StdQCDQCastONNXWeightQuantProxyHandler, + StdQCDQCastONNXFloatWeightQuantProxyHandler, StdCDQCastONNXBiasQuantProxyHandler, StdQCDQCastONNXActQuantProxyHandler, StdQCDQCastONNXDecoupledWeightQuantProxyHandler, diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py index ecf98afc8..7157cf0d6 100644 --- a/src/brevitas/proxy/__init__.py +++ b/src/brevitas/proxy/__init__.py @@ -5,6 +5,7 @@ from .parameter_quant import DecoupledWeightQuantProxyFromInjector from .parameter_quant import DecoupledWeightQuantWithInputProxyFromInjector from .parameter_quant import WeightQuantProxyFromInjector +from .parameter_quant import FloatWeightQuantProxyFromInjector from .runtime_quant import ActQuantProxyFromInjector from .runtime_quant import ClampQuantProxyFromInjector from .runtime_quant import TruncQuantProxyFromInjector diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 4a95b5e0f..0d8f3f9d3 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -24,6 +24,7 @@ __all__ = [ 'WeightQuantProxyFromInjector', + 'FloatWeightQuantProxyFromInjector', 'BiasQuantProxyFromInjector', 'WeightQuantProxyProtocol', 'BiasQuantProxyProtocol'] @@ -76,6 +77,50 @@ def max_uint_value(self, bit_width): return max_int(False, self.is_narrow_range, bit_width) +class FloatWeightQuantProxyFromInjector(ParameterQuantProxyFromInjector, WeightQuantProxyProtocol): + + @property + def tracked_parameter_list(self): + return [m.weight for m in self.tracked_module_list if m.weight is not None] + + @property + def requires_quant_input(self): + return False + + def scale(self): + if not self.is_quant_enabled: + return None + scale = self.__call__(self.tracked_parameter_list[0]).scale + return scale + + def zero_point(self): + if not self.is_quant_enabled: + return None + zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point + return zero_point + + def exponent_bit_width(self): + if not self.is_quant_enabled: + return None + exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width + return exponent_bit_width + + def mantissa_bit_width(self): + if not self.is_quant_enabled: + return None + mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width + return mantissa_bit_width + + def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + if self.is_quant_enabled: + warn("This code should be replace with FloatQuantTensor when it becomes") + impl = self.export_handler if self.export_mode else self.tensor_quant + out, scale, zero_point, bit_width = impl(x) + return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + else: # quantization disabled + return x + + class WeightQuantProxyFromInjector(ParameterQuantProxyFromInjector, WeightQuantProxyProtocol): @property diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 9a2893039..abe2d3250 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -7,7 +7,7 @@ from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector from brevitas.inject import value -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.proxy.parameter_quant import FloatWeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver @@ -28,7 +28,7 @@ def exponent_bias(exponent_bit_width): class FloatWeightBase(FloatBase): - proxy_class = WeightQuantProxyFromInjector + proxy_class = FloatWeightQuantProxyFromInjector class FloatActBase(FloatBase): diff --git a/src/brevitas/quant/experimental/scaled_float.py b/src/brevitas/quant/experimental/scaled_float.py new file mode 100644 index 000000000..5c9b655d9 --- /dev/null +++ b/src/brevitas/quant/experimental/scaled_float.py @@ -0,0 +1,16 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from brevitas.quant.base import MaxStatsScaling +from brevitas.quant.base import PerTensorFloatScaling8bit +from brevitas.quant.experimental.float import Fp8e4m3Weight +from brevitas.quant.solver.weight import WeightQuantSolver + +__all__ = ['Fp8e4m3OCPWeightPerTensorFloat'] + + +class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3Weight, + MaxStatsScaling, + PerTensorFloatScaling8bit, + WeightQuantSolver): + pass diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py new file mode 100644 index 000000000..48da70b9f --- /dev/null +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch + +from brevitas.export import export_onnx_qcdq +import brevitas.nn as qnn +from brevitas.quant.experimental.scaled_float import Fp8e4m3OCPWeightPerTensorFloat + + +def test_simple_fp8_export(): + model = qnn.QuantLinear(3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat) + export_onnx_qcdq(model, torch.randn(1, 3), 'test.onnx', export_weight_q_node=True) + assert True + + +if __name__ == "__main__": + test_simple_fp8_export() + print("Done") From d2b7d2dba39a132861b1f66e67cf99060430f148 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:12:24 +0100 Subject: [PATCH 02/44] checkpoint commit --- src/brevitas/core/quant/float.py | 2 +- src/brevitas/export/common/handler/base.py | 16 +++- src/brevitas/export/common/handler/qcdq.py | 88 +++++++++++++++++-- .../export/onnx/standard/qcdq/handler.py | 6 +- src/brevitas/proxy/parameter_quant.py | 9 +- 5 files changed, 107 insertions(+), 14 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index d5e3d06d9..bd2c66727 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -89,4 +89,4 @@ def forward(self, x): y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies - return y, scale, self.zero_point_impl(), self.bit_width() + return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width() diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index 2c7ec12b2..d90a4b976 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -12,6 +12,8 @@ from brevitas.function.ops import max_int from brevitas.function.ops import min_int +from warnings import warn + __all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin'] @@ -37,6 +39,11 @@ def quant_axis(cls, scale): return i return None +class FloatClipMixin(ABC): + @classmethod + def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width): + warn("Not implemented for floating point") + return None class ClipMixin(ABC): @@ -113,7 +120,14 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor): class FloatZeroPointHandlerMixin(ABC): - pass + @classmethod + def zero_point_with_dtype(cls, signed, exponent_bit_width, mantissa_bit_width, zero_point): + if exponent_bit_width == 4 and mantissa_bit_width == 3: + return zero_point.type(torch.float8_e4m3fn) + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + return zero_point.type(torch.float8_e5m2) + else: + return zero_point.type(torch.float32) class ZeroPointHandlerMixin(ABC): diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 50b6000f3..5bbc1e28e 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -21,6 +21,7 @@ from .base import BitWidthHandlerMixin from .base import ClipMixin +from .base import FloatClipMixin from .base import QuantAxisMixin from .base import ZeroPointHandlerMixin from .base import FloatZeroPointHandlerMixin @@ -68,7 +69,21 @@ def clip_fn(self, x, min_val, max_val): pass class FloatQMixin(ABC): - pass + @abstractmethod + def quantize_fn(self, x, scale, zero_point, dtype, axis): + raise NotImplementedError() + + @classmethod + def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_signed): + if exponent_bit_width is None or mantissa_bit_width is None: + return None + if exponent_bit_width == 4 and mantissa_bit_width == 3: + dtype = torch.float8_e4m3fn + elif exponent_bit_width == 5 and mantissa_bit_width == 2: + dtype = torch.float8_e5m2 + else: + dtype = torch.float32 + return dtype class QMixin(BitWidthHandlerMixin, ABC): @@ -114,10 +129,27 @@ def quantize_fn(self, x, dtype): pass -class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): +class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatClipMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): - def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): - raise NotImplementedError() + def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed): + scale_orig_shape = scale.shape + axis = cls.quant_axis(scale) + if cls.flatten_dequantize_params: + scale = scale.flatten() + scale = to_0dim_if_scalar(scale) + if cls.flatten_dequantize_params: + zero_point = zero_point.flatten() + zp = to_0dim_if_scalar(zero_point) + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantisssa_bit_width, zp) + return { + 'scale': scale, + 'zero_point': zp, + 'axis': axis, + # We save only the scale original shape + # as zero-point is being expanded to the same + # size as the scale + 'scale_orig_shape': scale_orig_shape} class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC): @@ -145,17 +177,57 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): class QCDQCastFloatWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): handled_layer = FloatWeightQuantProxyFromInjector - def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): - raise NotImplementedError() + def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): + # compute axis before redefining scale + axis = cls.quant_axis(scale) + scale = to_0dim_if_scalar(scale.flatten()) + zp = to_0dim_if_scalar(zero_point.flatten()) + # expand_as must go after 0-dim check + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantissa_bit_width, zp) + if cls.itemize_quantize_scalar_params: + scale = to_item_if_0dim(scale) + zp = to_item_if_0dim(zp) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_signed) + return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} def prepare_quantize_from_floating_point(self, module): - raise NotImplementedError() + quant_weight = module.tracked_module_list[0].quant_weight() + scale = quant_weight.scale + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( + scale, quant_weight.zero_point, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width, module.is_signed) def prepare_quantize_from_integer(self, module): raise NotImplementedError() def prepare_for_export(self, module): - raise NotImplementedError() + if module.is_quant_enabled: + self.validate(module) + if self._export_q_node: + self.prepare_quantize_from_floating_point(module) + else: + self.prepare_quantize_from_integer(module) + # Get the first quant weight as representative + quant_weight = module.tracked_module_list[0].quant_weight() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype + # of the scale and we cast it to float32. + # The original dtype is then restored during the forward pass + scale = quant_weight.scale + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + + self.symbolic_kwargs['bit_width'] = quant_weight.bit_width + self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( + module.is_narrow_range, module.is_signed, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width) + self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( + scale, quant_weight.zero_point, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width, module.is_signed) + else: + self.symbolic_kwargs = None def quantize_from_floating_point(self, x: Tensor): raise NotImplementedError() diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index a2ec42a00..da13526d1 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -28,6 +28,7 @@ from ..function import IntClipFn from ..function import QuantizeLinearFn +from warnings import warn class StdFloatDQCastONNXMixin(DQCastMixin, ABC): @@ -44,7 +45,7 @@ def flatten_dequantize_params(self): @property def itemize_quantize_scalar_params(self): - raise NotImplementedError() + return False def validate(self, module): raise NotImplementedError() @@ -81,7 +82,8 @@ def clip_fn(self, x, min_val, max_val): class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): def validate(self, module): - raise NotImplementedError() + warn("Needs to be implemented") + pass def quantize_fn(self, x, scale, zero_point, dtype, axis): raise NotImplementedError() diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 0d8f3f9d3..7a48a70ba 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -112,11 +112,16 @@ def mantissa_bit_width(self): return mantissa_bit_width def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + #!!!!!!!!! PLACEHOLDER CODE !!!!!!!!! if self.is_quant_enabled: warn("This code should be replace with FloatQuantTensor when it becomes") impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width = impl(x) - return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + out, scale, zero_point, exponent_bit_width, mantissa_bit_width = impl(x) + qt = QuantTensor(out, scale, zero_point, exponent_bit_width+mantissa_bit_width, self.is_signed, self.training) + qt.exponent_bit_width = exponent_bit_width + qt.mantissa_bit_width = mantissa_bit_width + #!!!!!!!!! PLACEHOLDER CODE !!!!!!!!! + return qt else: # quantization disabled return x From e10e63040236443d0e4d9f4041b65df81d8df938 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:53:26 +0100 Subject: [PATCH 03/44] first working flow end to end --- src/brevitas/export/common/handler/base.py | 6 --- src/brevitas/export/common/handler/qcdq.py | 50 +++++++++++++++---- .../export/onnx/standard/qcdq/handler.py | 25 ++-------- 3 files changed, 45 insertions(+), 36 deletions(-) diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index d90a4b976..ec0923fd1 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -39,12 +39,6 @@ def quant_axis(cls, scale): return i return None -class FloatClipMixin(ABC): - @classmethod - def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width): - warn("Not implemented for floating point") - return None - class ClipMixin(ABC): @classmethod diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 5bbc1e28e..572d2c241 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -21,7 +21,6 @@ from .base import BitWidthHandlerMixin from .base import ClipMixin -from .base import FloatClipMixin from .base import QuantAxisMixin from .base import ZeroPointHandlerMixin from .base import FloatZeroPointHandlerMixin @@ -129,7 +128,7 @@ def quantize_fn(self, x, dtype): pass -class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatClipMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): +class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed): scale_orig_shape = scale.shape @@ -221,22 +220,55 @@ def prepare_for_export(self, module): if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: scale = self.cast_fn(scale, torch.float32) - self.symbolic_kwargs['bit_width'] = quant_weight.bit_width - self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( - module.is_narrow_range, module.is_signed, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width) + self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width + self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( scale, quant_weight.zero_point, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width, module.is_signed) else: self.symbolic_kwargs = None - def quantize_from_floating_point(self, x: Tensor): - raise NotImplementedError() + def quantize_from_floating_point(self, x: Tensor, zp): + quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + quantize_symbolic_kwargs['zero_point'] = zp + # Before quantization, cast input to float32 + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) + x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) + return x def quantize_from_integer(self, x: Tensor): - raise NotImplementedError() + int_weights = { + tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) + for tm in module.tracked_module_list} + self.symbolic_kwargs['int_weights'] = int_weights def symbolic_execution(self, x: Tensor): - raise NotImplementedError() + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + + # Copy dict to allow for popping kwargs even on shared quantizers + dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) + scale = dequantize_symbolic_kwargs['scale'] + zero_point = dequantize_symbolic_kwargs['zero_point'] + + if self._export_q_node: + x = self.quantize_from_floating_point(x, zero_point) + else: + x = self.quantize_from_integer(x) + + exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] + mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') + # Workaround to trick the tracer into believing all return values are used + self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) + x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # After dequantization, cast both input and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return x, scale, zero_point, exponent_bit_width, mantissa_bit_width class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): handled_layer = WeightQuantProxyFromInjector diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index da13526d1..9e5b62fd5 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -30,25 +30,6 @@ from warnings import warn - -class StdFloatDQCastONNXMixin(DQCastMixin, ABC): - - def dequantize_fn(self, x, scale, zero_point, axis): - raise NotImplementedError() - - def cast_fn(self, x, dtype): - raise NotImplementedError() - - @property - def flatten_dequantize_params(self): - raise NotImplementedError() - - @property - def itemize_quantize_scalar_params(self): - return False - - def validate(self, module): - raise NotImplementedError() class StdDQCastONNXMixin(DQCastMixin, ABC): @@ -69,6 +50,9 @@ def itemize_quantize_scalar_params(self): def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' +class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC): + def validate(self, module): + pass class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): @@ -82,11 +66,10 @@ def clip_fn(self, x, min_val, max_val): class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): def validate(self, module): - warn("Needs to be implemented") pass def quantize_fn(self, x, scale, zero_point, dtype, axis): - raise NotImplementedError() + return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC): From 84e70f7741d2d5855926c13278dbdf4554aa3503 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:56:03 +0100 Subject: [PATCH 04/44] formatting --- src/brevitas/export/common/handler/base.py | 9 ++++--- src/brevitas/export/common/handler/qcdq.py | 27 ++++++++++++++----- .../export/onnx/standard/qcdq/handler.py | 14 +++++++--- .../export/onnx/standard/qcdq/manager.py | 2 +- src/brevitas/proxy/__init__.py | 2 +- src/brevitas/proxy/parameter_quant.py | 8 +++++- 6 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index ec0923fd1..609f11dc1 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -4,6 +4,7 @@ from abc import ABC from abc import abstractmethod import math +from warnings import warn import torch from torch import Tensor @@ -12,9 +13,8 @@ from brevitas.function.ops import max_int from brevitas.function.ops import min_int -from warnings import warn - -__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin'] +__all__ = [ + 'BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin'] class BaseHandler(Module, ABC): @@ -39,6 +39,7 @@ def quant_axis(cls, scale): return i return None + class ClipMixin(ABC): @classmethod @@ -114,6 +115,7 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor): class FloatZeroPointHandlerMixin(ABC): + @classmethod def zero_point_with_dtype(cls, signed, exponent_bit_width, mantissa_bit_width, zero_point): if exponent_bit_width == 4 and mantissa_bit_width == 3: @@ -123,6 +125,7 @@ def zero_point_with_dtype(cls, signed, exponent_bit_width, mantissa_bit_width, z else: return zero_point.type(torch.float32) + class ZeroPointHandlerMixin(ABC): @classmethod diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 572d2c241..cc05d5596 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -14,16 +14,16 @@ from brevitas.proxy import BiasQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector -from brevitas.proxy import WeightQuantProxyFromInjector from brevitas.proxy import FloatWeightQuantProxyFromInjector +from brevitas.proxy import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector from .base import BitWidthHandlerMixin from .base import ClipMixin +from .base import FloatZeroPointHandlerMixin from .base import QuantAxisMixin from .base import ZeroPointHandlerMixin -from .base import FloatZeroPointHandlerMixin class DQMixin(ABC): @@ -67,7 +67,9 @@ class CDQCastMixin(DQCastMixin, ABC): def clip_fn(self, x, min_val, max_val): pass + class FloatQMixin(ABC): + @abstractmethod def quantize_fn(self, x, scale, zero_point, dtype, axis): raise NotImplementedError() @@ -84,6 +86,7 @@ def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_signed): dtype = torch.float32 return dtype + class QMixin(BitWidthHandlerMixin, ABC): @classmethod @@ -130,7 +133,8 @@ def quantize_fn(self, x, dtype): class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): - def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed): + def dequantize_symbolic_kwargs( + cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed): scale_orig_shape = scale.shape axis = cls.quant_axis(scale) if cls.flatten_dequantize_params: @@ -150,6 +154,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, manti # size as the scale 'scale_orig_shape': scale_orig_shape} + class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): @@ -176,7 +181,8 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): class QCDQCastFloatWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): handled_layer = FloatWeightQuantProxyFromInjector - def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): + def quantize_symbolic_kwargs( + cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): # compute axis before redefining scale axis = cls.quant_axis(scale) scale = to_0dim_if_scalar(scale.flatten()) @@ -197,7 +203,11 @@ def prepare_quantize_from_floating_point(self, module): if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: scale = self.cast_fn(scale, torch.float32) self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( - scale, quant_weight.zero_point, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width, module.is_signed) + scale, + quant_weight.zero_point, + quant_weight.exponent_bit_width, + quant_weight.mantissa_bit_width, + module.is_signed) def prepare_quantize_from_integer(self, module): raise NotImplementedError() @@ -223,7 +233,11 @@ def prepare_for_export(self, module): self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( - scale, quant_weight.zero_point, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width, module.is_signed) + scale, + quant_weight.zero_point, + quant_weight.exponent_bit_width, + quant_weight.mantissa_bit_width, + module.is_signed) else: self.symbolic_kwargs = None @@ -270,6 +284,7 @@ def symbolic_execution(self, x: Tensor): zero_point = zero_point.view_as(scale) return x, scale, zero_point, exponent_bit_width, mantissa_bit_width + class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): handled_layer = WeightQuantProxyFromInjector diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 9e5b62fd5..7087019df 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC +from warnings import warn import torch @@ -10,15 +11,15 @@ from brevitas.export.common.handler.qcdq import DQCastMixin from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DynamicQMixin +from brevitas.export.common.handler.qcdq import FloatQMixin from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import \ QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastFloatWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQCastFloatWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin -from brevitas.export.common.handler.qcdq import FloatQMixin from brevitas.export.onnx.handler import ONNXBaseHandler from brevitas.export.onnx.handler import QuantLSTMLayerHandler @@ -28,9 +29,7 @@ from ..function import IntClipFn from ..function import QuantizeLinearFn -from warnings import warn - class StdDQCastONNXMixin(DQCastMixin, ABC): def dequantize_fn(self, x, scale, zero_point, axis): @@ -50,27 +49,34 @@ def itemize_quantize_scalar_params(self): def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' + class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC): + def validate(self, module): pass + class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) + class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) + class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): + def validate(self, module): pass def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) + class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC): @classmethod diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index dc006226c..186be6c87 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -20,10 +20,10 @@ from .handler import StdQCDQCastONNXActQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler +from .handler import StdQCDQCastONNXFloatWeightQuantProxyHandler from .handler import StdQCDQCastONNXQuantLSTMLayerHandler from .handler import StdQCDQCastONNXTruncQuantProxyHandler from .handler import StdQCDQCastONNXWeightQuantProxyHandler -from .handler import StdQCDQCastONNXFloatWeightQuantProxyHandler class StdQCDQONNXManager(StdONNXBaseManager): diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py index 7157cf0d6..ecf9754aa 100644 --- a/src/brevitas/proxy/__init__.py +++ b/src/brevitas/proxy/__init__.py @@ -4,8 +4,8 @@ from .parameter_quant import BiasQuantProxyFromInjector from .parameter_quant import DecoupledWeightQuantProxyFromInjector from .parameter_quant import DecoupledWeightQuantWithInputProxyFromInjector -from .parameter_quant import WeightQuantProxyFromInjector from .parameter_quant import FloatWeightQuantProxyFromInjector +from .parameter_quant import WeightQuantProxyFromInjector from .runtime_quant import ActQuantProxyFromInjector from .runtime_quant import ClampQuantProxyFromInjector from .runtime_quant import TruncQuantProxyFromInjector diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 7a48a70ba..44b7bf16e 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -117,7 +117,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: warn("This code should be replace with FloatQuantTensor when it becomes") impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, exponent_bit_width, mantissa_bit_width = impl(x) - qt = QuantTensor(out, scale, zero_point, exponent_bit_width+mantissa_bit_width, self.is_signed, self.training) + qt = QuantTensor( + out, + scale, + zero_point, + exponent_bit_width + mantissa_bit_width, + self.is_signed, + self.training) qt.exponent_bit_width = exponent_bit_width qt.mantissa_bit_width = mantissa_bit_width #!!!!!!!!! PLACEHOLDER CODE !!!!!!!!! From ef4c73762b840314740a8233792de2231b8c2008 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 11:42:38 +0100 Subject: [PATCH 05/44] changes to tests --- src/brevitas/export/common/handler/qcdq.py | 5 ++++- tests/brevitas/core/test_float_quant.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index cc05d5596..15be1e04f 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -210,7 +210,10 @@ def prepare_quantize_from_floating_point(self, module): module.is_signed) def prepare_quantize_from_integer(self, module): - raise NotImplementedError() + int_weights = { + tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) + for tm in module.tracked_module_list} + self.symbolic_kwargs['int_weights'] = int_weights def prepare_for_export(self, module): if module.is_quant_enabled: diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 5c1b0cdc3..24d6beb7b 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -202,7 +202,7 @@ def test_inner_scale(inp, minifloat_format, scale): # dequantize manually out = val_fp_quant * scale - expected_out, expected_scale, _, _ = float_quant(inp) + expected_out, expected_scale, _, _, _ = float_quant(inp) assert scale == expected_scale if scale == 0.0: From 4aa4b2142a4eb5c98c328161754b8a797bf5140b Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 13:14:49 +0100 Subject: [PATCH 06/44] added version check for test --- src/brevitas/export/common/handler/qcdq.py | 2 +- tests/brevitas/export/test_onnx_fp8.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 15be1e04f..85375726e 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -72,7 +72,7 @@ class FloatQMixin(ABC): @abstractmethod def quantize_fn(self, x, scale, zero_point, dtype, axis): - raise NotImplementedError() + pass @classmethod def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_signed): diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index 48da70b9f..e685874e2 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -1,14 +1,20 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from packaging import version +import pytest import torch +from brevitas import torch_version from brevitas.export import export_onnx_qcdq import brevitas.nn as qnn from brevitas.quant.experimental.scaled_float import Fp8e4m3OCPWeightPerTensorFloat def test_simple_fp8_export(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + model = qnn.QuantLinear(3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat) export_onnx_qcdq(model, torch.randn(1, 3), 'test.onnx', export_weight_q_node=True) assert True From 3b0588360d62d2fc28f221088bbd10d2027b5e7d Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 13:54:53 +0100 Subject: [PATCH 07/44] using existing functionality over homespun --- src/brevitas/quant/experimental/scaled_float.py | 16 ---------------- tests/brevitas/export/test_onnx_fp8.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) delete mode 100644 src/brevitas/quant/experimental/scaled_float.py diff --git a/src/brevitas/quant/experimental/scaled_float.py b/src/brevitas/quant/experimental/scaled_float.py deleted file mode 100644 index 5c9b655d9..000000000 --- a/src/brevitas/quant/experimental/scaled_float.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from brevitas.quant.base import MaxStatsScaling -from brevitas.quant.base import PerTensorFloatScaling8bit -from brevitas.quant.experimental.float import Fp8e4m3Weight -from brevitas.quant.solver.weight import WeightQuantSolver - -__all__ = ['Fp8e4m3OCPWeightPerTensorFloat'] - - -class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3Weight, - MaxStatsScaling, - PerTensorFloatScaling8bit, - WeightQuantSolver): - pass diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index e685874e2..fc60d9af8 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -8,7 +8,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq import brevitas.nn as qnn -from brevitas.quant.experimental.scaled_float import Fp8e4m3OCPWeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat def test_simple_fp8_export(): From cad580221ad759648a94791eca205b2b392a7de2 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:37:26 +0100 Subject: [PATCH 08/44] corrected mistake in copying and restored FloatClipMixin --- src/brevitas/export/common/handler/base.py | 7 +++++++ src/brevitas/export/common/handler/qcdq.py | 21 +++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index 609f11dc1..09d90fcf5 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -40,6 +40,13 @@ def quant_axis(cls, scale): return None +class FloatClipMixin(ABC): + + @classmethod + def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width): + return None + + class ClipMixin(ABC): @classmethod diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 85375726e..0db60a425 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -21,6 +21,7 @@ from .base import BitWidthHandlerMixin from .base import ClipMixin +from .base import FloatClipMixin from .base import FloatZeroPointHandlerMixin from .base import QuantAxisMixin from .base import ZeroPointHandlerMixin @@ -131,7 +132,11 @@ def quantize_fn(self, x, dtype): pass -class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC): +class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, + FloatClipMixin, + FloatZeroPointHandlerMixin, + CDQCastMixin, + ABC): def dequantize_symbolic_kwargs( cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed): @@ -235,6 +240,11 @@ def prepare_for_export(self, module): self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width + self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( + module.is_narrow_range, + module.is_signed, + quant_weight.exponent_bit_width, + quant_weight.mantissa_bit_width) self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( scale, quant_weight.zero_point, @@ -271,12 +281,14 @@ def symbolic_execution(self, x: Tensor): x = self.quantize_from_floating_point(x, zero_point) else: x = self.quantize_from_integer(x) - + clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) + if clip_symbolic_kwargs is not None: + x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) # After dequantization, cast both input and scale to the correct dtype if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: @@ -315,10 +327,7 @@ def prepare_quantize_from_floating_point(self, module): scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed) def prepare_quantize_from_integer(self, module): - int_weights = { - tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) - for tm in module.tracked_module_list} - self.symbolic_kwargs['int_weights'] = int_weights + return self.symbolic_kwargs['int_weights'][x.data_ptr()] def prepare_for_export(self, module): if module.is_quant_enabled: From 4848248c7bca054fb6bbc04e7a9dc0f73f686450 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:41:43 +0100 Subject: [PATCH 09/44] fixed mistake --- src/brevitas/export/common/handler/qcdq.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 0db60a425..58feb5527 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -264,10 +264,7 @@ def quantize_from_floating_point(self, x: Tensor, zp): return x def quantize_from_integer(self, x: Tensor): - int_weights = { - tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) - for tm in module.tracked_module_list} - self.symbolic_kwargs['int_weights'] = int_weights + return self.symbolic_kwargs['int_weights'][x.data_ptr()] def symbolic_execution(self, x: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' @@ -327,7 +324,10 @@ def prepare_quantize_from_floating_point(self, module): scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed) def prepare_quantize_from_integer(self, module): - return self.symbolic_kwargs['int_weights'][x.data_ptr()] + int_weights = { + tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) + for tm in module.tracked_module_list} + self.symbolic_kwargs['int_weights'] = int_weights def prepare_for_export(self, module): if module.is_quant_enabled: From 5188aa6f73d8dd23f28ff7a6054134016214272d Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 16 Apr 2024 09:59:03 +0100 Subject: [PATCH 10/44] first pass activation fp8 export --- src/brevitas/export/common/handler/qcdq.py | 90 +++++++++++++++++++ .../export/onnx/standard/qcdq/handler.py | 7 ++ .../export/onnx/standard/qcdq/manager.py | 2 + src/brevitas/proxy/__init__.py | 1 + src/brevitas/proxy/runtime_quant.py | 87 +++++++++++++++++- src/brevitas/quant/experimental/float_base.py | 4 +- tests/brevitas/export/test_onnx_fp8.py | 12 ++- 7 files changed, 198 insertions(+), 5 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 58feb5527..6390c96a6 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -14,6 +14,7 @@ from brevitas.proxy import BiasQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector +from brevitas.proxy import FloatActQuantProxyFromInjector from brevitas.proxy import FloatWeightQuantProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector @@ -255,6 +256,7 @@ def prepare_for_export(self, module): self.symbolic_kwargs = None def quantize_from_floating_point(self, x: Tensor, zp): + # Workaround for equal_cpu RuntimeError quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] quantize_symbolic_kwargs['zero_point'] = zp # Before quantization, cast input to float32 @@ -415,6 +417,94 @@ def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_ return super().symbolic_execution(x) +class FloatQCDQCastActQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin, ABC): + handled_layer = FloatActQuantProxyFromInjector + + def quantize_symbolic_kwargs( + cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): + # compute axis before redefining scale + axis = cls.quant_axis(scale) + scale = to_0dim_if_scalar(scale.flatten()) + zp = to_0dim_if_scalar(zero_point.flatten()) + # expand_as must go after 0-dim check + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantissa_bit_width, zp) + if cls.itemize_quantize_scalar_params: + scale = to_item_if_0dim(scale) + zp = to_item_if_0dim(zp) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_signed) + return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width() + self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype + # of the scale and we cast it to float32. + # The original dtype is then restored during the forward pass + scale = module.scale() + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( + scale, + module.zero_point(), + module.exponent_bit_width(), + module.mantissa_bit_width(), + module.is_signed) + self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( + scale, + module.zero_point(), + module.exponent_bit_width(), + module.mantissa_bit_width(), + module.is_signed) + self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( + module.is_narrow_range, + module.is_signed, + module.exponent_bit_width(), + module.mantissa_bit_width()) + + else: + self.symbolic_kwargs = None + + def symbolic_execution(self, x: Tensor): + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + + # Copy dict to allow for popping kwargs even on shared quantizers + dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) + scale = dequantize_symbolic_kwargs['scale'] + zero_point = dequantize_symbolic_kwargs['zero_point'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') + + quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] + exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] + mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + + # Workaround to trick the tracer into believing all return values are used + quantize_symbolic_kwargs['zero_point'] = zero_point + + self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) + # If original dtype of the input is (b)float16, cast the input to float32 + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) + x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) + if clip_symbolic_kwargs is not None: + x = self.clip_fn(x, *clip_symbolic_kwargs.values()) + x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # After dequantization, cast both output and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return x, scale, zero_point, exponent_bit_width, mantissa_bit_width + + class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC): handled_layer = ActQuantProxyFromInjector diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 7087019df..d61d290f0 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -11,6 +11,7 @@ from brevitas.export.common.handler.qcdq import DQCastMixin from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DynamicQMixin +from brevitas.export.common.handler.qcdq import FloatQCDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import FloatQMixin from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin @@ -160,6 +161,12 @@ class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( _export_q_node = False +class StdQCDQCastONNXFloatActQuantProxyHandler(StdFloatQCDQCastONNXMixin, + FloatQCDQCastActQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, QCDQCastActQuantProxyHandlerMixin, ONNXBaseHandler): diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index 186be6c87..d947309bd 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -20,6 +20,7 @@ from .handler import StdQCDQCastONNXActQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler +from .handler import StdQCDQCastONNXFloatActQuantProxyHandler from .handler import StdQCDQCastONNXFloatWeightQuantProxyHandler from .handler import StdQCDQCastONNXQuantLSTMLayerHandler from .handler import StdQCDQCastONNXTruncQuantProxyHandler @@ -40,6 +41,7 @@ class StdQCDQONNXManager(StdONNXBaseManager): StdQCDQCastONNXFloatWeightQuantProxyHandler, StdCDQCastONNXBiasQuantProxyHandler, StdQCDQCastONNXActQuantProxyHandler, + StdQCDQCastONNXFloatActQuantProxyHandler, StdQCDQCastONNXDecoupledWeightQuantProxyHandler, StdDynamicQDQCastONNXActQuantProxyHandler, StdQCDQCastONNXTruncQuantProxyHandler, diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py index ecf9754aa..2ee2389c6 100644 --- a/src/brevitas/proxy/__init__.py +++ b/src/brevitas/proxy/__init__.py @@ -8,4 +8,5 @@ from .parameter_quant import WeightQuantProxyFromInjector from .runtime_quant import ActQuantProxyFromInjector from .runtime_quant import ClampQuantProxyFromInjector +from .runtime_quant import FloatActQuantProxyFromInjector from .runtime_quant import TruncQuantProxyFromInjector diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 4dd8417a9..9eec16bef 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -20,6 +20,7 @@ 'ActQuantProxyProtocol', 'AccQuantProxyProtocol', 'ActQuantProxyFromInjector', + 'FloatActQuantProxyFromInjector', 'TruncQuantProxyFromInjector', 'ClampQuantProxyFromInjector'] @@ -80,8 +81,8 @@ def __init__(self, activation_impl, tensor_quant): @brevitas.jit.script_method def forward(self, x): x = self.activation_impl(x) - x, output_scale, output_zp, output_bit_width = self.tensor_quant(x) - return x, output_scale, output_zp, output_bit_width + x, output_scale, output_zp, output_exponent_bit_width, output_mantissa_bit_width = self.tensor_quant(x) + return x, output_scale, output_zp, output_exponent_bit_width, output_mantissa_bit_width class ActQuantProxyFromInjector(QuantProxyFromInjector, ActQuantProxyProtocol): @@ -205,6 +206,88 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: return out +class FloatActQuantProxyFromInjector(ActQuantProxyFromInjector): + + def exponent_bit_width(self, force_eval=True): + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.exponent_bit_width + elif self._cached_act is not None: + return self._cached_act.exponent_bit_width + elif self._cached_act is None: + return None + + def mantissa_bit_width(self, force_eval=True): + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.mantissa_bit_width + elif self._cached_act is not None: + return self._cached_act.mantissa_bit_width + elif self._cached_act is None: + return None + + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + out = x + if self.fused_activation_quant_proxy is not None: + y = x + if isinstance(y, QuantTensor): + y = y.value + + if self.export_mode: + y = self.fused_activation_quant_proxy.activation_impl(y) + y = self.export_handler(y) + elif not self.is_quant_enabled: + y = self.fused_activation_quant_proxy.activation_impl(y) + else: + y = self.fused_activation_quant_proxy(y) + # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): + # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! + value, scale, zero_point, exponent_bit_width, mantissa_bit_width = y + out = QuantTensor( + value, + scale, + zero_point, + exponent_bit_width + mantissa_bit_width, + signed=self.is_signed, + training=self.training) + out.exponent_bit_width = exponent_bit_width + out.mantissa_bit_width = mantissa_bit_width + # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, QuantTensor): + # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! + out = QuantTensor( + y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + out.exponent_bit_width = x.exponent_bit_width + out.mantissa_bit_width = x.mantissa_bit_width + # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! + else: + out = y + else: + if isinstance(y, tuple): + y = y[0] + out = y + else: + # If fused activation quant proxy is not enabled, return the input + out = x + if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): + cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out + + class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector): def scale(self, force_eval=True): diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index abe2d3250..09b3bb671 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -8,7 +8,7 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import value from brevitas.proxy.parameter_quant import FloatWeightQuantProxyFromInjector -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.proxy.runtime_quant import FloatActQuantProxyFromInjector from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum @@ -32,7 +32,7 @@ class FloatWeightBase(FloatBase): class FloatActBase(FloatBase): - proxy_class = ActQuantProxyFromInjector + proxy_class = FloatActQuantProxyFromInjector class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver): diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index fc60d9af8..09ab8a563 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -8,6 +8,7 @@ from brevitas import torch_version from brevitas.export import export_onnx_qcdq import brevitas.nn as qnn +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat @@ -20,6 +21,15 @@ def test_simple_fp8_export(): assert True +def test_fp8_export_activation(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + + model = qnn.QuantLinear(3, 16, input_quant=Fp8e4m3OCPActPerTensorFloat) + export_onnx_qcdq(model, torch.randn(1, 3), 'test.onnx', export_weight_q_node=True) + assert True + + if __name__ == "__main__": - test_simple_fp8_export() + test_fp8_export_activation() print("Done") From 29cb95205ea90dabc8a1fde2a61e9cb851cb8053 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:26:50 +0100 Subject: [PATCH 11/44] beginnings of activation fp8 export and change name of QCDQCastFloatWeightQuantProxyHandlerMixin to be more consistent --- src/brevitas/export/common/handler/qcdq.py | 2 +- src/brevitas/export/onnx/manager.py | 1 + src/brevitas/export/onnx/standard/qcdq/handler.py | 4 ++-- tests/brevitas/export/test_onnx_fp8.py | 15 +++++++++++++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 6390c96a6..634dad04a 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -184,7 +184,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): 'scale_orig_shape': scale_orig_shape} -class QCDQCastFloatWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): +class FloatQCDQCastWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): handled_layer = FloatWeightQuantProxyFromInjector def quantize_symbolic_kwargs( diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index 1bacb461e..e3b693618 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -127,6 +127,7 @@ def export_onnx( else: model_bytes = BytesIO() export_target = model_bytes + onnx_export_kwargs['verbose'] = True torch.onnx.export(module, args, export_target, **onnx_export_kwargs) # restore the model to previous properties diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index d61d290f0..7858911ad 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -12,12 +12,12 @@ from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DynamicQMixin from brevitas.export.common.handler.qcdq import FloatQCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import FloatQCDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import FloatQMixin from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import \ QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQCastFloatWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -138,7 +138,7 @@ def quantize_fn(self, x, dtype): class StdQCDQCastONNXFloatWeightQuantProxyHandler(StdFloatQCDQCastONNXMixin, - QCDQCastFloatWeightQuantProxyHandlerMixin, + FloatQCDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): _export_q_node = False diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index 09ab8a563..6ea29b4a2 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -17,7 +17,7 @@ def test_simple_fp8_export(): pytest.skip(f"OCP FP8 types not supported by {torch_version}") model = qnn.QuantLinear(3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat) - export_onnx_qcdq(model, torch.randn(1, 3), 'test.onnx', export_weight_q_node=True) + export_onnx_qcdq(model, torch.randn(1, 3), 'weight_fp8.onnx', export_weight_q_node=True) assert True @@ -26,10 +26,21 @@ def test_fp8_export_activation(): pytest.skip(f"OCP FP8 types not supported by {torch_version}") model = qnn.QuantLinear(3, 16, input_quant=Fp8e4m3OCPActPerTensorFloat) - export_onnx_qcdq(model, torch.randn(1, 3), 'test.onnx', export_weight_q_node=True) + export_onnx_qcdq(model, torch.randn(1, 3), 'act_fp8.onnx', export_weight_q_node=True) + assert True + + +def test_fp8_export_export_activation(): + if torch_version < version.parse('2.1.0'): + pytest.skip(f"OCP FP8 types not supported by {torch_version}") + + model = qnn.QuantLinear( + 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat) + export_onnx_qcdq(model, torch.randn(1, 3), 'weight_act_fp8.onnx', export_weight_q_node=True) assert True if __name__ == "__main__": test_fp8_export_activation() + #test_fp8_export_export_activation() print("Done") From 9bf9240cb5b75fe6cde9f3bf1a4d34ef88a56c60 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:35:03 +0100 Subject: [PATCH 12/44] more changes to make naming scheme more consistent --- src/brevitas/export/onnx/standard/qcdq/handler.py | 4 ++-- src/brevitas/export/onnx/standard/qcdq/manager.py | 8 ++++---- tests/brevitas/export/test_onnx_fp8.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 7858911ad..55003c147 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -137,7 +137,7 @@ def quantize_fn(self, x, dtype): return DynamicQuantizeLinearFn.apply(x, dtype) -class StdQCDQCastONNXFloatWeightQuantProxyHandler(StdFloatQCDQCastONNXMixin, +class StdFloatQCDQCastONNXWeightQuantProxyHandler(StdFloatQCDQCastONNXMixin, FloatQCDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): _export_q_node = False @@ -161,7 +161,7 @@ class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( _export_q_node = False -class StdQCDQCastONNXFloatActQuantProxyHandler(StdFloatQCDQCastONNXMixin, +class StdFloatQCDQCastONNXActQuantProxyHandler(StdFloatQCDQCastONNXMixin, FloatQCDQCastActQuantProxyHandlerMixin, ONNXBaseHandler): pass diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index d947309bd..e43d97e6d 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -17,11 +17,11 @@ from ..manager import StdONNXBaseManager from .handler import StdCDQCastONNXBiasQuantProxyHandler from .handler import StdDynamicQDQCastONNXActQuantProxyHandler +from .handler import StdFloatQCDQCastONNXActQuantProxyHandler +from .handler import StdFloatQCDQCastONNXWeightQuantProxyHandler from .handler import StdQCDQCastONNXActQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler -from .handler import StdQCDQCastONNXFloatActQuantProxyHandler -from .handler import StdQCDQCastONNXFloatWeightQuantProxyHandler from .handler import StdQCDQCastONNXQuantLSTMLayerHandler from .handler import StdQCDQCastONNXTruncQuantProxyHandler from .handler import StdQCDQCastONNXWeightQuantProxyHandler @@ -38,10 +38,10 @@ class StdQCDQONNXManager(StdONNXBaseManager): handlers = [ StdQCDQCastONNXWeightQuantProxyHandler, - StdQCDQCastONNXFloatWeightQuantProxyHandler, + StdFloatQCDQCastONNXWeightQuantProxyHandler, StdCDQCastONNXBiasQuantProxyHandler, StdQCDQCastONNXActQuantProxyHandler, - StdQCDQCastONNXFloatActQuantProxyHandler, + StdFloatQCDQCastONNXActQuantProxyHandler, StdQCDQCastONNXDecoupledWeightQuantProxyHandler, StdDynamicQDQCastONNXActQuantProxyHandler, StdQCDQCastONNXTruncQuantProxyHandler, diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index 6ea29b4a2..b460bff24 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -41,6 +41,6 @@ def test_fp8_export_export_activation(): if __name__ == "__main__": - test_fp8_export_activation() - #test_fp8_export_export_activation() + #test_fp8_export_activation() + test_fp8_export_export_activation() print("Done") From f9406f13ec7a6c0a61f888ebad5cdbbb51303940 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:44:07 +0100 Subject: [PATCH 13/44] added FloatFusedActivationQuantProxy --- src/brevitas/proxy/runtime_quant.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9eec16bef..153c9c582 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -78,6 +78,20 @@ def __init__(self, activation_impl, tensor_quant): self.activation_impl = activation_impl self.tensor_quant = tensor_quant + @brevitas.jit.script_method + def forward(self, x): + x = self.activation_impl(x) + x, output_scale, output_zp, output_bit_width = self.tensor_quant(x) + return x, output_scale, output_zp, output_bit_width + + +class FloatFusedActivationQuantProxy(brevitas.jit.ScriptModule): + + def __init__(self, activation_impl, tensor_quant): + super(FloatFusedActivationQuantProxy, self).__init__() + self.activation_impl = activation_impl + self.tensor_quant = tensor_quant + @brevitas.jit.script_method def forward(self, x): x = self.activation_impl(x) @@ -208,6 +222,27 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: class FloatActQuantProxyFromInjector(ActQuantProxyFromInjector): + def init_tensor_quant(self): + tensor_quant = self.quant_injector.tensor_quant + if 'act_impl' in self.quant_injector: + act_impl = self.quant_injector.act_impl + else: + act_impl = None + is_act_enabled = _is_act_enabled(act_impl, tensor_quant) + is_quant_enabled = tensor_quant is not None + self.is_quant_enabled = is_quant_enabled + if is_act_enabled and is_quant_enabled: + self.fused_activation_quant_proxy = FloatFusedActivationQuantProxy( + act_impl, tensor_quant) + elif is_act_enabled and not is_quant_enabled: + self.fused_activation_quant_proxy = FloatFusedActivationQuantProxy( + act_impl, _TensorQuantDisabledIdentity()) + elif not is_act_enabled and is_quant_enabled: + self.fused_activation_quant_proxy = FloatFusedActivationQuantProxy( + Identity(), tensor_quant) + else: + self.fused_activation_quant_proxy = None + def exponent_bit_width(self, force_eval=True): if self.is_quant_enabled: current_status = self.training From 991ddb797678d8ca531c8befdc1bd07b60c459f4 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:18:04 +0100 Subject: [PATCH 14/44] replaced zero_point workaround with placeholder implementation of fp8 equal_cpu --- src/brevitas/export/common/handler/qcdq.py | 8 ++------ src/brevitas/export/onnx/manager.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 634dad04a..5d1d0edd2 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -255,10 +255,9 @@ def prepare_for_export(self, module): else: self.symbolic_kwargs = None - def quantize_from_floating_point(self, x: Tensor, zp): + def quantize_from_floating_point(self, x: Tensor): # Workaround for equal_cpu RuntimeError quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] - quantize_symbolic_kwargs['zero_point'] = zp # Before quantization, cast input to float32 if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: x = self.cast_fn(x, torch.float32) @@ -277,7 +276,7 @@ def symbolic_execution(self, x: Tensor): zero_point = dequantize_symbolic_kwargs['zero_point'] if self._export_q_node: - x = self.quantize_from_floating_point(x, zero_point) + x = self.quantize_from_floating_point(x) else: x = self.quantize_from_integer(x) clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] @@ -484,9 +483,6 @@ def symbolic_execution(self, x: Tensor): exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] - # Workaround to trick the tracer into believing all return values are used - quantize_symbolic_kwargs['zero_point'] = zero_point - self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) # If original dtype of the input is (b)float16, cast the input to float32 if x.dtype == torch.float16 or x.dtype == torch.bfloat16: diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index e3b693618..af04a99da 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -8,6 +8,7 @@ import warnings from packaging import version +import torch.library as tlib try: import onnx @@ -128,6 +129,18 @@ def export_onnx( model_bytes = BytesIO() export_target = model_bytes onnx_export_kwargs['verbose'] = True + + # workaround for fp8 not having many operators implemented + lib = tlib.Library("aten", "IMPL") + + def equal_cpu(self, other): + if self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + return torch.tensor([True]) + else: + return self.__eq__(other) + + lib.impl("equal", equal_cpu, "CPU") + torch.onnx.export(module, args, export_target, **onnx_export_kwargs) # restore the model to previous properties From 520db85183a4681ea14e45b605cfa64f343f08db Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:27:04 +0100 Subject: [PATCH 15/44] removed verbose flag --- src/brevitas/export/onnx/manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index af04a99da..e8de21f33 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -128,7 +128,6 @@ def export_onnx( else: model_bytes = BytesIO() export_target = model_bytes - onnx_export_kwargs['verbose'] = True # workaround for fp8 not having many operators implemented lib = tlib.Library("aten", "IMPL") From 2bb2895d2965709948456c3df71ff3e4406d70b8 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:05:44 +0100 Subject: [PATCH 16/44] created context manager for fp8 workaround --- src/brevitas/export/onnx/manager.py | 37 +++++++++++++++++++---------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index e8de21f33..fe90ee008 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -8,7 +8,6 @@ import warnings from packaging import version -import torch.library as tlib try: import onnx @@ -31,6 +30,28 @@ from ..manager import ExportContext +# workaround for fp8 not having many operators implemented +class Fp8Workaround(): + + def __init__(self): + pass + + def __enter__(self): + if torch_version >= version.parse('2.1.0'): + self.lib = torch.library.Library("aten", "IMPL") + + def equal_cpu(self, other): + if self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + return torch.tensor([True]) + else: + return self.__eq__(other) + + self.lib.impl("equal", equal_cpu, "CPU") + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.lib = None + + class ONNXBaseManager(BaseManager, ABC): model_transforms = [] @@ -129,18 +150,8 @@ def export_onnx( model_bytes = BytesIO() export_target = model_bytes - # workaround for fp8 not having many operators implemented - lib = tlib.Library("aten", "IMPL") - - def equal_cpu(self, other): - if self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - return torch.tensor([True]) - else: - return self.__eq__(other) - - lib.impl("equal", equal_cpu, "CPU") - - torch.onnx.export(module, args, export_target, **onnx_export_kwargs) + with Fp8Workaround(): + torch.onnx.export(module, args, export_target, **onnx_export_kwargs) # restore the model to previous properties module.apply(lambda m: _restore_act_caching_mode(m)) From 8ffce48930828f6cc087aa61e95283f6c363bb48 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:16:18 +0100 Subject: [PATCH 17/44] added check that objects being compared are tensors in the fp8 workaround --- src/brevitas/export/onnx/manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index fe90ee008..a3243b6cd 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -41,7 +41,10 @@ def __enter__(self): self.lib = torch.library.Library("aten", "IMPL") def equal_cpu(self, other): - if self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if (isinstance(self, Tensor) and + self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)) or ( + isinstance(other, Tensor) and + other.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)): return torch.tensor([True]) else: return self.__eq__(other) From 7edf5bd96e6327a888855bf3d696ad34d41019ac Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 14 May 2024 16:53:10 +0100 Subject: [PATCH 18/44] General equal implementation --- src/brevitas/export/onnx/manager.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index a3243b6cd..d23a9b2d6 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -47,12 +47,22 @@ def equal_cpu(self, other): other.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)): return torch.tensor([True]) else: - return self.__eq__(other) + res = True + if not isinstance(self, Tensor): + self = torch.tensor(self) + if not isinstance(other, Tensor): + other = torch.tensor(other) + if self.dim() > 0: + for x, y in zip(self.flatten(), other.flatten()): + res &= x == y + else: + res = self.item() == other.item() + return torch.tensor([res]) self.lib.impl("equal", equal_cpu, "CPU") def __exit__(self, exc_type, exc_value, exc_traceback): - self.lib = None + self.lib._destroy() class ONNXBaseManager(BaseManager, ABC): From bbd5362dc1f43cf445117f8fe9f519ea94b384a8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 14 May 2024 16:56:04 +0100 Subject: [PATCH 19/44] fallback to fp32 if fp8 --- src/brevitas/export/onnx/manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index d23a9b2d6..ddca07a3d 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -45,7 +45,9 @@ def equal_cpu(self, other): self.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)) or ( isinstance(other, Tensor) and other.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)): - return torch.tensor([True]) + self = self.to(torch.float32) + other = other.to(torch.float32) + return torch.equal(self, other) else: res = True if not isinstance(self, Tensor): From 4bc126daaaa47d4eb1d2591df4490bef9623f9f2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 14 May 2024 17:02:03 +0100 Subject: [PATCH 20/44] Fix for PT < 2.1 --- src/brevitas/export/onnx/manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index ddca07a3d..ad4ef40e2 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -34,7 +34,7 @@ class Fp8Workaround(): def __init__(self): - pass + self.lib = None def __enter__(self): if torch_version >= version.parse('2.1.0'): @@ -64,7 +64,8 @@ def equal_cpu(self, other): self.lib.impl("equal", equal_cpu, "CPU") def __exit__(self, exc_type, exc_value, exc_traceback): - self.lib._destroy() + if self.lib is not None: + self.lib._destroy() class ONNXBaseManager(BaseManager, ABC): From a55dcd06878710f0989c4e45a3f0256bb95fa9bc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 14 May 2024 17:13:55 +0100 Subject: [PATCH 21/44] Remove non existent destroy --- src/brevitas/export/onnx/manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index ad4ef40e2..bf35e471d 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -64,8 +64,7 @@ def equal_cpu(self, other): self.lib.impl("equal", equal_cpu, "CPU") def __exit__(self, exc_type, exc_value, exc_traceback): - if self.lib is not None: - self.lib._destroy() + self.lib = None class ONNXBaseManager(BaseManager, ABC): From fabc8ae01c03fa055a46b9f58fcc273465711f47 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 May 2024 13:45:27 +0100 Subject: [PATCH 22/44] Remove import --- src/brevitas/proxy/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py index 126783b27..2ddbdf784 100644 --- a/src/brevitas/proxy/__init__.py +++ b/src/brevitas/proxy/__init__.py @@ -5,7 +5,6 @@ from .parameter_quant import BiasQuantProxyFromInjectorBase from .parameter_quant import DecoupledWeightQuantProxyFromInjector from .parameter_quant import DecoupledWeightQuantWithInputProxyFromInjector -from .parameter_quant import FloatWeightQuantProxyFromInjector from .parameter_quant import WeightQuantProxyFromInjector from .parameter_quant import WeightQuantProxyFromInjectorBase from .runtime_quant import ActQuantProxyFromInjector From 74b65a9b2685d8a3165b960cfcf8862baadadd4e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 May 2024 14:15:48 +0100 Subject: [PATCH 23/44] Fixed imports --- src/brevitas/export/common/handler/qcdq.py | 8 +- src/brevitas/proxy/__init__.py | 3 +- src/brevitas/proxy/runtime_quant.py | 103 --------------------- 3 files changed, 6 insertions(+), 108 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 5d1d0edd2..32a0e77ac 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -10,12 +10,12 @@ from brevitas.export.common import to_0dim_if_scalar from brevitas.export.common import to_item_if_0dim +from brevitas.proxy import ActFloatQuantProxyFromInjector from brevitas.proxy import ActQuantProxyFromInjector from brevitas.proxy import BiasQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector -from brevitas.proxy import FloatActQuantProxyFromInjector -from brevitas.proxy import FloatWeightQuantProxyFromInjector +from brevitas.proxy import WeightFloatQuantProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector @@ -185,7 +185,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): class FloatQCDQCastWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): - handled_layer = FloatWeightQuantProxyFromInjector + handled_layer = WeightFloatQuantProxyFromInjector def quantize_symbolic_kwargs( cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): @@ -417,7 +417,7 @@ def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_ class FloatQCDQCastActQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin, ABC): - handled_layer = FloatActQuantProxyFromInjector + handled_layer = ActFloatQuantProxyFromInjector def quantize_symbolic_kwargs( cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py index 2ddbdf784..ebdc6403c 100644 --- a/src/brevitas/proxy/__init__.py +++ b/src/brevitas/proxy/__init__.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from .float_parameter_quant import WeightFloatQuantProxyFromInjector +from .float_runtime_quant import ActFloatQuantProxyFromInjector from .parameter_quant import BiasQuantProxyFromInjector from .parameter_quant import BiasQuantProxyFromInjectorBase from .parameter_quant import DecoupledWeightQuantProxyFromInjector @@ -10,5 +12,4 @@ from .runtime_quant import ActQuantProxyFromInjector from .runtime_quant import ActQuantProxyFromInjectorBase from .runtime_quant import ClampQuantProxyFromInjector -from .runtime_quant import FloatActQuantProxyFromInjector from .runtime_quant import TruncQuantProxyFromInjector diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 0064980c1..aeb9fd0b7 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -210,109 +210,6 @@ def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTen return out -class FloatActQuantProxyFromInjector(ActQuantProxyFromInjector): - - def init_tensor_quant(self): - tensor_quant = self.quant_injector.tensor_quant - if 'act_impl' in self.quant_injector: - act_impl = self.quant_injector.act_impl - else: - act_impl = None - is_act_enabled = _is_act_enabled(act_impl, tensor_quant) - is_quant_enabled = tensor_quant is not None - self.is_quant_enabled = is_quant_enabled - if is_act_enabled and is_quant_enabled: - self.fused_activation_quant_proxy = FloatFusedActivationQuantProxy( - act_impl, tensor_quant) - elif is_act_enabled and not is_quant_enabled: - self.fused_activation_quant_proxy = FloatFusedActivationQuantProxy( - act_impl, _TensorQuantDisabledIdentity()) - elif not is_act_enabled and is_quant_enabled: - self.fused_activation_quant_proxy = FloatFusedActivationQuantProxy( - Identity(), tensor_quant) - else: - self.fused_activation_quant_proxy = None - - def exponent_bit_width(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.exponent_bit_width - elif self._cached_act is not None: - return self._cached_act.exponent_bit_width - elif self._cached_act is None: - return None - - def mantissa_bit_width(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.mantissa_bit_width - elif self._cached_act is not None: - return self._cached_act.mantissa_bit_width - elif self._cached_act is None: - return None - - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - out = x - if self.fused_activation_quant_proxy is not None: - y = x - if isinstance(y, QuantTensor): - y = y.value - - if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) - elif not self.is_quant_enabled: - y = self.fused_activation_quant_proxy.activation_impl(y) - else: - y = self.fused_activation_quant_proxy(y) - # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, - # otherwise return a simple Tensor - if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): - # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! - value, scale, zero_point, exponent_bit_width, mantissa_bit_width = y - out = QuantTensor( - value, - scale, - zero_point, - exponent_bit_width + mantissa_bit_width, - signed=self.is_signed, - training=self.training) - out.exponent_bit_width = exponent_bit_width - out.mantissa_bit_width = mantissa_bit_width - # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! - elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant - if isinstance(y, tuple): - y = y[0] - if isinstance(x, QuantTensor): - # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! - out = QuantTensor( - y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) - out.exponent_bit_width = x.exponent_bit_width - out.mantissa_bit_width = x.mantissa_bit_width - # !!!!!!!!! PLACEHOLDER CODE !!!!!!!!!!!!!!!!!!!! - else: - out = y - else: - if isinstance(y, tuple): - y = y[0] - out = y - else: - # If fused activation quant proxy is not enabled, return the input - out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): - cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only) - self._cached_act = cached_out - return out - - class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector): def scale(self, force_eval=True): From cf1ea02d785e30198affa697459f6b4e45ff7141 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 May 2024 14:20:21 +0100 Subject: [PATCH 24/44] Fixed imports --- src/brevitas/quant/experimental/float_base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 09b3bb671..1b7191037 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -7,8 +7,8 @@ from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector from brevitas.inject import value -from brevitas.proxy.parameter_quant import FloatWeightQuantProxyFromInjector -from brevitas.proxy.runtime_quant import FloatActQuantProxyFromInjector +from brevitas.proxy import ActFloatQuantProxyFromInjector +from brevitas.proxy import WeightFloatQuantProxyFromInjector from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum @@ -28,11 +28,11 @@ def exponent_bias(exponent_bit_width): class FloatWeightBase(FloatBase): - proxy_class = FloatWeightQuantProxyFromInjector + proxy_class = WeightFloatQuantProxyFromInjector class FloatActBase(FloatBase): - proxy_class = FloatActQuantProxyFromInjector + proxy_class = ActFloatQuantProxyFromInjector class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver): From cda7f1f4c4bbfd39a28c2796915dad8df4cab131 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 May 2024 17:39:21 +0100 Subject: [PATCH 25/44] Fix export --- src/brevitas/export/common/handler/qcdq.py | 20 +++++++- src/brevitas/proxy/float_runtime_quant.py | 56 ++++++++-------------- src/brevitas/proxy/runtime_quant.py | 53 ++++++++------------ 3 files changed, 59 insertions(+), 70 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 32a0e77ac..6eb9b1e4e 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -241,6 +241,10 @@ def prepare_for_export(self, module): self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width + self.symbolic_kwargs['exponent_bias'] = quant_weight.exponent_bias + self.symbolic_kwargs['saturating'] = quant_weight.saturating + self.symbolic_kwargs['inf_values'] = quant_weight.inf_values + self.symbolic_kwargs['nan_values'] = quant_weight.nan_values self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( module.is_narrow_range, module.is_signed, @@ -282,6 +286,10 @@ def symbolic_execution(self, x: Tensor): clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + exponent_bias = self.symbolic_kwargs['exponent_bias'] + saturating = self.symbolic_kwargs['saturating'] + inf_values = self.symbolic_kwargs['inf_values'] + nan_values = self.symbolic_kwargs['nan_values'] scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) @@ -295,7 +303,7 @@ def symbolic_execution(self, x: Tensor): # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) - return x, scale, zero_point, exponent_bit_width, mantissa_bit_width + return x, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): @@ -439,6 +447,10 @@ def prepare_for_export(self, module): self.validate(module) self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width() self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width() + self.symbolic_kwargs['exponent_bias'] = module.exponent_bias() + self.symbolic_kwargs['saturating'] = module.saturating() + self.symbolic_kwargs['inf_values'] = module.inf_values() + self.symbolic_kwargs['nan_values'] = module.nan_values() # (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype # of the scale and we cast it to float32. @@ -482,6 +494,10 @@ def symbolic_execution(self, x: Tensor): clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + exponent_bias = self.symbolic_kwargs['exponent_bias'] + saturating = self.symbolic_kwargs['saturating'] + inf_values = self.symbolic_kwargs['inf_values'] + nan_values = self.symbolic_kwargs['nan_values'] self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) # If original dtype of the input is (b)float16, cast the input to float32 @@ -498,7 +514,7 @@ def symbolic_execution(self, x: Tensor): # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) - return x, scale, zero_point, exponent_bit_width, mantissa_bit_width + return x, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC): diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 5fc1f2411..4151bc555 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -14,43 +14,28 @@ class ActFloatQuantProxyFromInjector(ActQuantProxyFromInjectorBase): def scale(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.scale - elif self._cached_act is not None: - return self._cached_act.scale - elif self._cached_act is None: - return None + return self.retrieve_attribute('scale', force_eval) def zero_point(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.zero_point - elif self._cached_act is not None: - return self._cached_act.zero_point - elif self._cached_act is None: - return None + return self.retrieve_attribute('zero_point', force_eval) - def bit_width(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.bit_width - elif self._cached_act is not None: - return self._cached_act.bit_width - elif self._cached_act is None: - return None + def exponent_bit_width(self, force_eval=True): + return self.retrieve_attribute('exponent_bit_width', force_eval) + + def mantissa_bit_width(self, force_eval=True): + return self.retrieve_attribute('mantissa_bit_width', force_eval) + + def exponent_bias(self, force_eval=True): + return self.retrieve_attribute('exponent_bias', force_eval) + + def saturating(self, force_eval=True): + return self.retrieve_attribute('saturating', force_eval) + + def inf_values(self, force_eval=True): + return self.retrieve_attribute('inf_values', force_eval) + + def nan_values(self, force_eval=True): + return self.retrieve_attribute('nan_values', force_eval) def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: out = x @@ -68,7 +53,8 @@ def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuan y = self.fused_activation_quant_proxy(y) # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): + # We exclude the last two values (inf_values and nan_values) + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index aeb9fd0b7..9e538eb28 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -96,6 +96,23 @@ def __init__(self, quant_layer, quant_injector): self.cache_inference_quant_act = False self.cache_quant_io_metadata_only = True + def internal_forward(self, force_eval): + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out + + def retrieve_attribute(self, attribute, force_eval): + if self.is_quant_enabled: + out = self.internal_forward(force_eval) + return getattr(out, attribute) + elif self._cached_act is not None: + return getattr(self._cached_act, attribute) + elif self._cached_act is None: + return None + @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant @@ -133,43 +150,13 @@ def init_tensor_quant(self): class ActQuantProxyFromInjector(ActQuantProxyFromInjectorBase): def scale(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.scale - elif self._cached_act is not None: - return self._cached_act.scale - elif self._cached_act is None: - return None + return self.retrieve_attribute('scale', force_eval) def zero_point(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.zero_point - elif self._cached_act is not None: - return self._cached_act.zero_point - elif self._cached_act is None: - return None + return self.retrieve_attribute('zero_point', force_eval) def bit_width(self, force_eval=True): - if self.is_quant_enabled: - current_status = self.training - if force_eval: - self.eval() - out = self.__call__(self._zero_hw_sentinel()) - self.train(current_status) - return out.bit_width - elif self._cached_act is not None: - return self._cached_act.bit_width - elif self._cached_act is None: - return None + return self.retrieve_attribute('bit_width', force_eval) def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]: out = x From 8349391b9820b88642e98989f1e17af33b0bc44a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 24 May 2024 00:00:17 +0100 Subject: [PATCH 26/44] more testing --- src/brevitas/export/common/handler/qcdq.py | 88 ++++++++++++++++++++-- tests/brevitas_ort/__init__.py | 9 --- tests/brevitas_ort/common.py | 14 +++- tests/brevitas_ort/quant_module_cases.py | 14 +++- tests/brevitas_ort/test_quant_module.py | 2 +- 5 files changed, 103 insertions(+), 24 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 6eb9b1e4e..9f8e5141b 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -17,6 +17,7 @@ from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector from brevitas.proxy import WeightFloatQuantProxyFromInjector from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.proxy.float_parameter_quant import BiasFloatQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector @@ -216,10 +217,10 @@ def prepare_quantize_from_floating_point(self, module): module.is_signed) def prepare_quantize_from_integer(self, module): - int_weights = { - tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) + minifloat_weight = { + tm.weight.data_ptr(): tm.quant_weight().minifloat(float_datatype=True) for tm in module.tracked_module_list} - self.symbolic_kwargs['int_weights'] = int_weights + self.symbolic_kwargs['minifloat_weight'] = minifloat_weight def prepare_for_export(self, module): if module.is_quant_enabled: @@ -269,7 +270,7 @@ def quantize_from_floating_point(self, x: Tensor): return x def quantize_from_integer(self, x: Tensor): - return self.symbolic_kwargs['int_weights'][x.data_ptr()] + return self.symbolic_kwargs['minifloat_weight'][x.data_ptr()] def symbolic_execution(self, x: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' @@ -623,6 +624,79 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width +class FloatCDQCastBiasQuantProxyHandlerMixin(DQCastMixin, + QuantAxisMixin, + FloatZeroPointHandlerMixin, + ABC): + handled_layer = BiasFloatQuantProxyFromInjector + + def validate(self, module): + if module.bit_width() is not None: + assert module.bit_width() > 1., 'Binary quant not supported' + assert module.is_signed, 'Unsigned bias not supported.' + assert module.rounding_mode == 'ROUND', 'Only round to nearest even supported.' + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + int_biases = { + tm.bias.data_ptr(): tm.quant_bias().minifloat(float_datatype=False) + for tm in module.tracked_module_list} + self.symbolic_kwargs = { + 'int_biases': int_biases, + 'scale': module.scale(), + 'zero_point': module.zero_point(), + 'exponent_bit_width': module.exponent_bit_width(), + 'mantissa_bit_width': module.mantissa_bit_width(), + 'exponent_bias': module.exponent_bias(), + 'saturating': module.saturating(), + 'inf_values': module.inf_values(), + 'nan_values': module.nan_values()} + + else: + self.symbolic_kwargs = None + + def symbolic_execution(self, x: Tensor, input_scale=None): + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + int_bias = self.symbolic_kwargs['int_biases'][x.data_ptr()] + scale = self.symbolic_kwargs['scale'] + zero_point = self.symbolic_kwargs['zero_point'] + exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] + mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] + exponent_bias = self.symbolic_kwargs['exponent_bias'] + saturating = self.symbolic_kwargs['saturating'] + inf_values = self.symbolic_kwargs['inf_values'] + nan_values = self.symbolic_kwargs['nan_values'] + + assert scale is not None or input_scale is not None, 'Input scale required for bias export' + if input_scale is not None: + scale = input_scale + scale_orig_shape = scale.shape + + quant_axis = self.quant_axis(scale) + if self.flatten_dequantize_params: + scale = scale.flatten() + zero_point = zero_point.flatten() + scale = to_0dim_if_scalar(scale) + zero_point = to_0dim_if_scalar(zero_point).expand_as(scale) + zero_point = self.zero_point_with_dtype( + True, exponent_bit_width, mantissa_bit_width, zero_point) # assume signed is True + # If original dtype of scale is (b)float16, store the original dtype + # and cast the scale to float32 + scale_dtype = scale.dtype + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + scale = self.cast_fn(scale, torch.float32) + y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) + # After dequantization, cast both output and scale to the correct dtype + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + y = self.cast_fn(y, scale_dtype) + scale = self.cast_fn(scale, scale_dtype) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return y, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values + + class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC): handled_layer = BiasQuantProxyFromInjector @@ -646,19 +720,17 @@ def prepare_for_export(self, module): else: self.symbolic_kwargs = None - def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): + def symbolic_execution(self, x: Tensor, input_scale=None): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' int_bias = self.symbolic_kwargs['int_biases'][x.data_ptr()] scale = self.symbolic_kwargs['scale'] bit_width = self.symbolic_kwargs['bit_width'] zero_point = self.symbolic_kwargs['zero_point'] assert scale is not None or input_scale is not None, 'Input scale required for bias export' - assert bit_width is not None or input_bit_width is not None, 'Input bit width required for bias export' if input_scale is not None: scale = input_scale scale_orig_shape = scale.shape - if input_bit_width is not None: - bit_width = input_bit_width + quant_axis = self.quant_axis(scale) if self.flatten_dequantize_params: scale = scale.flatten() diff --git a/tests/brevitas_ort/__init__.py b/tests/brevitas_ort/__init__.py index 78315f9c7..b10a7efee 100644 --- a/tests/brevitas_ort/__init__.py +++ b/tests/brevitas_ort/__init__.py @@ -1,11 +1,2 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause - -try: - import torch - - # Avoid fast algorithms that might introduce extra error during fake-quantization - torch.use_deterministic_algorithms(True) -except: - # Introduced in 1.8.0 - pass diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 4d1e16679..7e287a7f7 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -18,6 +18,8 @@ from brevitas.nn import QuantConvTranspose2d from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint @@ -59,7 +61,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'symmetric_per_channel_fixed_point': (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), 'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float': - (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat)} + (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat), + 'fp8_per_tensor_float': (None, Fp8e4m3OCPActPerTensorFloat)} LSTM_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), @@ -143,9 +146,12 @@ def is_brevitas_ort_close( ort_output = odict[exported_model.graph.output[0].name] else: if export_type == 'qcdq': - export_onnx_qcdq(model, input_t, export_path=export_name) - elif export_type == 'qcdq_opset14': - export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name) + export_onnx_qcdq( + model, + input_t, + export_path=export_name, + export_weight_q_node=True, + opset_version=19) elif export_type == 'qonnx_opset14': export_qonnx(model, input_t, opset_version=14, export_path=export_name) else: diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index 7361c1231..8eab547d2 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -27,12 +27,22 @@ def case_quant_wbiol( set_case_id(request.node.callspec.id, QuantWBIOLCases.case_quant_wbiol) weight_quant, io_quant = quantizers + if weight_quant == Fp8e4m3OCPWeightPerTensorFloat: + if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8: + pytest.skip('FP8 export requires total bitwidth equal to 8') + torch.use_deterministic_algorithms(False) + else: + torch.use_deterministic_algorithms(True) + if impl is QuantLinear: layer_kwargs = {'in_features': IN_CH, 'out_features': OUT_CH} else: layer_kwargs = { 'in_channels': IN_CH, 'out_channels': OUT_CH, 'kernel_size': KERNEL_SIZE} + bias_quantizer = None if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else Int32Bias + return_quant_tensor = False if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else True + class Model(nn.Module): def __init__(self): @@ -46,8 +56,8 @@ def __init__(self): weight_bit_width=weight_bit_width, input_bit_width=input_bit_width, output_bit_width=output_bit_width, - bias_quant=Int32Bias, - return_quant_tensor=True) + bias_quant=bias_quantizer, + return_quant_tensor=return_quant_tensor) self.conv.weight.data.uniform_(-0.01, 0.01) def forward(self, x): diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 2b7b6b1cf..e13353a01 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -71,7 +71,7 @@ def test_ort_avgpool(model, current_cases): @parametrize_with_cases('model', cases=QuantRecurrentCases) -@pytest.mark.parametrize('export_type', ['qcdq_opset14', 'qonnx_opset14']) +@pytest.mark.parametrize('export_type', ['qcdq', 'qonnx_opset14']) @requires_pt_ge('1.10') def test_ort_lstm(model, export_type, current_cases): cases_generator_func = current_cases['model'][1] From 11387d311a5b8bd3e245790727a3b6442107f836 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 24 May 2024 14:49:39 +0100 Subject: [PATCH 27/44] Fix --- tests/brevitas_ort/common.py | 2 +- tests/brevitas_ort/test_quant_module.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 7e287a7f7..a5bd740df 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -62,7 +62,7 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), 'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float': (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat), - 'fp8_per_tensor_float': (None, Fp8e4m3OCPActPerTensorFloat)} + 'fp8_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat)} LSTM_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index e13353a01..c183dfb21 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -37,6 +37,8 @@ def test_ort_wbiol(model, export_type, current_cases): if 'dynamic' in quantizer and ((o_bit_width != "o8" or i_bit_width != "i8") or export_type != "qcdq"): pytest.skip('Dynamic Act Quant supported only for 8bit and QCDQ export') + if export_type == 'qonnx' and 'fp8' in quantizer: + pytest.skip('FP8 export requires QCDQ') if impl in ('QuantLinear'): in_size = (1, IN_CH) From 592ccd39ce07e6059631f28a847f676c86d60225 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 24 May 2024 15:47:50 +0100 Subject: [PATCH 28/44] Fix --- tests/brevitas_ort/common.py | 13 ++++++++++--- tests/brevitas_ort/test_quant_module.py | 21 ++++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index a5bd740df..ceaf789f4 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -123,7 +123,14 @@ def recursive_allclose(ort_output, brevitas_output, tolerance): def is_brevitas_ort_close( - model, np_input, export_name, export_type, tolerance=None, first_output_only=False): + model, + np_input, + export_name, + export_type, + tolerance=None, + first_output_only=False, + onnx_opset=14, + export_q_weight=False): input_t = torch.from_numpy(np_input) with torch.no_grad(): brevitas_output = model(input_t) @@ -150,8 +157,8 @@ def is_brevitas_ort_close( model, input_t, export_path=export_name, - export_weight_q_node=True, - opset_version=19) + export_weight_q_node=export_q_weight, + opset_version=onnx_opset) elif export_type == 'qonnx_opset14': export_qonnx(model, input_t, opset_version=14, export_path=export_name) else: diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index c183dfb21..14271cb4a 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -3,13 +3,14 @@ from functools import reduce from operator import mul -import os +from packaging.version import parse import pytest from pytest_cases import get_case_id from pytest_cases import parametrize_with_cases import torch +from brevitas import torch_version from tests.marker import requires_pt_ge from .common import * @@ -20,7 +21,6 @@ @parametrize_with_cases('model', cases=QuantWBIOLCases) @pytest.mark.parametrize('export_type', ['qcdq', 'qonnx']) -@requires_pt_ge('1.8.1') def test_ort_wbiol(model, export_type, current_cases): cases_generator_func = current_cases['model'][1] case_id = get_case_id(cases_generator_func) @@ -29,7 +29,8 @@ def test_ort_wbiol(model, export_type, current_cases): quantizer = case_id.split('-')[-6] o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - + onnx_opset = 14 + export_q_weight = False if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') if 'QuantLinear' in impl and 'asymmetric' in quantizer: @@ -39,6 +40,11 @@ def test_ort_wbiol(model, export_type, current_cases): pytest.skip('Dynamic Act Quant supported only for 8bit and QCDQ export') if export_type == 'qonnx' and 'fp8' in quantizer: pytest.skip('FP8 export requires QCDQ') + if torch_version < parse('2.1') and 'fp8' in quantizer: + pytest.skip('FP8 requires PyTorch 2.1 or higher') + else: + onnx_opset = 19 + export_q_weight = True if impl in ('QuantLinear'): in_size = (1, IN_CH) @@ -57,11 +63,16 @@ def test_ort_wbiol(model, export_type, current_cases): model.eval() export_name = f'qcdq_qop_export_{case_id}.onnx' assert is_brevitas_ort_close( - model, inp, export_name, export_type, tolerance=INT_TOLERANCE, first_output_only=True) + model, + inp, + export_name, + export_type, + tolerance=INT_TOLERANCE, + first_output_only=True, + onnx_opset=onnx_opset) @parametrize_with_cases('model', cases=QuantAvgPoolCases) -@requires_pt_ge('1.8.1') def test_ort_avgpool(model, current_cases): in_size = (1, IN_CH, FEATURES, FEATURES) inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) From 1fc56425a1cdf6e10204658f6b86ef9e0dc8c1da Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 25 May 2024 01:47:16 +0100 Subject: [PATCH 29/44] fix --- tests/brevitas_ort/test_quant_module.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 14271cb4a..0ff6b5a1f 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -42,7 +42,7 @@ def test_ort_wbiol(model, export_type, current_cases): pytest.skip('FP8 export requires QCDQ') if torch_version < parse('2.1') and 'fp8' in quantizer: pytest.skip('FP8 requires PyTorch 2.1 or higher') - else: + elif torch_version >= parse('2.1') and 'fp8' in quantizer: onnx_opset = 19 export_q_weight = True @@ -69,7 +69,8 @@ def test_ort_wbiol(model, export_type, current_cases): export_type, tolerance=INT_TOLERANCE, first_output_only=True, - onnx_opset=onnx_opset) + onnx_opset=onnx_opset, + export_q_weight=export_q_weight) @parametrize_with_cases('model', cases=QuantAvgPoolCases) From 58f46bc4e95926ee5de4d66bb89f08d35e95bc20 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 25 May 2024 14:52:42 +0100 Subject: [PATCH 30/44] Fix minifloat check --- src/brevitas/quant_tensor/float_quant_tensor.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index c2bb99900..f6208687a 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -3,6 +3,7 @@ import torch +from brevitas.function.ops_ste import floor_ste from brevitas.function.ops_ste import round_ste from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import FloatQuantTensorBase @@ -90,17 +91,23 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def tensor(self): return self.value + def internal_scale(self): + internal_scale = floor_ste(torch.log2(torch.abs(self.value))) - self.mantissa_bit_width + internal_scale = torch.clamp_min( + internal_scale, 1. - self.exponent_bias - self.mantissa_bit_width) + internal_scale = torch.exp2(internal_scale) + return internal_scale + @property def _pre_round_float_value(self): value = self.value scale = self.scale - zero_point = self.zero_point if self.scale.dtype == torch.bfloat16: value = self.value.type(torch.float32) scale = self.scale.type(torch.float32) - zero_point = self.zero_point.type(torch.float32) minifloat_value = value / scale - minifloat_value = minifloat_value + zero_point + int_scale = self.internal_scale() + minifloat_value = minifloat_value / int_scale return minifloat_value @property @@ -130,6 +137,7 @@ def device(self): return value_device def minifloat(self, float_datatype=True): + # TODO: Check if OCP and cast to proper data-type if matching assert float_datatype, "Minifloat quant returns only higher precision dtype" if self.is_valid: From bd657b8b4a3db29e670a4a3740c149d7b98aebba Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 25 May 2024 16:56:10 +0100 Subject: [PATCH 31/44] Last fix --- tests/brevitas_ort/test_quant_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index 0ff6b5a1f..693b9274d 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -21,6 +21,7 @@ @parametrize_with_cases('model', cases=QuantWBIOLCases) @pytest.mark.parametrize('export_type', ['qcdq', 'qonnx']) +@requires_pt_ge('1.10') def test_ort_wbiol(model, export_type, current_cases): cases_generator_func = current_cases['model'][1] case_id = get_case_id(cases_generator_func) @@ -74,6 +75,7 @@ def test_ort_wbiol(model, export_type, current_cases): @parametrize_with_cases('model', cases=QuantAvgPoolCases) +@requires_pt_ge('1.10') def test_ort_avgpool(model, current_cases): in_size = (1, IN_CH, FEATURES, FEATURES) inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size) From 630a3e379c429f341628bb15278dda2d54ed3d99 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 27 May 2024 08:55:11 +0100 Subject: [PATCH 32/44] Fix minifloat --- src/brevitas/quant_tensor/float_quant_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index f6208687a..b6abc3fac 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -141,7 +141,7 @@ def minifloat(self, float_datatype=True): assert float_datatype, "Minifloat quant returns only higher precision dtype" if self.is_valid: - float_value = self._pre_round_float_value + float_value = torch.round(self._pre_round_float_value) * self.internal_scale() return float_value.type(self.scale.dtype) else: raise RuntimeError(f"FloatQuantTensor not valid.") From 38a37fb229bbd0b91612adb1ec6c56edd271f0d8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 12:53:57 +0100 Subject: [PATCH 33/44] Review --- src/brevitas/export/common/handler/base.py | 2 +- src/brevitas/export/common/handler/qcdq.py | 51 ++++++++----------- src/brevitas/export/onnx/manager.py | 4 +- .../export/onnx/standard/qcdq/handler.py | 21 ++++++-- src/brevitas/proxy/runtime_quant.py | 1 - 5 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index 09d90fcf5..bf3f69ed4 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -124,7 +124,7 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor): class FloatZeroPointHandlerMixin(ABC): @classmethod - def zero_point_with_dtype(cls, signed, exponent_bit_width, mantissa_bit_width, zero_point): + def zero_point_with_dtype(cls, exponent_bit_width, mantissa_bit_width, zero_point): if exponent_bit_width == 4 and mantissa_bit_width == 3: return zero_point.type(torch.float8_e4m3fn) elif exponent_bit_width == 5 and mantissa_bit_width == 2: diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 9f8e5141b..bf7b531be 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -78,7 +78,7 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): pass @classmethod - def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_signed): + def signed_dtype(cls, exponent_bit_width, mantissa_bit_width): if exponent_bit_width is None or mantissa_bit_width is None: return None if exponent_bit_width == 4 and mantissa_bit_width == 3: @@ -140,8 +140,7 @@ class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, CDQCastMixin, ABC): - def dequantize_symbolic_kwargs( - cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed): + def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): scale_orig_shape = scale.shape axis = cls.quant_axis(scale) if cls.flatten_dequantize_params: @@ -151,7 +150,7 @@ def dequantize_symbolic_kwargs( zero_point = zero_point.flatten() zp = to_0dim_if_scalar(zero_point) zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantisssa_bit_width, zp) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) return { 'scale': scale, 'zero_point': zp, @@ -188,19 +187,18 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): class FloatQCDQCastWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin): handled_layer = WeightFloatQuantProxyFromInjector - def quantize_symbolic_kwargs( - cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): + def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): # compute axis before redefining scale axis = cls.quant_axis(scale) scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()) # expand_as must go after 0-dim check zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantissa_bit_width, zp) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) if cls.itemize_quantize_scalar_params: scale = to_item_if_0dim(scale) zp = to_item_if_0dim(zp) - dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_signed) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width) return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} def prepare_quantize_from_floating_point(self, module): @@ -213,14 +211,10 @@ def prepare_quantize_from_floating_point(self, module): scale, quant_weight.zero_point, quant_weight.exponent_bit_width, - quant_weight.mantissa_bit_width, - module.is_signed) + quant_weight.mantissa_bit_width) - def prepare_quantize_from_integer(self, module): - minifloat_weight = { - tm.weight.data_ptr(): tm.quant_weight().minifloat(float_datatype=True) - for tm in module.tracked_module_list} - self.symbolic_kwargs['minifloat_weight'] = minifloat_weight + def prepare_quantize_from_minifloat(self, module): + raise NotImplementedError def prepare_for_export(self, module): if module.is_quant_enabled: @@ -228,7 +222,7 @@ def prepare_for_export(self, module): if self._export_q_node: self.prepare_quantize_from_floating_point(module) else: - self.prepare_quantize_from_integer(module) + self.prepare_quantize_from_minifloat(module) # Get the first quant weight as representative quant_weight = module.tracked_module_list[0].quant_weight() @@ -255,8 +249,7 @@ def prepare_for_export(self, module): scale, quant_weight.zero_point, quant_weight.exponent_bit_width, - quant_weight.mantissa_bit_width, - module.is_signed) + quant_weight.mantissa_bit_width) else: self.symbolic_kwargs = None @@ -269,8 +262,8 @@ def quantize_from_floating_point(self, x: Tensor): x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) return x - def quantize_from_integer(self, x: Tensor): - return self.symbolic_kwargs['minifloat_weight'][x.data_ptr()] + def quantize_from_minifloat(self, x: Tensor): + raise NotImplementedError def symbolic_execution(self, x: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' @@ -283,7 +276,7 @@ def symbolic_execution(self, x: Tensor): if self._export_q_node: x = self.quantize_from_floating_point(x) else: - x = self.quantize_from_integer(x) + x = self.quantize_from_minifloat(x) clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] exponent_bit_width = self.symbolic_kwargs['exponent_bit_width'] mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width'] @@ -428,19 +421,18 @@ def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_ class FloatQCDQCastActQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin, ABC): handled_layer = ActFloatQuantProxyFromInjector - def quantize_symbolic_kwargs( - cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed): + def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width): # compute axis before redefining scale axis = cls.quant_axis(scale) scale = to_0dim_if_scalar(scale.flatten()) zp = to_0dim_if_scalar(zero_point.flatten()) # expand_as must go after 0-dim check zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantissa_bit_width, zp) + zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp) if cls.itemize_quantize_scalar_params: scale = to_item_if_0dim(scale) zp = to_item_if_0dim(zp) - dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_signed) + dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width) return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} def prepare_for_export(self, module): @@ -465,14 +457,12 @@ def prepare_for_export(self, module): scale, module.zero_point(), module.exponent_bit_width(), - module.mantissa_bit_width(), - module.is_signed) + module.mantissa_bit_width()) self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( scale, module.zero_point(), module.exponent_bit_width(), - module.mantissa_bit_width(), - module.is_signed) + module.mantissa_bit_width()) self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs( module.is_narrow_range, module.is_signed, @@ -628,6 +618,9 @@ class FloatCDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, FloatZeroPointHandlerMixin, ABC): + # TODO: We do not have any bias quantizer so this is not wired to anything. + # Currently we do not support Minifloat -> DQ export for minifloat. + # This has to be rewritten to be QDQ handled_layer = BiasFloatQuantProxyFromInjector def validate(self, module): diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index bf35e471d..ae3270cc9 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -31,7 +31,7 @@ # workaround for fp8 not having many operators implemented -class Fp8Workaround(): +class PatchFp8Ops(): def __init__(self): self.lib = None @@ -165,7 +165,7 @@ def export_onnx( model_bytes = BytesIO() export_target = model_bytes - with Fp8Workaround(): + with PatchFp8Ops(): torch.onnx.export(module, args, export_target, **onnx_export_kwargs) # restore the model to previous properties diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 55003c147..873febac7 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -54,13 +54,13 @@ def validate(self, module): class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC): def validate(self, module): - pass + super().validate(module) class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): - return IntClipFn.apply(x, min_val, max_val) + raise NotImplementedError class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC): @@ -71,8 +71,23 @@ def clip_fn(self, x, min_val, max_val): class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): + def is_ocp(self, module): + is_e4m3 = module.mantissa_bit_width() == 3 and module.exponent_bit_width() == 4 + + is_ocp_e4m3 = is_e4m3 and module.inf_values() is None and module.nan_values() == (('111',)) + + is_e5m2 = module.mantissa_bit_width() == 5 and module.exponent_bit_width() == 2 + + is_ocp_e5m2 = is_e5m2 and module.inf_values() == ( + ('00',)) and module.nan_values() == ('01', '11', '10') + + return is_ocp_e4m3 or is_ocp_e5m2 + def validate(self, module): - pass + assert self.is_ocp(module), 'Only OCP Standard is supported for FP8 export' + if getattr(self, '_export_q_node', True): + assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + super().validate(module) def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9e538eb28..4ef93cad6 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -22,7 +22,6 @@ 'ActQuantProxyProtocol', 'AccQuantProxyProtocol', 'ActQuantProxyFromInjector', - 'FloatActQuantProxyFromInjector', 'TruncQuantProxyFromInjector', 'ClampQuantProxyFromInjector'] From 76b3193b57000c858d5c920b8e27e1f25254dcd1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 13:22:26 +0100 Subject: [PATCH 34/44] Review 2 --- src/brevitas/core/quant/float.py | 12 +++------ .../export/onnx/standard/qcdq/handler.py | 27 +++++++++---------- src/brevitas/function/ops.py | 12 +++++++-- .../quant_tensor/float_quant_tensor.py | 11 +++----- tests/brevitas_ort/quant_module_cases.py | 1 - 5 files changed, 29 insertions(+), 34 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 371c5551c..743561650 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -10,7 +10,7 @@ from brevitas.core.function_wrapper import RoundSte from brevitas.core.scaling import ConstScaling from brevitas.core.utils import StatelessBuffer -from brevitas.function.ops import max_float +from brevitas.function.ops import float_internal_scale from brevitas.function.ops_ste import floor_ste @@ -59,13 +59,6 @@ def __init__( self.scaling_impl = scaling_impl self.float_clamp_impl = float_clamp_impl - @brevitas.jit.script_method - def internal_scale(self, x): - internal_scale = floor_ste(torch.log2(torch.abs(x))) - self.mantissa_bit_width() - internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min()) - internal_scale = torch.exp2(internal_scale) - return internal_scale - @brevitas.jit.script_method def quantize(self, x: torch.Tensor): scaling_impl_value = self.scaling_impl(x) @@ -73,7 +66,8 @@ def quantize(self, x: torch.Tensor): self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) scale = scaling_impl_value / float_scaling_impl_value scaled_x = x / scale - internal_scale = self.internal_scale(scaled_x) + internal_scale = float_internal_scale( + scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min()) val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) return val_fp_quant, scale diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 873febac7..9f4184071 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -53,8 +53,20 @@ def validate(self, module): class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC): + def is_ocp(self, module): + is_e4m3 = module.mantissa_bit_width() == 3 and module.exponent_bit_width() == 4 + + is_ocp_e4m3 = is_e4m3 and module.inf_values() is None and module.nan_values() == (('111',)) + + is_e5m2 = module.mantissa_bit_width() == 5 and module.exponent_bit_width() == 2 + + is_ocp_e5m2 = is_e5m2 and module.inf_values() == ( + ('00',)) and module.nan_values() == ('01', '11', '10') + + return is_ocp_e4m3 or is_ocp_e5m2 + def validate(self, module): - super().validate(module) + assert self.is_ocp(module), 'Only OCP Standard is supported for FP8 export' class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC): @@ -71,20 +83,7 @@ def clip_fn(self, x, min_val, max_val): class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC): - def is_ocp(self, module): - is_e4m3 = module.mantissa_bit_width() == 3 and module.exponent_bit_width() == 4 - - is_ocp_e4m3 = is_e4m3 and module.inf_values() is None and module.nan_values() == (('111',)) - - is_e5m2 = module.mantissa_bit_width() == 5 and module.exponent_bit_width() == 2 - - is_ocp_e5m2 = is_e5m2 and module.inf_values() == ( - ('00',)) and module.nan_values() == ('01', '11', '10') - - return is_ocp_e4m3 or is_ocp_e5m2 - def validate(self, module): - assert self.is_ocp(module), 'Only OCP Standard is supported for FP8 export' if getattr(self, '_export_q_node', True): assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' super().validate(module) diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 10717774c..9a6f6db6f 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -5,13 +5,12 @@ Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler. """ -from typing import List, Optional, Tuple import torch from torch import Tensor import brevitas -from brevitas.utils.float_quant_utils import get_minifloat_value +from brevitas.function.ops_ste import floor_ste @brevitas.jit.script @@ -219,3 +218,12 @@ def get_upper_bound_on_l1_norm( max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) return max_accumulator_mag * max_input_mag_inverse + + +@brevitas.jit.script +def float_internal_scale( + x: Tensor, mantissa_bit_width: Tensor, fp_internal_scale_min: Tensor) -> Tensor: + internal_scale = floor_ste(torch.log2(torch.abs(x))) - mantissa_bit_width + internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) + internal_scale = torch.exp2(internal_scale) + return internal_scale diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index b6abc3fac..3fa8fecf1 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -3,6 +3,7 @@ import torch +from brevitas.function.ops import float_internal_scale from brevitas.function.ops_ste import floor_ste from brevitas.function.ops_ste import round_ste from brevitas.quant_tensor import _unpack_quant_tensor @@ -91,13 +92,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def tensor(self): return self.value - def internal_scale(self): - internal_scale = floor_ste(torch.log2(torch.abs(self.value))) - self.mantissa_bit_width - internal_scale = torch.clamp_min( - internal_scale, 1. - self.exponent_bias - self.mantissa_bit_width) - internal_scale = torch.exp2(internal_scale) - return internal_scale - @property def _pre_round_float_value(self): value = self.value @@ -106,7 +100,8 @@ def _pre_round_float_value(self): value = self.value.type(torch.float32) scale = self.scale.type(torch.float32) minifloat_value = value / scale - int_scale = self.internal_scale() + fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + int_scale = float_internal_scale(self.value, self.mantissa_bit_width(), fp_internal_scale) minifloat_value = minifloat_value / int_scale return minifloat_value diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index 8eab547d2..d4e607975 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -41,7 +41,6 @@ def case_quant_wbiol( 'in_channels': IN_CH, 'out_channels': OUT_CH, 'kernel_size': KERNEL_SIZE} bias_quantizer = None if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else Int32Bias - return_quant_tensor = False if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else True class Model(nn.Module): From f2f8969cd1e4093269af63ef48c30e71fc621544 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 13:34:52 +0100 Subject: [PATCH 35/44] fix --- src/brevitas/function/ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 9a6f6db6f..d378ff899 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -10,7 +10,6 @@ from torch import Tensor import brevitas -from brevitas.function.ops_ste import floor_ste @brevitas.jit.script @@ -223,6 +222,9 @@ def get_upper_bound_on_l1_norm( @brevitas.jit.script def float_internal_scale( x: Tensor, mantissa_bit_width: Tensor, fp_internal_scale_min: Tensor) -> Tensor: + # Avoid circular import + from brevitas.function.ops_ste import floor_ste + internal_scale = floor_ste(torch.log2(torch.abs(x))) - mantissa_bit_width internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) internal_scale = torch.exp2(internal_scale) From 44579f86bce4232e3cfd89bf7bebeadbe1766873 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 13:51:23 +0100 Subject: [PATCH 36/44] Typo --- tests/brevitas_ort/quant_module_cases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index d4e607975..d8eb6e7aa 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -56,7 +56,7 @@ def __init__(self): input_bit_width=input_bit_width, output_bit_width=output_bit_width, bias_quant=bias_quantizer, - return_quant_tensor=return_quant_tensor) + return_quant_tensor=True) self.conv.weight.data.uniform_(-0.01, 0.01) def forward(self, x): From 038cba91cfc1781a2298dc2afb6625b211b11cd2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 14:17:30 +0100 Subject: [PATCH 37/44] fix tests --- tests/brevitas/core/test_float_quant.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 43c090e6c..090e1c751 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -12,6 +12,7 @@ from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling +from brevitas.function.ops import float_internal_scale from brevitas.function.ops import max_float from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st @@ -192,7 +193,8 @@ def test_inner_scale(inp, minifloat_format, scale): max_value = max_val if max_available_float is None else torch.min( max_value, max_available_float) # call internal scale - internal_scale = float_quant.internal_scale(scaled_inp) + internal_scale = float_internal_scale( + scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min()) val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale) if signed: val_fp_quant = torch.clip(val_fp_quant, -1. * max_val, max_val) From 198c5afdc77c056b8e37a5864f982e49d6ce8286 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 14:54:00 +0100 Subject: [PATCH 38/44] Typo --- src/brevitas/quant_tensor/float_quant_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 3fa8fecf1..99831f526 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -101,7 +101,7 @@ def _pre_round_float_value(self): scale = self.scale.type(torch.float32) minifloat_value = value / scale fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width(), fp_internal_scale) + int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) minifloat_value = minifloat_value / int_scale return minifloat_value From c3d7d3c6488ea5a9e7d8cecd7baa4b873ddf72f2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 15:11:11 +0100 Subject: [PATCH 39/44] fix --- src/brevitas/export/common/handler/qcdq.py | 4 ++-- tests/brevitas_ort/quant_module_cases.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index bf7b531be..9a91d1f5a 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -286,7 +286,7 @@ def symbolic_execution(self, x: Tensor): nan_values = self.symbolic_kwargs['nan_values'] scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') # Workaround to trick the tracer into believing all return values are used - self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) + self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width, exponent_bias) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) @@ -490,7 +490,7 @@ def symbolic_execution(self, x: Tensor): inf_values = self.symbolic_kwargs['inf_values'] nan_values = self.symbolic_kwargs['nan_values'] - self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width) + self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width, exponent_bias) # If original dtype of the input is (b)float16, cast the input to float32 if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = self.cast_fn(x, torch.float32) diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index d8eb6e7aa..9bad4e89c 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -41,6 +41,8 @@ def case_quant_wbiol( 'in_channels': IN_CH, 'out_channels': OUT_CH, 'kernel_size': KERNEL_SIZE} bias_quantizer = None if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else Int32Bias + # Required because of numpy error with FP8 data type. Export iself works fine. + return_quant_tensor = False if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else True class Model(nn.Module): @@ -56,7 +58,7 @@ def __init__(self): input_bit_width=input_bit_width, output_bit_width=output_bit_width, bias_quant=bias_quantizer, - return_quant_tensor=True) + return_quant_tensor=return_quant_tensor) self.conv.weight.data.uniform_(-0.01, 0.01) def forward(self, x): From fef531dc55611142a01d68823c45d329d0f83670 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 15:54:40 +0100 Subject: [PATCH 40/44] last fix --- src/brevitas/quant_tensor/float_quant_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 99831f526..6b71866a3 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -136,7 +136,9 @@ def minifloat(self, float_datatype=True): assert float_datatype, "Minifloat quant returns only higher precision dtype" if self.is_valid: - float_value = torch.round(self._pre_round_float_value) * self.internal_scale() + fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + float_value = torch.round(self._pre_round_float_value) * int_scale return float_value.type(self.scale.dtype) else: raise RuntimeError(f"FloatQuantTensor not valid.") From 6431882c0f243ab7a154fc40b986b7658f85f2c2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 29 May 2024 09:42:08 +0100 Subject: [PATCH 41/44] Fix JIT --- src/brevitas/core/quant/float.py | 3 +-- src/brevitas/function/ops.py | 12 ------------ src/brevitas/quant_tensor/float_quant_tensor.py | 4 +--- src/brevitas/quant_tensor/float_torch_handler.py | 6 ------ src/brevitas/utils/torch_utils.py | 14 ++++++++++++++ 5 files changed, 16 insertions(+), 23 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 9d5b6cb3e..929024c63 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -10,8 +10,7 @@ from brevitas.core.function_wrapper import RoundSte from brevitas.core.scaling import ConstScaling from brevitas.core.utils import StatelessBuffer -from brevitas.function.ops import float_internal_scale -from brevitas.function.ops_ste import floor_ste +from brevitas.utils.torch_utils import float_internal_scale class FloatQuant(brevitas.jit.ScriptModule): diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index d378ff899..6751ab69c 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -217,15 +217,3 @@ def get_upper_bound_on_l1_norm( max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) return max_accumulator_mag * max_input_mag_inverse - - -@brevitas.jit.script -def float_internal_scale( - x: Tensor, mantissa_bit_width: Tensor, fp_internal_scale_min: Tensor) -> Tensor: - # Avoid circular import - from brevitas.function.ops_ste import floor_ste - - internal_scale = floor_ste(torch.log2(torch.abs(x))) - mantissa_bit_width - internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) - internal_scale = torch.exp2(internal_scale) - return internal_scale diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 6b71866a3..b06466d2d 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -3,12 +3,10 @@ import torch -from brevitas.function.ops import float_internal_scale -from brevitas.function.ops_ste import floor_ste -from brevitas.function.ops_ste import round_ste from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import FloatQuantTensorBase from brevitas.quant_tensor import QuantTensor +from brevitas.utils.torch_utils import float_internal_scale from .float_torch_handler import FLOAT_QUANT_TENSOR_FN_HANDLER from .torch_handler import QUANT_TENSOR_FN_HANDLER diff --git a/src/brevitas/quant_tensor/float_torch_handler.py b/src/brevitas/quant_tensor/float_torch_handler.py index 05386733a..7fb4507c1 100644 --- a/src/brevitas/quant_tensor/float_torch_handler.py +++ b/src/brevitas/quant_tensor/float_torch_handler.py @@ -1,14 +1,8 @@ import functools -import math -import warnings import torch import torch.nn.functional as F -from brevitas.function.ops import max_int -from brevitas.function.ops_ste import ceil_ste -from brevitas.utils.torch_utils import compute_channel_view_shape - FLOAT_QUANT_TENSOR_FN_HANDLER = {} diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 9392c001d..f7dbe9ef3 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -7,6 +7,9 @@ import torch from torch.nn import Sequential +import brevitas +from brevitas.function.ops_ste import floor_ste + class TupleSequential(Sequential): @@ -86,3 +89,14 @@ def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int): broadcast_shape = [1] * len(tensor.size()) broadcast_shape[channel_dim] = -1 return tuple(broadcast_shape) + + +@brevitas.jit.script +def float_internal_scale( + x: torch.Tensor, mantissa_bit_width: torch.Tensor, + fp_internal_scale_min: torch.Tensor) -> torch.Tensor: + + internal_scale = floor_ste(torch.log2(torch.abs(x))) - mantissa_bit_width + internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) + internal_scale = torch.exp2(internal_scale) + return internal_scale From 4b78543f66473665aec9e99ca07f1f196057030a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 29 May 2024 10:31:38 +0100 Subject: [PATCH 42/44] Fix import --- tests/brevitas/core/test_float_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 090e1c751..021365239 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -12,8 +12,8 @@ from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling -from brevitas.function.ops import float_internal_scale from brevitas.function.ops import max_float +from brevitas.utils.torch_utils import float_internal_scale from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format From d762c991a0a77f9b37a90513ac97a6887188c8c9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 29 May 2024 13:54:52 +0100 Subject: [PATCH 43/44] Last fix --- tests/brevitas/export/test_onnx_fp8.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index b460bff24..140485256 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -10,8 +10,10 @@ import brevitas.nn as qnn from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from tests.marker import jit_disabled_for_mock +@jit_disabled_for_mock() def test_simple_fp8_export(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -21,6 +23,7 @@ def test_simple_fp8_export(): assert True +@jit_disabled_for_mock() def test_fp8_export_activation(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -30,6 +33,7 @@ def test_fp8_export_activation(): assert True +@jit_disabled_for_mock() def test_fp8_export_export_activation(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -38,9 +42,3 @@ def test_fp8_export_export_activation(): 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat) export_onnx_qcdq(model, torch.randn(1, 3), 'weight_act_fp8.onnx', export_weight_q_node=True) assert True - - -if __name__ == "__main__": - #test_fp8_export_activation() - test_fp8_export_export_activation() - print("Done") From ac5e58c2bc3de7ce9a7f37df7b3e1282ac1940e2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 29 May 2024 13:57:32 +0100 Subject: [PATCH 44/44] correct skip --- tests/brevitas/export/test_onnx_fp8.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index 140485256..b7e017484 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -10,10 +10,10 @@ import brevitas.nn as qnn from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat -from tests.marker import jit_disabled_for_mock +from tests.marker import jit_disabled_for_export -@jit_disabled_for_mock() +@jit_disabled_for_export() def test_simple_fp8_export(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -23,7 +23,7 @@ def test_simple_fp8_export(): assert True -@jit_disabled_for_mock() +@jit_disabled_for_export() def test_fp8_export_activation(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -33,7 +33,7 @@ def test_fp8_export_activation(): assert True -@jit_disabled_for_mock() +@jit_disabled_for_export() def test_fp8_export_export_activation(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}")