From bddfe1e2409140d7dcda1e63cee6bfc7275d79fb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 23:06:34 +0000 Subject: [PATCH] Feat (brevitas_examples): Po2 per channel float OCP weight quantization --- src/brevitas/quant/base.py | 3 ++- src/brevitas_examples/common/generative/quantize.py | 3 +++ src/brevitas_examples/common/generative/quantizers.py | 9 +++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index f2af79f5f..0fd5e683a 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -250,8 +250,9 @@ class PerChannelPoTScaling8bit(ExtendedInjector): """ """ scaling_per_output_type = ScalingPerOutputType.CHANNEL - restrict_scaling_type = RestrictValueType.FP + restrict_scaling_type = RestrictValueType.POWER_OF_TWO bit_width = 8 + restrict_value_float_to_int_impl = CeilSte class PerTensorPoTScaling8bit(ExtendedInjector): diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index f11c42efd..b125718a5 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -59,6 +59,7 @@ from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint +from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFixedPoint @@ -131,6 +132,8 @@ 'per_group': { 'sym': MXFloat8e4m3Weight}}, 'mse': { + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFixedPointMSE}, 'per_group': { 'sym': MXFloat8e4m3WeightMSE}}}}, 'float_fnuz': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 8c1ae119c..25c37eb17 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -27,9 +27,12 @@ from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.quant.base import HQOWeightZeroPoint +from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.base import PerChannelPoTScaling8bit from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8WeightPerChannelFloat from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO @@ -141,3 +144,9 @@ class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3ActPerTensorFloat): restrict_scaling_type = RestrictValueType.POWER_OF_TWO restrict_value_float_to_int_impl = FloorSte proxy_class = ActFloatQuantProxyFromInjector + + +class Fp8e4m3OCPWeightPerChannelFixedPointMSE(MSESymmetricScale, + PerChannelPoTScaling8bit, + Fp8e4m3OCPWeightPerChannelFloat): + pass