diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index a313ffcaf..d18150a10 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -1,14 +1,12 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.inject.enum import ScalingPerOutputType +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.quant.base import * from brevitas.quant.base import HQOActZeroPoint -from brevitas.quant.base import HQOAsymmetricScale from brevitas.quant.base import HQOZeroPoint -from brevitas.quant.scaled_int import Int8WeightPerTensorFloat from brevitas.quant.solver.act import ActQuantSolver -from brevitas.quant.solver.bias import BiasQuantSolver -from brevitas.quant.solver.trunc import TruncQuantSolver from brevitas.quant.solver.weight import WeightQuantSolver __all__ = [ @@ -150,7 +148,7 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale, class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat): """ 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer - zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses. + zero point. Zero-point is initialized from HQO local loss. Examples: >>> from brevitas.nn import QuantLinear @@ -162,7 +160,7 @@ class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTen class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat): """ 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer - zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses. + zero point. Zero-point is initialized from HQO local loss. Examples: >>> from brevitas.nn import QuantLinear @@ -171,10 +169,23 @@ class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerCh quantize_zero_point = False +class ShiftedUint8WeightPerGroupFloatHQO(ShiftedUint8WeightPerChannelFloatHQO): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point.Zero-point is initialized from HQO local loss. + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) + """ + group_size = 32 + scaling_per_output_type = ScalingPerOutputType.GROUP + proxy_class = GroupwiseWeightQuantProxyFromInjector + + class ShiftedUint8ActPerTensorFloatHQO(HQOActZeroPoint, ShiftedUint8ActPerTensorFloat): """ 8-bit per-tensor unsigned int activations quantizer with floating-point scale factor and - integer zero point. Both zero-point and scale factors are learned parameters initialized from + integer zero point. Zero-point is learned parameter initialized from HQO local loss. Examples: @@ -182,3 +193,11 @@ class ShiftedUint8ActPerTensorFloatHQO(HQOActZeroPoint, ShiftedUint8ActPerTensor >>> act = QuantReLU(act_quant=ShiftedUint8ActPerTensorFloatHQO) """ quantize_zero_point = False + + +class ShiftedUint8WeightGroupQuantFloat(ShiftedUint8WeightPerChannelFloat): + """ + Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. + """ + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index fa1522130..f3293b631 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -44,9 +44,11 @@ from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightGroupQuantFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerGroupFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE @@ -60,8 +62,6 @@ from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat -from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant -from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuantHQO WEIGHT_QUANT_MAP = { 'int': { @@ -73,7 +73,7 @@ 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}, 'per_group': { 'sym': IntWeightSymmetricGroupQuant, - 'asym': ShiftedUintWeightAsymmetricGroupQuant}}, + 'asym': ShiftedUint8WeightGroupQuantFloat}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFloatMSE, @@ -89,7 +89,7 @@ 'sym': Int8WeightPerChannelFloatHQO, 'asym': ShiftedUint8WeightPerChannelFloatHQO}, 'per_group': { - 'asym': ShiftedUintWeightAsymmetricGroupQuantHQO}},}, + 'asym': ShiftedUint8WeightPerGroupFloatHQO}},}, 'po2_scale': { 'stats': { 'per_tensor': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index e9963c239..c3c99a96f 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -58,23 +58,6 @@ class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): scaling_per_output_type = ScalingPerOutputType.GROUP -class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat): - """ - Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. - """ - proxy_class = GroupwiseWeightQuantProxyFromInjector - scaling_per_output_type = ScalingPerOutputType.GROUP - - -class ShiftedUintWeightAsymmetricGroupQuantHQO(HQOWeightZeroPoint, - ShiftedUint8WeightPerChannelFloat): - """ - Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. - """ - proxy_class = GroupwiseWeightQuantProxyFromInjector - scaling_per_output_type = ScalingPerOutputType.GROUP - - class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index ac3091db4..cacbb14c3 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -10,7 +10,6 @@ import torch.nn as nn from brevitas import torch_version -from brevitas.inject.enum import ScalingPerOutputType from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d @@ -22,7 +21,6 @@ from brevitas.nn.quant_mha import QuantMultiheadAttention from brevitas.nn.quant_rnn import QuantLSTM from brevitas.nn.quant_rnn import QuantRNN -from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant from brevitas.quant.scaled_int import Int8ActPerTensorFloat @@ -34,7 +32,7 @@ from brevitas.quant.scaled_int import Int16Bias from brevitas.quant.scaled_int import Uint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat -from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerGroupFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant_tensor import IntQuantTensor @@ -46,20 +44,6 @@ EMBED_DIM = 9 NUM_HEADS = 3 - -class ShiftedUint8WeightPerGroupFloatHQO(ShiftedUint8WeightPerChannelFloatHQO): - """ - 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer - zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses. - Examples: - >>> from brevitas.nn import QuantLinear - >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) - """ - group_size = 4 - scaling_per_output_type = ScalingPerOutputType.GROUP - proxy_class = GroupwiseWeightQuantProxyFromInjector - - LSTM_WEIGHT_QUANTIZER = { 'None': None, 'quant_sym': Int8WeightPerTensorFloat,