Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Per-Row po2 float ocp #1102

Merged
merged 4 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor
import torch.nn as nn

from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad

Expand All @@ -19,16 +20,31 @@ def __init__(
self,
scaling_stats_impl: nn.Module,
dynamic_scaling_broadcastable_fn: Callable,
scaling_stats_input_view_shape_impl: nn.Module) -> None:
scaling_stats_input_view_shape_impl: nn.Module,
restrict_scaling_impl: nn.Module,
restrict_threshold_impl: nn.Module = None,
scaling_min_val=None) -> None:
super(RuntimeDynamicStatsScaling, self).__init__()
# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl
self.scaling_stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
self.stats_impl = scaling_stats_impl
self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_clamp_scaling = _RestrictClampValue(
scaling_min_val=scaling_min_val, restrict_value_impl=restrict_scaling_impl)
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)

def forward(self, x, threshold) -> Tensor:
shape = x.shape
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
x = self.scaling_stats_input_view_shape_impl(x)
x = self.stats_impl(x) / threshold
x = self.stats_impl(x)
x = self.restrict_clamp_scaling(self.restrict_scaling_pre(x))
x = x / threshold

x = self.dynamic_scaling_broadcastable_fn(x, shape)
return x
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
Expand Down Expand Up @@ -170,6 +172,8 @@
'sym': Int8DynamicActPerGroupFloat}}},
'po2_scale': {
'stats': {
'per_row': {
'sym': Int8DynamicActPerRowFixedPoint,},
'per_group': {
'sym': MXInt8Act}}}}},
'float': {
Expand All @@ -194,6 +198,8 @@
'dynamic': {
'po2_scale': {
'stats': {
'per_row': {
'sym': FP8e4m3OCPDynamicActPerRowFixedPoint},
'per_group': {
'sym': MXFloat8e4m3Act}}}}},
'float_fnuz': {
Expand Down
21 changes: 21 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from torch import nn

from brevitas.core.function_wrapper.ops_ste import FloorSte
from brevitas.core.function_wrapper.shape import OverOutputFeaturesView
from brevitas.core.function_wrapper.shape import OverTensorView
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
Expand All @@ -16,7 +17,9 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
from brevitas.inject import value
from brevitas.inject.enum import RestrictValueType
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_float_parameter_quant import \
GroupwiseWeightFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector
Expand Down Expand Up @@ -78,6 +81,11 @@ class Int8DynamicActPerRowFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
scaling_per_output_channel = True


class Int8DynamicActPerRowFixedPoint(Int8DynamicActPerRowFloat):
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = FloorSte


class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
"""
Symmetric quantizer with per group scale.
Expand Down Expand Up @@ -120,3 +128,16 @@ class Fp8e4m3DynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3ActPerTensorFl
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_per_output_type = ScalingPerOutputType.GROUP
scaling_stats_op = 'min_max'


class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3ActPerTensorFloat):
"""
Symmetric quantizer with per row dynamic scale.
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverOutputFeaturesView
scaling_stats_op = 'min_max'
scaling_per_output_channel = True
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = FloorSte
proxy_class = ActFloatQuantProxyFromInjector