Skip to content

Commit

Permalink
Reshuffling imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 21, 2024
1 parent 670d444 commit 68bba51
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 45 deletions.
33 changes: 26 additions & 7 deletions src/brevitas/quant/shifted_scaled_int.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.inject.enum import ScalingPerOutputType
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.base import *
from brevitas.quant.base import HQOActZeroPoint
from brevitas.quant.base import HQOAsymmetricScale
from brevitas.quant.base import HQOZeroPoint
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
from brevitas.quant.solver.trunc import TruncQuantSolver
from brevitas.quant.solver.weight import WeightQuantSolver

__all__ = [
Expand Down Expand Up @@ -150,7 +148,7 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale,
class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses.
zero point. Zero-point is initialized from HQO local loss.
Examples:
>>> from brevitas.nn import QuantLinear
Expand All @@ -162,7 +160,7 @@ class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTen
class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses.
zero point. Zero-point is initialized from HQO local loss.
Examples:
>>> from brevitas.nn import QuantLinear
Expand All @@ -171,14 +169,35 @@ class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerCh
quantize_zero_point = False


class ShiftedUint8WeightPerGroupFloatHQO(ShiftedUint8WeightPerChannelFloatHQO):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point.Zero-point is initialized from HQO local loss.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO)
"""
group_size = 32
scaling_per_output_type = ScalingPerOutputType.GROUP
proxy_class = GroupwiseWeightQuantProxyFromInjector


class ShiftedUint8ActPerTensorFloatHQO(HQOActZeroPoint, ShiftedUint8ActPerTensorFloat):
"""
8-bit per-tensor unsigned int activations quantizer with floating-point scale factor and
integer zero point. Both zero-point and scale factors are learned parameters initialized from
integer zero point. Zero-point is learned parameter initialized from
HQO local loss.
Examples:
>>> from brevitas.nn import QuantReLU
>>> act = QuantReLU(act_quant=ShiftedUint8ActPerTensorFloatHQO)
"""
quantize_zero_point = False


class ShiftedUint8WeightGroupQuantFloat(ShiftedUint8WeightPerChannelFloat):
"""
Block / group / vector signed asymmetric weight quantizer with float scales and zero-points.
"""
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP
8 changes: 4 additions & 4 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@
from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightGroupQuantFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerGroupFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
Expand All @@ -60,8 +62,6 @@
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuantHQO

WEIGHT_QUANT_MAP = {
'int': {
Expand All @@ -73,7 +73,7 @@
'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat},
'per_group': {
'sym': IntWeightSymmetricGroupQuant,
'asym': ShiftedUintWeightAsymmetricGroupQuant}},
'asym': ShiftedUint8WeightGroupQuantFloat}},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFloatMSE,
Expand All @@ -89,7 +89,7 @@
'sym': Int8WeightPerChannelFloatHQO,
'asym': ShiftedUint8WeightPerChannelFloatHQO},
'per_group': {
'asym': ShiftedUintWeightAsymmetricGroupQuantHQO}},},
'asym': ShiftedUint8WeightPerGroupFloatHQO}},},
'po2_scale': {
'stats': {
'per_tensor': {
Expand Down
17 changes: 0 additions & 17 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,6 @@ class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat):
scaling_per_output_type = ScalingPerOutputType.GROUP


class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat):
"""
Block / group / vector signed asymmetric weight quantizer with float scales and zero-points.
"""
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP


class ShiftedUintWeightAsymmetricGroupQuantHQO(HQOWeightZeroPoint,
ShiftedUint8WeightPerChannelFloat):
"""
Block / group / vector signed asymmetric weight quantizer with float scales and zero-points.
"""
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP


class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
Expand Down
18 changes: 1 addition & 17 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch.nn as nn

from brevitas import torch_version
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand All @@ -22,7 +21,6 @@
from brevitas.nn.quant_mha import QuantMultiheadAttention
from brevitas.nn.quant_rnn import QuantLSTM
from brevitas.nn.quant_rnn import QuantRNN
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
Expand All @@ -34,7 +32,7 @@
from brevitas.quant.scaled_int import Int16Bias
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerGroupFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant_tensor import IntQuantTensor

Expand All @@ -46,20 +44,6 @@
EMBED_DIM = 9
NUM_HEADS = 3


class ShiftedUint8WeightPerGroupFloatHQO(ShiftedUint8WeightPerChannelFloatHQO):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO)
"""
group_size = 4
scaling_per_output_type = ScalingPerOutputType.GROUP
proxy_class = GroupwiseWeightQuantProxyFromInjector


LSTM_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
Expand Down

0 comments on commit 68bba51

Please sign in to comment.