diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index 84ee9f355..f1dfc7796 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -16,6 +16,7 @@ from brevitas.function.shape import over_output_channels from brevitas.function.shape import over_output_features from brevitas.function.shape import over_tensor +from brevitas.utils.torch_utils import padding class PermuteDims(brevitas.jit.ScriptModule): @@ -154,17 +155,19 @@ def forward(self, x: torch.Tensor): class OverSubChannelBlockView(brevitas.jit.ScriptModule): - __constants__ = ['expanded_scaling_shape'] + __constants__ = ['expanded_groupwise_shape', 'group_size', 'group_dim'] - def __init__(self, expanded_scaling_shape, padding) -> None: + def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None: super(OverSubChannelBlockView, self).__init__() - self.expanded_scaling_shape = expanded_scaling_shape - self.padding = padding + self.expanded_groupwise_shape = expanded_groupwise_shape + self.group_dim = group_dim + self.group_size = group_size @brevitas.jit.script_method def forward(self, x: torch.Tensor): - y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0) - y = y.view(self.expanded_scaling_shape) + y = torch.nn.functional.pad( + x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.) + y = y.view(self.expanded_groupwise_shape) return y @@ -181,12 +184,9 @@ def forward(self, x): tensor_shape = x.shape tensor_shape_list = list(tensor_shape) - padding = [0, 0] * len(tensor_shape_list) - if tensor_shape_list[self.group_dim] % self.group_size != 0: - padding[2 * self.group_dim] = self.group_size - tensor_shape_list[ - self.group_dim] % self.group_size - padding = list(reversed(padding)) - x = torch.nn.functional.pad(x, padding, mode='constant', value=0) + pad = padding(x, self.group_size, self.group_dim) + + x = torch.nn.functional.pad(x, pad, mode='constant', value=0.) tensor_shape = x.shape tensor_shape_list = list(tensor_shape) diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 6751ab69c..74da08e19 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -189,7 +189,7 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: return value -@brevitas.jit.script +@brevitas.jit.ignore def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias max_mantissa = torch.sum(( diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index fdbaee52f..75992e8fe 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -11,14 +11,17 @@ from typing import List, Optional, Set import warnings +import torch from torch.fx import GraphModule as TorchGraphModule from brevitas.fx import GraphModule from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor +from brevitas.graph.utils import is_conv_transposed import brevitas.nn as qnn from brevitas.quant_tensor import IntQuantTensor +from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO SUPPORTED_CONV_OP = ( @@ -194,8 +197,14 @@ def __init__( self.layer = layer self.name = name self.act_order = act_order + if self.layer.weight_quant.is_groupwise: + weight = self.layer.weight_quant.apply_input_view(self.layer.weight) + weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape) + self.layer.weight.data = weight.data + self.layer.in_channels = weight.shape[1] if is_conv_transposed( + self.layer) else weight.shape[0] - weight = layer.weight.data + weight_shape = torch.tensor(layer.weight.shape) if create_weight_orig and not hasattr(self.layer, 'weight_orig'): self.layer.register_buffer('weight_orig', layer.weight.detach().clone()) @@ -203,17 +212,14 @@ def __init__( # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): - weight = weight.transpose(1, 0) # This performs a view - weight = weight.flatten(1) + if is_conv_transposed(self.layer): + weight_shape[1], weight_shape[0] = weight_shape[0], weight_shape[1] self.groups = self.layer.groups # Number of rows is equal to the output channels (OC) - self.rows = weight.shape[0] + self.rows = weight_shape[0] # Number of columns is equal to the input channels (IC) - self.columns = weight.shape[1] + self.columns = torch.prod(weight_shape[1:]) self.len_parallel_layers = len_parallel_layers self.disable_pre_forward_hook = False @@ -262,17 +268,25 @@ def get_quant_weights(self, i, i1, permutation_list): # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility # of quantizing only a subset of the entire matrix speeding up the computation of GPxQ if isinstance(self.layer, qnn.QuantLinear): - index = permutation_list[0][i] - subtensor_slice_list = [None, (index, index + 1)] - q = self.layer.quant_weight( - subtensor_slice_list=subtensor_slice_list, - quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1] + if self.layer.weight_quant.is_groupwise: + # No slicing, not optimized + index = permutation_list[0][i] + q = self.layer.quant_weight(quant_input=self.quant_metadata).value.unsqueeze( + 0) # [1, OC, 1] + q = q[:, :, i:i + 1] # [groups, OC/groups, 1] + else: + index = permutation_list[0][i] + subtensor_slice_list = [None, (index, index + 1)] + q = self.layer.quant_weight( + subtensor_slice_list=subtensor_slice_list, + quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1] elif isinstance(self.layer, SUPPORTED_CONV_OP): # For depthwise and ConvTranspose we fall back to quantizing the entire martix. # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix # and we quantize only the selected dimensions. - if self.groups > 1 or (self.groups == 1 and isinstance( - self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))): + if self.layer.weight_quant.is_groupwise or self.groups > 1 or ( + self.groups == 1 and + isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))): quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata) quant_weight = quant_weight.value diff --git a/src/brevitas/graph/utils.py b/src/brevitas/graph/utils.py index ed00de3eb..27b18edf0 100644 --- a/src/brevitas/graph/utils.py +++ b/src/brevitas/graph/utils.py @@ -26,13 +26,13 @@ 'get_output_channels', 'get_output_channel_dim'] -CONV_TRANSPOSED = [ +CONV_TRANSPOSED = ( nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, - qnn.QuantConvTranspose3d] + qnn.QuantConvTranspose3d) def module_class_name(m: torch.nn.Module): @@ -146,7 +146,7 @@ def matches_module_pattern(pattern: Iterable, node: Node, modules: Dict[str, Any def is_conv_transposed(module): - return isinstance(module, tuple(CONV_TRANSPOSED)) + return isinstance(module, CONV_TRANSPOSED) def get_output_channel_dim(module): diff --git a/src/brevitas/jit.py b/src/brevitas/jit.py index 0719e1017..6acf43728 100644 --- a/src/brevitas/jit.py +++ b/src/brevitas/jit.py @@ -14,6 +14,7 @@ def _disabled(fn): script_method = torch.jit.script_method script = torch.jit.script + ignore = torch.jit.ignore ScriptModule = torch.jit.ScriptModule Attribute = torch.jit.Attribute @@ -21,5 +22,6 @@ def _disabled(fn): script_method = _disabled script = _disabled + ignore = _disabled ScriptModule = torch.nn.Module Attribute = lambda val, type: val diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 10b80d8a6..9957848c1 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -22,6 +22,11 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size + def apply_input_view(self, x): + x = super().apply_input_view(x) + start_dim = self.group_dim if self.group_dim != -1 else -2 + return x.flatten(start_dim, start_dim + 1) + def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args return GroupwiseFloatQuantTensor( diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index c98ff0eaf..835ebdd5d 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -21,6 +21,11 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size + def apply_input_view(self, x): + x = super().apply_input_view(x) + start_dim = self.group_dim if self.group_dim != -1 else -2 + return x.flatten(start_dim, start_dim + 1) + def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index b4049cb55..3c79a723b 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -22,6 +22,11 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size + def apply_input_view(self, x): + x = super().apply_input_view(x) + start_dim = self.group_dim if self.group_dim != -1 else -2 + return x.flatten(start_dim, start_dim + 1) + def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: out, scale, zero_point, bit_width = qt_args return GroupwiseIntQuantTensor( diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 42e595fd0..ec9418e19 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -21,6 +21,11 @@ def group_dim(self): def group_size(self): return self.quant_injector.group_size + def apply_input_view(self, x): + x = super().apply_input_view(x) + start_dim = self.group_dim if self.group_dim != -1 else -2 + return x.flatten(start_dim, start_dim + 1) + def create_quant_tensor( self, qt_args: Union[torch.Tensor, Tuple[Any]], diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index e7818359e..77a806ee8 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -132,7 +132,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: self._cached_weight = self.cache_class( out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - out = x + out = self.apply_input_view(x) return out diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 2d2ed10c1..9c4255773 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -11,6 +11,7 @@ from brevitas import config from brevitas.common import ExportMixin +from brevitas.core.scaling import ScalingPerOutputType from brevitas.core.utils import StatelessBuffer from brevitas.inject import BaseInjector as Injector from brevitas.utils.quant_utils import float_to_int_impl_to_enum @@ -21,10 +22,7 @@ def _is_groupwise(quant_injector): - if 'group_size' in quant_injector: - return True - else: - return False + return 'scaling_per_output' in quant_injector and quant_injector.scaling_per_output == ScalingPerOutputType.GROUP def _is_narrow_range(quant_injector): @@ -123,6 +121,9 @@ def add_tracked_module(self, module: nn.Module) -> None: else: raise RuntimeError("Trying to add None as a parent module.") + def apply_input_view(self, x): + return self.quant_injector.input_view_impl(x) + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 4ec52e47c..511f914e6 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -168,7 +168,9 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later - y = (self.fused_activation_quant_proxy.activation_impl(y), None) + # If quant is not enabled, we still apply input_view in the case of groupwise + padding + y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) + y = (y, None) else: y = self.fused_activation_quant_proxy(y) # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 76d3f6f3e..ae2c7cfc0 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -111,45 +111,49 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): @value - def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None): + def scaling_shape(scaling_per_output, expanded_groupwise_shape=None, group_dim=None): if scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE elif scaling_per_output == ScalingPerOutputType.CHANNEL: return this.scaling_per_output_channel_shape elif scaling_per_output == ScalingPerOutputType.GROUP: - assert group_size is not None, "Per Group scaling requires group size" - assert group_dim is not None, "Per Group scaling requires group dim" - size = list(module.weight.shape) - size[group_dim] = (size[group_dim] + group_size - 1) // group_size - size.insert(group_dim + 1, 1) - return size + # Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1 + assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured" + assert group_dim is not None, "Per Group scaling not correctly configured" + size = list(expanded_groupwise_shape) + size[group_dim + 1] = 1 + return tuple(size) @value - def reshaped_scaling_shape(module): - return module.weight.shape + def reshaped_groupwise_shape(expanded_groupwise_shape, group_dim, group_size): + new_shape = list(expanded_groupwise_shape) + del new_shape[group_dim + 1] # delete the group_size shape + # Expand the group_dim shape, accounting for padding + new_shape[group_dim] = new_shape[group_dim] * group_size + return new_shape @value - def expanded_scaling_shape(module, group_dim, group_size=None): - assert group_size is not None, "Per Group scaling requires group size" - size = list(module.weight.shape) + def expanded_groupwise_shape(tracked_parameter_list, group_dim, group_size=None): + # expanded_groupwise_shape will be called always to create scaling_shape, but it is only needed + # for groupwise quantization. All other groupwise shape infos are derived from this. + + # If conditions do not allow for groupwise quantization, early exit and return None + if group_size is None: + return + + # If group_size is specified and shared quantization is used, raise an error. + assert len(tracked_parameter_list) == 1, "Shared groupwise quantization is not currently supported" + + weight_shape = tracked_parameter_list[0].shape + size = list(weight_shape) size[group_dim] = (size[group_dim] + group_size - 1) // group_size size.insert(group_dim + 1, group_size) - return size - - @value - def padding(module, group_dim, group_size): - padding = [0, 0] * len(module.weight.shape) - size = list(module.weight.shape) - if size[group_dim] % group_size != 0: - # Padding is done on the left side - padding[2 * group_dim] = group_size - size[group_dim] % group_size - # Padding takes a list of 2 values per dim in reverse order (N_DIM, N_DIM-1,...,0) - # so we need to reverse the order - padding = list(reversed(padding)) - return padding + return tuple(size) @value def group_dim(module, group_size=None): + # group_dim will be called always to create scaling_shape, but it is only needed + # for groupwise quantization. if group_size is not None: return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 96fafebee..2f0d34fba 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from torch.nn import Sequential @@ -102,3 +102,14 @@ def float_internal_scale( internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min) internal_scale = torch.exp2(internal_scale) return internal_scale + + +@brevitas.jit.ignore +def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: + # Given a tensor X, compute the padding aloing group_dim so that groupwise shaping is possible + padding = [0, 0] * len(x.shape) + size = x.shape + if size[group_dim] % group_size != 0: + padding[2 * group_dim] = group_size - size[group_dim] % group_size + padding = list(reversed(padding)) + return padding diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 5716c6f50..9f068bddd 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -86,8 +86,8 @@ class WeightBlockQuantProxyHandler(WeightBlockQuantHandlerBase): def __init__(self): super().__init__() - self.expanded_scaling_shape = None - self.reshaped_scaling_shape = None + self.expanded_groupwise_shape = None + self.reshaped_groupwise_shape = None self.expanded_zero_point_shape = None self.reshaped_zero_point_shape = None @@ -101,8 +101,8 @@ def prepare_for_export(self, module): self.int_dtype = torch.int8 if signed else torch.uint8 self.dtype = quant_weight.value.dtype self.scale = self.export_scale(module, self.bit_width).detach() - self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape - self.reshaped_scaling_shape = self.scaling_impl(module).reshaped_scaling_shape + self.expanded_groupwise_shape = self.scaling_impl(module).expanded_groupwise_shape + self.reshaped_groupwise_shape = self.scaling_impl(module).reshaped_groupwise_shape if (quant_weight.zero_point != 0.).any(): self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach() self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape @@ -124,22 +124,22 @@ def forward(self, x): zero_point = self.zero_point # QCDQ - x = x.view(self.expanded_scaling_shape) + x = x.view(self.expanded_groupwise_shape) x = torch.round((x / scale) + zero_point).type(self.int_dtype) if self.clip_kwargs is not None: x = torch.clip(x, min=self.clip_kwargs['min_val'], max=self.clip_kwargs['max_val']) x = (x.type(self.dtype) - zero_point) * scale # Fix shape post quantization - scale = scale.expand(self.expanded_scaling_shape).contiguous().view( - self.reshaped_scaling_shape) + scale = scale.expand(self.expanded_groupwise_shape).contiguous().view( + self.reshaped_groupwise_shape) # If zero_point is not defined, propagate same shape as scale if self.zero_point is None: zero_point = torch.zeros_like(scale).type(self.int_dtype) else: zero_point = zero_point.expand(self.expanded_zero_point_shape).contiguous().view( self.reshaped_zero_point_shape).type(self.int_dtype) - x = x.view(self.reshaped_scaling_shape) + x = x.view(self.reshaped_groupwise_shape) return x, scale, zero_point, bit_width diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 985986789..2719b48a0 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -13,6 +13,11 @@ from brevitas.graph.equalize import _cross_layer_equalization import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat SEED = 123456 ATOL = 1e-3 @@ -379,8 +384,11 @@ def forward(self, x): [('layer4.0.bn2', 'layer4.0.downsample.1', 'layer4.1.bn2'), ('fc', 'layer4.1.conv1')],] +input_quant, weight_quant = pytest_cases.param_fixtures("input_quant, weight_quant", [(None, Int8WeightPerTensorFloat), (Int8ActPerTensorFloat, Int8WeightPerTensorFloat), (MXInt8Act, MXInt8Weight), (MXFloat8e4m3Act, MXFloat8e4m3Weight)]) + + @pytest_cases.fixture -def quant_conv_with_input_quant_model(): +def quant_conv_with_input_quant_model(input_quant, weight_quant): class QuantConvModel(nn.Module): @@ -388,7 +396,8 @@ def __init__(self) -> None: super().__init__() self.conv_0 = qnn.QuantConv2d( 3, 16, kernel_size=3) # gpxq tests assume no quant on first layer - self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat) + self.conv_1 = qnn.QuantConv2d( + 16, 32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant) def forward(self, x): x = self.conv_0(x) @@ -420,15 +429,17 @@ def forward(self, x): @pytest_cases.fixture -def quant_residual_model(): +def quant_residual_model(input_quant, weight_quant): class QuantResidualModel(nn.Module): def __init__(self) -> None: super().__init__() - self.conv = qnn.QuantConv2d(3, 16, kernel_size=1) - self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1) - self.relu = qnn.QuantReLU(return_quant_tensor=True) + self.conv = qnn.QuantConv2d( + 3, 16, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant) + self.conv_0 = qnn.QuantConv2d( + 16, 3, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant) + self.relu = qnn.QuantReLU(return_quant_tensor=input_quant != None) def forward(self, x): start = x @@ -436,21 +447,32 @@ def forward(self, x): x = self.relu(x) x = self.conv_0(x) x = start + x + return x return QuantResidualModel @pytest_cases.fixture -def quant_convtranspose_model(): +def quant_convtranspose_model(input_quant, weight_quant): class QuantConvTransposeModel(nn.Module): def __init__(self) -> None: super().__init__() - self.relu = qnn.QuantReLU(return_quant_tensor=True) - self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3) - self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3) + self.relu = qnn.QuantReLU(return_quant_tensor=input_quant != None) + self.conv_0 = qnn.QuantConvTranspose2d( + in_channels=3, + out_channels=8, + kernel_size=3, + input_quant=input_quant, + weight_quant=weight_quant) + self.conv_1 = qnn.QuantConvTranspose2d( + in_channels=8, + out_channels=32, + kernel_size=3, + input_quant=input_quant, + weight_quant=weight_quant) def forward(self, x): x = self.conv_0(x) diff --git a/tests/brevitas/graph/test_gpxq.py b/tests/brevitas/graph/test_gpxq.py index 49d470402..4293b8582 100644 --- a/tests/brevitas/graph/test_gpxq.py +++ b/tests/brevitas/graph/test_gpxq.py @@ -72,30 +72,23 @@ def custom_layer_filter_fnc(layer: nn.Module) -> bool: return True -def identity_layer_filter_func(layer: nn.Module) -> bool: - return True - - -filter_func_dict = {"identity": identity_layer_filter_func, "ignore_input": custom_layer_filter_fnc} - apply_gpxq_func_map = {"gpfq": apply_gpfq, "gptq": apply_gptq} @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("use_quant_activations", [True, False]) @pytest.mark.parametrize("acc_bit_width", [32, 24, 16, 12]) -@pytest.mark.parametrize("filter_func_str", filter_func_dict.keys()) @pytest.mark.parametrize("apply_gpxq_tuple", apply_gpxq_func_map.items()) def test_toymodels( - toy_quant_model, - act_order, - use_quant_activations, - acc_bit_width, - filter_func_str, - apply_gpxq_tuple, + toy_quant_model, act_order, use_quant_activations, acc_bit_width, apply_gpxq_tuple, request): test_id = request.node.callspec.id + input_quant = test_id.split('-')[1] + weight_quant = test_id.split('-')[2] + + if ('MXFloat' in input_quant or 'MXInt' in weight_quant) and acc_bit_width < 32: + pytest.skip("MX quant does not support accumulator-aware quantization.") torch.manual_seed(SEED) @@ -105,7 +98,7 @@ def test_toymodels( pytest.skip("GPTQ does not support accumulator-aware quantization.") if name == 'gpfq': - filter_func = filter_func_dict[filter_func_str] + filter_func = custom_layer_filter_fnc apply_gpxq = partial( apply_gpxq, accumulator_bit_width=acc_bit_width, a2q_layer_filter_fnc=filter_func) @@ -130,11 +123,10 @@ def test_toymodels( use_quant_activations=use_quant_activations) elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or - filter_func_str == 'identity'): + input_quant == 'None'): # GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will # raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will # happen when `use_quant_activations=False` or when the input to a model is not quantized - # and `a2q_layer_filter_fnc` does not properly handle it. with pytest.raises(ValueError): apply_gpxq( calib_loader=calib_loader, diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 7b3183b94..a6b1c05af 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -10,6 +10,7 @@ import torch.nn as nn from brevitas import torch_version +import brevitas.config as config from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d @@ -21,6 +22,10 @@ from brevitas.nn.quant_mha import QuantMultiheadAttention from brevitas.nn.quant_rnn import QuantLSTM from brevitas.nn.quant_rnn import QuantRNN +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant from brevitas.quant.scaled_int import Int8ActPerTensorFloat @@ -34,7 +39,6 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant_tensor import IntQuantTensor -from brevitas.quant_tensor import QuantTensor SEED = 123456 OUT_CH = 16 @@ -58,11 +62,15 @@ 'quant_sym': Int8WeightPerTensorFloat, 'quant_asym': ShiftedUint8WeightPerTensorFloat, 'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint, + 'quant_mx': MXInt8Weight, + 'quant_float': Fp8e4m3WeightPerTensorFloat, **A2Q_WBIOL_WEIGHT_QUANTIZER} WBIOL_IO_QUANTIZER = { 'None': None, 'batch_quant': (Int8ActPerTensorFloatBatchQuant1d, Int8ActPerTensorFloatBatchQuant2d), + 'quant_mx': MXInt8Act, + 'quant_float': Fp8e4m3ActPerTensorFloat, 'quant_sym': Int8ActPerTensorFloat, 'quant_asym': ShiftedUint8ActPerTensorFloat} @@ -113,14 +121,26 @@ def build_case_model( is_training, accumulator_bit_width=32): - k, weight_quantizer = weight_quantizer - _, bias_quantizer = bias_quantizer - _, io_quantizer = io_quantizer + weight_quant_name, weight_quantizer = weight_quantizer + bias_quant_name, bias_quantizer = bias_quantizer + io_quant_name, io_quantizer = io_quantizer - if io_quantizer is None and not input_quantized and k in A2Q_WBIOL_WEIGHT_QUANTIZER: + if ((io_quantizer is None and not input_quantized) or + 'float' in io_quant_name) and weight_quant_name in A2Q_WBIOL_WEIGHT_QUANTIZER: pytest.skip( "A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor." ) + if ('mx' in weight_quant_name and + 'mx' not in io_quant_name) or ('mx' not in weight_quant_name and 'mx' in io_quant_name): + pytest.skip("MX requires input and weights quantization to be aligned") + elif weight_quantizer == MXInt8Weight: + if bias_quant_name != 'quant_internal': + pytest.skip("MX quant does not support external scaled bias") + elif weight_quantizer == Fp8e4m3WeightPerTensorFloat or io_quantizer == Fp8e4m3ActPerTensorFloat: + if bias_quant_name != 'quant_internal': + pytest.skip("Float quant does not support external scaled bias") + if return_quant_tensor and ('float' in io_quant_name or io_quantizer is None): + pytest.skip("Float quant requires output quant to generate quant tensor") impl = module.__name__ # BatchQuant has dimension specific quantizers @@ -618,16 +638,18 @@ def case_mha( # Change the case_id based on current value of Parameters set_case_id(request.node.callspec.id, case_mha) - k, weight_quantizer = weight_quantizer + weight_quant_name, weight_quantizer = weight_quantizer _, bias_quantizer = bias_quantizer _, io_quantizer = io_quantizer - if io_quantizer is None and k in A2Q_WBIOL_WEIGHT_QUANTIZER: + if io_quantizer is None and weight_quant_name in A2Q_WBIOL_WEIGHT_QUANTIZER: # Can't rely on a QuantTensor input for quant_mha at this point pytest.skip( "A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor." ) - + # TODO: restore compatibility + if ('mx' in weight_quant_name or 'float' in weight_quant_name): + pytest.skip("MX/Float quant not supported for MHA") # BatchQuant1d works over 3d input but not 2d, so we have a separate quantizer for out_proj if isinstance(io_quantizer, tuple): io_quantizer, out_proj_io_quantizer = io_quantizer diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 55dc42be2..db4f21e02 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -172,6 +172,7 @@ def test_quant_mha(model_input, current_cases): case_id = get_case_id(cases_generator_func) args = case_id.split('-')[1:] # Exclude first argument kwargs = parse_args(args) + is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] if (not is_input_quanttensor or kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external':