diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 776f1f6b2..d355cf704 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -8,6 +8,7 @@ from torch import Tensor import torch.nn as nn +from brevitas.core.restrict_val import _RestrictClampValue from brevitas.core.zero_point import _ScaleShiftZeroPoint from brevitas.function.ops_ste import abs_binary_sign_grad @@ -19,16 +20,31 @@ def __init__( self, scaling_stats_impl: nn.Module, dynamic_scaling_broadcastable_fn: Callable, - scaling_stats_input_view_shape_impl: nn.Module) -> None: + scaling_stats_input_view_shape_impl: nn.Module, + restrict_scaling_impl: nn.Module, + restrict_threshold_impl: nn.Module = None, + scaling_min_val=None) -> None: super(RuntimeDynamicStatsScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl self.scaling_stats_input_view_shape_impl = scaling_stats_input_view_shape_impl self.stats_impl = scaling_stats_impl self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn + self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_clamp_scaling = _RestrictClampValue( + scaling_min_val=scaling_min_val, restrict_value_impl=restrict_scaling_impl) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) def forward(self, x, threshold) -> Tensor: shape = x.shape + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) x = self.scaling_stats_input_view_shape_impl(x) - x = self.stats_impl(x) / threshold + x = self.stats_impl(x) + x = self.restrict_clamp_scaling(self.restrict_scaling_pre(x)) + x = x / threshold x = self.dynamic_scaling_broadcastable_fn(x, shape) return x diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 08543d4e4..f11c42efd 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -58,8 +58,10 @@ from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat +from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFixedPoint from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant @@ -170,6 +172,8 @@ 'sym': Int8DynamicActPerGroupFloat}}}, 'po2_scale': { 'stats': { + 'per_row': { + 'sym': Int8DynamicActPerRowFixedPoint,}, 'per_group': { 'sym': MXInt8Act}}}}}, 'float': { @@ -194,6 +198,8 @@ 'dynamic': { 'po2_scale': { 'stats': { + 'per_row': { + 'sym': FP8e4m3OCPDynamicActPerRowFixedPoint}, 'per_group': { 'sym': MXFloat8e4m3Act}}}}}, 'float_fnuz': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index c3c99a96f..8c1ae119c 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -5,6 +5,7 @@ from torch import nn +from brevitas.core.function_wrapper.ops_ste import FloorSte from brevitas.core.function_wrapper.shape import OverOutputFeaturesView from brevitas.core.function_wrapper.shape import OverTensorView from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling @@ -16,7 +17,9 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value +from brevitas.inject.enum import RestrictValueType from brevitas.inject.enum import ScalingPerOutputType +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.groupwise_float_parameter_quant import \ GroupwiseWeightFloatQuantProxyFromInjector from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector @@ -78,6 +81,11 @@ class Int8DynamicActPerRowFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): scaling_per_output_channel = True +class Int8DynamicActPerRowFixedPoint(Int8DynamicActPerRowFloat): + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + restrict_value_float_to_int_impl = FloorSte + + class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per group scale. @@ -120,3 +128,16 @@ class Fp8e4m3DynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3ActPerTensorFl scaling_impl = RuntimeDynamicGroupStatsScaling scaling_per_output_type = ScalingPerOutputType.GROUP scaling_stats_op = 'min_max' + + +class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3ActPerTensorFloat): + """ + Symmetric quantizer with per row dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverOutputFeaturesView + scaling_stats_op = 'min_max' + scaling_per_output_channel = True + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + restrict_value_float_to_int_impl = FloorSte + proxy_class = ActFloatQuantProxyFromInjector