Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoupled PerChannel/PerTensor quantization #1025

Merged
merged 11 commits into from
Oct 8, 2024
Merged
93 changes: 76 additions & 17 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,62 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
class PerChannelL2Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L2Norm


class PerChannelL1Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L1Norm


class PerChannelPreNorm(ExtendedInjector):
pre_scaling_impl = ParameterPreScalingWeightNorm
scaling_stats_input_view_shape_impl = OverOutputChannelView
scaling_impl = (this << 1).scaling_impl
normalize_stats_impl = (this << 1).normalize_stats_impl
tracked_parameter_list = (this << 1).tracked_parameter_list
pre_scaling_shape = (this << 1).pre_scaling_shape
permute_dims = (this << 1).permute_dims


class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm):

pre_scaling_impl = AccumulatorAwareParameterPreScaling
accumulator_bit_width = (this << 1).accumulator_bit_width
accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl


class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm):

pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.pre_scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_shape = (this << 1).scaling_shape


class SolvePostScaleGranularity(ExtendedInjector):

@value
def scaling_stats_input_view_shape_impl(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS

@value
def stats_reduce_dim(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return SCALING_STATS_REDUCE_DIM


class WeightNormPerChannelFloatDecoupled(SolvePostScaleGranularity,
SolveStatsReduceDimFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape,
Expand All @@ -359,6 +414,8 @@ def scaling_init(scaling_init_impl, bit_width):
scales = scaling_init_impl.parameter_list_stats() / (pow(2., bit_width - 1.) - 1.)
return scales

per_channel_pre_norm = PerChannelPreNorm

proxy_class = DecoupledWeightQuantProxyFromInjector
tensor_quant = DecoupledRescalingIntQuant
decoupled_int_quant = DecoupledIntQuant
Expand All @@ -367,22 +424,23 @@ def scaling_init(scaling_init_impl, bit_width):
scaling_init_impl = StatsFromParameterScaling
restrict_scaling_impl = LogFloatRestrictValue
scaling_stats_impl = AbsMax
pre_scaling_impl = ParameterPreScalingWeightNorm
restrict_pre_scaling_impl = LogFloatRestrictValue
normalize_stats_impl = L2Norm
normalize_stats_impl = PerChannelL2Norm.normalize_stats_impl
scaling_per_output_type = ScalingPerOutputType.CHANNEL
pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape
pre_scaling_shape = this.scaling_per_output_channel_shape
int_scaling_impl = SingleArgStatelessBuffer(1.)
zero_point_impl = ZeroZeroPoint
pre_zero_point_impl = ZeroZeroPoint
bit_width_impl = BitWidthConst
narrow_range = True
signed = True
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_min_val = 1e-10
pre_scaling_min_val = 1e-10

@value
def pre_scaling_impl():
return this.per_channel_pre_norm.pre_scaling_impl


class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
"""Experimental accumulator-aware weight quantizer based on `Quantized Neural Networks
Expand All @@ -401,16 +459,16 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
details on the arithmetic, see `AccumulatorAwareParameterPreScalingWeightNorm`. For further
details on accumulator-aware quantization (A2Q) technique, see the referenced paper."""

@value
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)

proxy_class = DecoupledWeightQuantWithInputProxyFromInjector
tensor_quant = DecoupledRescalingIntQuantWithInput
pre_scaling_impl = AccumulatorAwareParameterPreScaling
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
normalize_stats_impl = L1Norm # required to align with derivations in paper
per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm
normalize_stats_impl = PerChannelL1Norm.normalize_stats_impl # required to align with derivations in paper
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits

@value
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)


class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
Expand All @@ -421,10 +479,11 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
(2) a more relaxed l1-norm bound that is derived in the referenced paper
"""
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm

@value
def pre_zero_point_impl():
return this.per_channel_pre_norm.pre_zero_point_impl


class MSESubInjectorBase(ExtendedInjector):
Expand Down
25 changes: 23 additions & 2 deletions tests/brevitas/export/quant_module_fixture.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a code style question, any reason why we don't use the standard pytest way of sharing fixtures across multiple files (i.e., conftest.py)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a lot of duplication for various reasons and in general tests would need to be restructured a bit.
All this to say that I wouldn't be sure at his point of the cleanest way to do it myself, but if you have a suggestion, all ears.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch import nn

from brevitas.inject.enum import ScalingPerOutputType
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand All @@ -20,6 +21,7 @@
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
Expand All @@ -39,21 +41,33 @@
KERNEL_SIZE = 3
TOLERANCE = 1


class Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat(
Int8AccumulatorAwareZeroCenterWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


A2Q_QUANTIZERS = {
'a2q_per_channel_float': (Int8AccumulatorAwareWeightQuant, Int8ActPerTensorFloat),
'a2q_plus_per_tensor_float':
(Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat, Int8ActPerTensorFloat)}

QUANTIZERS = {
'asymmetric_per_tensor_float':
(ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat),
'symmetric_per_tensor_float': (Int8WeightPerTensorFloat, Int8ActPerTensorFloat),
'asymmetric_per_channel_float':
(ShiftedUint8WeightPerChannelFloat, ShiftedUint8ActPerTensorFloat),
'symmetric_per_channel_float': (Int8WeightPerChannelFloat, Int8ActPerTensorFloat),
'a2q': (Int8AccumulatorAwareWeightQuant, Int8ActPerTensorFloat),
'symmetric_per_tensor_fixed_point': (Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint),
'symmetric_per_channel_fixed_point':
(Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)}
(Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint),
**A2Q_QUANTIZERS}

BIAS_QUANTIZERS = {
'bias_external_scale': (Int32Bias,),
'bias_internal_scale': (Int8BiasPerTensorFloatInternalScaling,)}

QUANT_WBIOL_IMPL = [
QuantLinear,
QuantConv1d,
Expand All @@ -62,6 +76,7 @@
QuantConvTranspose1d,
QuantConvTranspose2d,
QuantConvTranspose3d,]

BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8
BIAS_BIT_WIDTHS = [8, 16, 32]

Expand Down Expand Up @@ -102,6 +117,12 @@ def weight_act_quantizers(quantizers):
return quantizers


@fixture
@parametrize('quantizers', A2Q_QUANTIZERS.items(), ids=list(A2Q_QUANTIZERS.keys()))
def a2q_weight_act_quantizers(quantizers):
return quantizers


@fixture
@parametrize('quantizer', BIAS_QUANTIZERS.items(), ids=list(BIAS_QUANTIZERS.keys()))
def bias_quantizer(quantizer):
Expand Down
37 changes: 36 additions & 1 deletion tests/brevitas/export/test_qonnx_export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import os

import torch

from brevitas.export import enable_debug
Expand All @@ -9,14 +11,15 @@
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant.scaled_int import Int4WeightPerTensorFloatDecoupled
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int16Bias
from brevitas_examples import imagenet_classification
from tests.marker import jit_disabled_for_export

from .quant_module_fixture import *

OUT_CH = 50
IN_CH = 40
TOLERANCE = 1.1
Expand Down Expand Up @@ -48,6 +51,7 @@ def forward(self, x):
model(inp) # collect scale factors
model.eval()
export_qonnx(model, inp, export_path='generic_quant_linear.onnx')
os.remove('generic_quant_linear.onnx')


@jit_disabled_for_export()
Expand Down Expand Up @@ -79,6 +83,37 @@ def forward(self, x):
export_qonnx(model, inp, export_path='generic_decoupled_quant_linear.onnx')


@jit_disabled_for_export()
def test_a2q_quant_linear_export(a2q_weight_act_quantizers):
IN_SIZE = (2, IN_CH)

_, (weight_quant, io_quant) = a2q_weight_act_quantizers

class Model(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = QuantLinear(
out_features=OUT_CH,
in_features=IN_CH,
bias=True,
input_quant=io_quant,
output_quant=io_quant,
weight_quant=weight_quant,
bias_quant=Int16Bias,
return_quant_tensor=False)
self.linear.weight.data.uniform_(-0.1, 0.1)

def forward(self, x):
return self.linear(x)

inp = torch.randn(IN_SIZE)
model = Model()
model(inp) # collect scale factors
model.eval()
export_qonnx(model, inp, export_path='a2q_quant_linear.onnx')


@jit_disabled_for_export()
def test_generic_quant_conv_export():
IN_SIZE = (2, IN_CH, IN_CH, IN_CH)
Expand Down
20 changes: 19 additions & 1 deletion tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas import torch_version
import brevitas.config as config
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand Down Expand Up @@ -48,20 +49,37 @@
EMBED_DIM = 9
NUM_HEADS = 3


class Int8WeightNormL2PerChannelPerTensorFixedPoint(Int8WeightNormL2PerChannelFixedPoint):
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Int8AccumulatorAwareWeightQuantPerTensorFloat(Int8AccumulatorAwareWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat(
Int8AccumulatorAwareZeroCenterWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


LSTM_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat}

A2Q_WBIOL_WEIGHT_QUANTIZER = {
'quant_a2q': Int8AccumulatorAwareWeightQuant,
'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant}
'quant_a2q_per_tensor': Int8AccumulatorAwareWeightQuantPerTensorFloat,
'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant,
'quant_a2q_plus_per_tensor': Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat}

WBIOL_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
'quant_decoupled_per_tensor': Int8WeightNormL2PerChannelPerTensorFixedPoint,
'quant_mx': MXInt8Weight,
'quant_float': Fp8e4m3WeightPerTensorFloat,
**A2Q_WBIOL_WEIGHT_QUANTIZER}
Expand Down
13 changes: 12 additions & 1 deletion tests/brevitas/nn/test_a2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def calc_a2q_plus_acc_bit_width(
return min_bit_width


calc_fnc = {"quant_a2q": calc_a2q_acc_bit_width, "quant_a2q_plus": calc_a2q_plus_acc_bit_width}
calc_fnc = {
"quant_a2q": calc_a2q_acc_bit_width,
"quant_a2q_per_tensor": calc_a2q_acc_bit_width,
"quant_a2q_plus": calc_a2q_plus_acc_bit_width,
"quant_a2q_plus_per_tensor": calc_a2q_plus_acc_bit_width}


@pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q)
Expand Down Expand Up @@ -94,6 +98,13 @@ def test_quant_wbiol_a2q(model_input, current_cases):

# the tensor quantizer requires a QuantTensor with specified bit-width and sign
quant_weight = model.conv.quant_weight(quant_input)

# test that the scaling factor is per-tensor or per-channel
if kwargs['weight_quant'].endswith('per_tensor'):
assert quant_weight.scale.numel() == 1
else:
assert quant_weight.scale.numel() == model.conv.out_channels

quant_weight = quant_weight.int().float()
if kwargs['model_type'] == 'QuantLinear': # shape = (out_features, in_features)
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=1)
Expand Down
Loading