Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoupled PerChannel/PerTensor quantization #1025

Merged
merged 11 commits into from
Oct 8, 2024
Merged
93 changes: 76 additions & 17 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,62 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
class PerChannelL2Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L2Norm


class PerChannelL1Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L1Norm


class PerChannelPreNorm(ExtendedInjector):
pre_scaling_impl = ParameterPreScalingWeightNorm
scaling_stats_input_view_shape_impl = OverOutputChannelView
scaling_impl = (this << 1).scaling_impl
normalize_stats_impl = (this << 1).normalize_stats_impl
tracked_parameter_list = (this << 1).tracked_parameter_list
pre_scaling_shape = (this << 1).pre_scaling_shape
permute_dims = (this << 1).permute_dims


class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm):

pre_scaling_impl = AccumulatorAwareParameterPreScaling
accumulator_bit_width = (this << 1).accumulator_bit_width
accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl


class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm):

pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.pre_scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_shape = (this << 1).scaling_shape


class SolvePostScaleGranularity(ExtendedInjector):

@value
def scaling_stats_input_view_shape_impl(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS

@value
def stats_reduce_dim(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return SCALING_STATS_REDUCE_DIM


class WeightNormPerChannelFloatDecoupled(SolvePostScaleGranularity,
SolveStatsReduceDimFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape,
Expand All @@ -359,6 +414,8 @@ def scaling_init(scaling_init_impl, bit_width):
scales = scaling_init_impl.parameter_list_stats() / (pow(2., bit_width - 1.) - 1.)
return scales

per_channel_pre_norm = PerChannelPreNorm

proxy_class = DecoupledWeightQuantProxyFromInjector
tensor_quant = DecoupledRescalingIntQuant
decoupled_int_quant = DecoupledIntQuant
Expand All @@ -367,22 +424,23 @@ def scaling_init(scaling_init_impl, bit_width):
scaling_init_impl = StatsFromParameterScaling
restrict_scaling_impl = LogFloatRestrictValue
scaling_stats_impl = AbsMax
pre_scaling_impl = ParameterPreScalingWeightNorm
restrict_pre_scaling_impl = LogFloatRestrictValue
normalize_stats_impl = L2Norm
normalize_stats_impl = PerChannelL2Norm.normalize_stats_impl
scaling_per_output_type = ScalingPerOutputType.CHANNEL
pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape
pre_scaling_shape = this.scaling_per_output_channel_shape
int_scaling_impl = SingleArgStatelessBuffer(1.)
zero_point_impl = ZeroZeroPoint
pre_zero_point_impl = ZeroZeroPoint
bit_width_impl = BitWidthConst
narrow_range = True
signed = True
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_min_val = 1e-10
pre_scaling_min_val = 1e-10

@value
def pre_scaling_impl():
return this.per_channel_pre_norm.pre_scaling_impl


class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
"""Experimental accumulator-aware weight quantizer based on `Quantized Neural Networks
Expand All @@ -401,16 +459,16 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
details on the arithmetic, see `AccumulatorAwareParameterPreScalingWeightNorm`. For further
details on accumulator-aware quantization (A2Q) technique, see the referenced paper."""

@value
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)

proxy_class = DecoupledWeightQuantWithInputProxyFromInjector
tensor_quant = DecoupledRescalingIntQuantWithInput
pre_scaling_impl = AccumulatorAwareParameterPreScaling
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
normalize_stats_impl = L1Norm # required to align with derivations in paper
per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm
normalize_stats_impl = PerChannelL1Norm.normalize_stats_impl # required to align with derivations in paper
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits

@value
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)


class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
Expand All @@ -421,10 +479,11 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
(2) a more relaxed l1-norm bound that is derived in the referenced paper
"""
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm

@value
def pre_zero_point_impl():
return this.per_channel_pre_norm.pre_zero_point_impl


class MSESubInjectorBase(ExtendedInjector):
Expand Down
20 changes: 19 additions & 1 deletion tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas import torch_version
import brevitas.config as config
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand Down Expand Up @@ -48,20 +49,37 @@
EMBED_DIM = 9
NUM_HEADS = 3


class Int8WeightNormL2PerChannelPerTensorFixedPoint(Int8WeightNormL2PerChannelFixedPoint):
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Int8AccumulatorAwareWeightQuantPerTensorFloat(Int8AccumulatorAwareWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat(
Int8AccumulatorAwareZeroCenterWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


LSTM_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat}

A2Q_WBIOL_WEIGHT_QUANTIZER = {
'quant_a2q': Int8AccumulatorAwareWeightQuant,
'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant}
'quant_a2q_per_tensor': Int8AccumulatorAwareWeightQuantPerTensorFloat,
'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant,
'quant_a2q_plus_per_tensor': Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat}

WBIOL_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
'quant_decoupled_per_tensor': Int8WeightNormL2PerChannelPerTensorFixedPoint,
'quant_mx': MXInt8Weight,
'quant_float': Fp8e4m3WeightPerTensorFloat,
**A2Q_WBIOL_WEIGHT_QUANTIZER}
Expand Down
13 changes: 12 additions & 1 deletion tests/brevitas/nn/test_a2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def calc_a2q_plus_acc_bit_width(
return min_bit_width


calc_fnc = {"quant_a2q": calc_a2q_acc_bit_width, "quant_a2q_plus": calc_a2q_plus_acc_bit_width}
calc_fnc = {
"quant_a2q": calc_a2q_acc_bit_width,
"quant_a2q_per_tensor": calc_a2q_acc_bit_width,
"quant_a2q_plus": calc_a2q_plus_acc_bit_width,
"quant_a2q_plus_per_tensor": calc_a2q_plus_acc_bit_width}


@pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q)
Expand Down Expand Up @@ -94,6 +98,13 @@ def test_quant_wbiol_a2q(model_input, current_cases):

# the tensor quantizer requires a QuantTensor with specified bit-width and sign
quant_weight = model.conv.quant_weight(quant_input)

# test that the scaling factor is per-tensor or per-channel
if kwargs['weight_quant'].endswith('per_tensor'):
assert quant_weight.scale.numel() == 1
else:
assert quant_weight.scale.numel() == model.conv.out_channels

quant_weight = quant_weight.int().float()
if kwargs['model_type'] == 'QuantLinear': # shape = (out_features, in_features)
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=1)
Expand Down
Loading