From 5fa4e572b7dcca9fa1ba4ce6b1a323e4bfb530ce Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 26 Nov 2024 23:33:54 +0000 Subject: [PATCH] Per Row Int po2 kernel --- src/brevitas_examples/common/generative/quantize.py | 3 +++ src/brevitas_examples/common/generative/quantizers.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index a45168246..7ec94f507 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -61,6 +61,7 @@ from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFixedPoint from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant @@ -171,6 +172,8 @@ 'sym': Int8DynamicActPerGroupFloat}}}, 'po2_scale': { 'stats': { + 'per_row': { + 'sym': Int8DynamicActPerRowFixedPoint,}, 'per_group': { 'sym': MXInt8Act}}}}}, 'float': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index ecca9f861..adf2e5bd9 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -81,6 +81,11 @@ class Int8DynamicActPerRowFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): scaling_per_output_channel = True +class Int8DynamicActPerRowFixedPoint(Int8DynamicActPerRowFloat): + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + restrict_value_float_to_int_impl = FloorSte + + class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per group scale.