Skip to content

Commit

Permalink
Adding gptq and quant tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 28, 2024
1 parent e2cd495 commit ad54986
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 31 deletions.
47 changes: 31 additions & 16 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
from typing import List, Optional, Set
import warnings

import torch
from torch.fx import GraphModule as TorchGraphModule

from brevitas.fx import GraphModule
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.utils import is_conv_transposed
import brevitas.nn as qnn
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

SUPPORTED_CONV_OP = (
Expand Down Expand Up @@ -194,26 +197,29 @@ def __init__(
self.layer = layer
self.name = name
self.act_order = act_order
if self.layer.weight_quant.is_groupwise:
weight = self.layer.weight_quant.apply_input_view(self.layer.weight)
weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_scaling_shape)
self.layer.weight.data = weight.data
self.layer.in_channels = weight.shape[1] if is_conv_transposed(
self.layer) else weight.shape[0]

weight = layer.weight.data
weight_shape = list(layer.weight.shape)

if create_weight_orig and not hasattr(self.layer, 'weight_orig'):
self.layer.register_buffer('weight_orig', layer.weight.detach().clone())

# By default, use groups = 1
self.groups = 1
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
if is_conv_transposed(self.layer):
weight_shape[1], weight_shape[0] = weight_shape[0], weight_shape[1]
self.groups = self.layer.groups

# Number of rows is equal to the output channels (OC)
self.rows = weight.shape[0]
self.rows = weight_shape[0]
# Number of columns is equal to the input channels (IC)
self.columns = weight.shape[1]
self.columns = torch.prod(weight_shape[1:])
self.len_parallel_layers = len_parallel_layers

self.disable_pre_forward_hook = False
Expand All @@ -232,7 +238,8 @@ def process_input(self, inp):
if isinstance(inp, IntQuantTensor):
if is_quant_enabled and self.quant_metadata is None:
self.quant_metadata = _CachedIO(inp, metadata_only=True)
inp = inp.value
if isinstance(inp, QuantTensor):
inp = inp.value

# If input is unbatched, add batch_size = 1
if len(inp.shape) == 1:
Expand Down Expand Up @@ -262,17 +269,25 @@ def get_quant_weights(self, i, i1, permutation_list):
# For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
# of quantizing only a subset of the entire matrix speeding up the computation of GPxQ
if isinstance(self.layer, qnn.QuantLinear):
index = permutation_list[0][i]
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1]
if self.layer.weight_quant.is_groupwise:
# No slicing, not optimized
index = permutation_list[0][i]
q = self.layer.quant_weight(quant_input=self.quant_metadata).value.unsqueeze(
0) # [1, OC, 1]
q = q[:, :, i:i + 1] # [groups, OC/groups, 1]
else:
index = permutation_list[0][i]
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1]
elif isinstance(self.layer, SUPPORTED_CONV_OP):
# For depthwise and ConvTranspose we fall back to quantizing the entire martix.
# For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
# and we quantize only the selected dimensions.
if self.groups > 1 or (self.groups == 1 and isinstance(
self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):
if self.layer.weight_quant.is_groupwise or self.groups > 1 or (
self.groups == 1 and
isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):

quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata)
quant_weight = quant_weight.value
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
self,
qt_args: Tuple[Any],
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: List[Any]) -> GroupwiseIntQuantTensor:
out, scale, zero_point, bit_width = qt_args
return GroupwiseIntQuantTensor(
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
self,
qt_args: Tuple[Any],
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
self._cached_weight = self.cache_class(
out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = x
out = self.apply_input_view(x)
return out


Expand Down
5 changes: 4 additions & 1 deletion src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from brevitas import config
from brevitas.common import ExportMixin
from brevitas.core.scaling import ScalingPerOutputType
from brevitas.core.utils import StatelessBuffer
from brevitas.inject import BaseInjector as Injector
from brevitas.utils.quant_utils import float_to_int_impl_to_enum
from brevitas.core.scaling import ScalingPerOutputType

__all__ = [
'QuantProxyProtocol',
Expand Down Expand Up @@ -121,6 +121,9 @@ def add_tracked_module(self, module: nn.Module) -> None:
else:
raise RuntimeError("Trying to add None as a parent module.")

def apply_input_view(self, x):
return self.quant_injector.input_view_impl(x)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
y = (self.fused_activation_quant_proxy.activation_impl(y), None)
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = (y, None)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy,
Expand Down
17 changes: 5 additions & 12 deletions tests/brevitas/graph/test_gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,20 @@ def identity_layer_filter_func(layer: nn.Module) -> bool:
return True


filter_func_dict = {"identity": identity_layer_filter_func, "ignore_input": custom_layer_filter_fnc}

apply_gpxq_func_map = {"gpfq": apply_gpfq, "gptq": apply_gptq}


@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("use_quant_activations", [True, False])
@pytest.mark.parametrize("acc_bit_width", [32, 24, 16, 12])
@pytest.mark.parametrize("filter_func_str", filter_func_dict.keys())
@pytest.mark.parametrize("apply_gpxq_tuple", apply_gpxq_func_map.items())
def test_toymodels(
toy_quant_model,
act_order,
use_quant_activations,
acc_bit_width,
filter_func_str,
apply_gpxq_tuple,
toy_quant_model, act_order, use_quant_activations, acc_bit_width, apply_gpxq_tuple,
request):

test_id = request.node.callspec.id
if ('MXFloat' in test_id or 'MXInt' in test_id) and acc_bit_width < 32:
pytest.skip("MX quant does not support accumulator-aware quantization.")

torch.manual_seed(SEED)

Expand All @@ -105,7 +99,7 @@ def test_toymodels(
pytest.skip("GPTQ does not support accumulator-aware quantization.")

if name == 'gpfq':
filter_func = filter_func_dict[filter_func_str]
filter_func = custom_layer_filter_fnc
apply_gpxq = partial(
apply_gpxq, accumulator_bit_width=acc_bit_width, a2q_layer_filter_fnc=filter_func)

Expand All @@ -129,8 +123,7 @@ def test_toymodels(
act_order=act_order,
use_quant_activations=use_quant_activations)

elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or
filter_func_str == 'identity'):
elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations):
# GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will
# raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will
# happen when `use_quant_activations=False` or when the input to a model is not quantized
Expand Down
10 changes: 10 additions & 0 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from brevitas.nn.quant_mha import QuantMultiheadAttention
from brevitas.nn.quant_rnn import QuantLSTM
from brevitas.nn.quant_rnn import QuantRNN
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
Expand Down Expand Up @@ -58,11 +60,13 @@
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
'quant_mx': MXInt8Weight,
**A2Q_WBIOL_WEIGHT_QUANTIZER}

WBIOL_IO_QUANTIZER = {
'None': None,
'batch_quant': (Int8ActPerTensorFloatBatchQuant1d, Int8ActPerTensorFloatBatchQuant2d),
'quant_mx': MXInt8Act,
'quant_sym': Int8ActPerTensorFloat,
'quant_asym': ShiftedUint8ActPerTensorFloat}

Expand Down Expand Up @@ -121,6 +125,12 @@ def build_case_model(
pytest.skip(
"A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor."
)
if (weight_quantizer == MXInt8Weight and
io_quantizer != MXInt8Act) or (weight_quantizer != MXInt8Weight and
io_quantizer == MXInt8Act):
pytest.skip("MX requires input and weights quantization to be aligned")
elif weight_quantizer == MXInt8Weight:
bias_quantizer = None

impl = module.__name__
# BatchQuant has dimension specific quantizers
Expand Down

0 comments on commit ad54986

Please sign in to comment.