diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 33ee5fbb4..92e3da2bf 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -31,110 +31,6 @@ from brevitas.quant_tensor import _unpack_quant_tensor -class gpfq_mode(gpxq_mode): - """ - Apply GPFQ algorithm. - - Args: - model (Module): The model to quantize with GPFQ - group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group - of layer names that can be optimized in parallel. Default: None - inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True - create_weight_orig (bool): If True, store the original floating point weights before applying - gpfq. These weights will be used anytime quantization is disabled. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPFQ. Default: False - p (float): The percentage of processed inputs to use. Default: 1.0 - return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the - forward call inside the context manager returns None. Default: False - act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gpfq_mode(model) as gpfq: - >>> gpfq_model = gpfq.model - >>> for i in tqdm(range(gpfq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gpfq_model(img) - >>> gpfq.update() - """ - - def __init__( - self, - model: nn.Module, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - create_weight_orig: bool = True, - use_quant_activations: bool = True, - p: float = 1.0, - return_forward_output: bool = False, - act_order: bool = False, - gpfq_class: Optional[nn.Module] = None) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - create_weight_orig, - use_quant_activations, - act_order, - return_forward_output) - - self.p = p - if gpfq_class is None: - gpfq_class = GPFQ - self.gpfq_class = gpfq_class - assert issubclass(gpfq_class, GPxQ), \ - "Error: expected `gpfq_class` to be derived from `brevitas.graph.gpxq.GPxQ`." - - def catch_stopfwd(self, *args, **kwargs): - # Collect quant input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - - # Disable quantization - self.return_quant_tensor_state = disable_return_quant_tensor(self.model) - self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) - self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) - # Collect float input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - - # Re-enable quantization. If activation quantization is disabled, - # we also disable bias quantization - self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) - if self.use_quant_activations: - self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) - else: - self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) - restore_return_quant_tensor(self.model, self.return_quant_tensor_state) - - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - def initialize_module_optimizer( - self, layer, name, act_order, len_parallel_layers, create_weight_orig): - return self.gpfq_class( - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - p=self.p) - - class GPFQ(GPxQ): """ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main @@ -243,7 +139,11 @@ def single_layer_update(self): self.float_input = self.float_input.to(dev) self.quant_input = self.quant_input.to(dev) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, + dtype=torch.float32) # We don't need full Hessian, we just need the diagonal # Summing over batch dimension H_diag = self.quant_input.transpose(2, 1).square().sum(2) @@ -261,7 +161,8 @@ def single_layer_update(self): for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( - weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1), + weight[group_index, :, + permutation_list[group_index][t]].unsqueeze(1).to(torch.float32), self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] norm = torch.linalg.norm( @@ -272,11 +173,11 @@ def single_layer_update(self): else: q_arg = torch.zeros_like(U[group_index, :, 0]) - weight[group_index, :, permutation_list[group_index][t]] = q_arg + weight[group_index, :, permutation_list[group_index][t]] = q_arg.to(dtype) q = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): U[group_index] -= torch.matmul( - q[group_index].unsqueeze(1), + q[group_index].unsqueeze(1).to(torch.float32), self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0)) del self.float_input @@ -360,7 +261,7 @@ def update_batch(self, module, input, current_layer): # if quant is not enabled, then it is the float input; if it is a float input # then a quant input has already happened and we can update G if not is_quant_enabled: - # Computing the normalized H matrix using CPU buffer + # Computing the normalized G matrix using CPU buffer self.B.copy_(self.quant_input.bmm(inp_processed.transpose(2, 1))) self.G += self.B self.quant_input = None # NOTE: set back to None now that we've used it @@ -401,6 +302,8 @@ def _get_permutation_list(self, weight: Tensor): def single_layer_update(self, percdamp: float = 0.01): assert not self.layer.weight_quant.requires_quant_input, \ "Error: GPFQ does not support weight quantizers that require quantized inputs." + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) weight = self.layer.weight.data dev = weight.device dtype = weight.dtype @@ -446,14 +349,17 @@ def single_layer_update(self, percdamp: float = 0.01): permutation_list = self._get_permutation_list(weight) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, - dtype=dtype) # [Groups, OC/groups, Samples] + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, + dtype=torch.float32) # [Groups, OC/groups, Samples] for t in range(weight.shape[-1]): for group_index in range(self.groups): i = permutation_list[group_index][t] U[group_index] += torch.matmul( - weight[group_index, :, i].unsqueeze(1), + weight[group_index, :, i].unsqueeze(1).to(torch.float32), self.float_input[group_index, :, i].unsqueeze(0), ) # [OC/Groups, 1] * [1, INSHAPE[1]] norm = norms[group_index, i] @@ -461,13 +367,116 @@ def single_layer_update(self, percdamp: float = 0.01): q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm else: q_arg = torch.zeros_like(U[group_index, :, 0]) - weight[group_index, :, i] = q_arg + weight[group_index, :, i] = q_arg.to(dtype) q_groups = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): U[group_index] -= torch.matmul( - q_groups[group_index].unsqueeze(1), + q_groups[group_index].unsqueeze(1).to(torch.float32), self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0), ) - + if hasattr(self.layer, 'offload_params'): + self.layer.offload_params(self.layer) del self.float_input del self.quant_input + + +class gpfq_mode(gpxq_mode): + """ + Apply GPFQ algorithm. + + Args: + model (Module): The model to quantize with GPFQ + group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group + of layer names that can be optimized in parallel. Default: None + inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True + create_weight_orig (bool): If True, store the original floating point weights before applying + gpfq. These weights will be used anytime quantization is disabled. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPFQ. Default: False + p (float): The percentage of processed inputs to use. Default: 1.0 + return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the + forward call inside the context manager returns None. Default: False + act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False + gpfq_class (GPFQ): The uninitialized class to perform GPFQ. Default: `brevitas.graph.gpfq.GPFQv2`, + which is the memory-efficient formulation + + Example: + >>> with torch.no_grad(): + >>> with gpfq_mode(model) as gpfq: + >>> gpfq_model = gpfq.model + >>> for i in tqdm(range(gpfq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gpfq_model(img) + >>> gpfq.update() + """ + + def __init__( + self, + model: nn.Module, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + create_weight_orig: bool = True, + use_quant_activations: bool = True, + p: float = 1.0, + return_forward_output: bool = False, + act_order: bool = False, + gpfq_class: GPFQ = GPFQv2) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + create_weight_orig, + use_quant_activations, + act_order, + return_forward_output) + + self.p = p + self.gpfq_class = gpfq_class + + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Disable quantization + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + def initialize_module_optimizer( + self, layer, name, act_order, len_parallel_layers, create_weight_orig): + return self.gpfq_class( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + p=self.p) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index a1380da4e..667e47d40 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -23,85 +23,6 @@ import brevitas.nn as qnn -class gptq_mode(gpxq_mode): - """ - Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. - - Args: - model (Module): The model to quantize with GPTQ - group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group - of layer names that can be optimized in parallel. Default: None - inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True - create_weight_orig (bool): If True, store the original floating point weights before applying - gptq. These weights will be used anytime quantization is disabled. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPTQ. Default: False - num_blocks (int): The number of sub-blocks to use to speed-up GPTQ computation. Default: 100 - act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False - return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the - forward call inside the context manager returns None. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gptq_mode(model) as gptq: - >>> gptq_model = gptq.model - >>> for i in tqdm(range(gptq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gptq_model(img) - >>> gptq.update() - """ - - def __init__( - self, - model, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - create_weight_orig: bool = True, - use_quant_activations: bool = True, - num_blocks: int = 100, - return_forward_output: bool = False, - act_order: bool = False) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - create_weight_orig, - use_quant_activations, - act_order, - return_forward_output) - - # How many subblock to use during GPTQ for each layer - self.num_blocks = num_blocks - - def catch_stopfwd(self, *args, **kwargs): - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - finally: - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - def initialize_module_optimizer( - self, layer, name, act_order, len_parallel_layers, create_weight_orig): - return GPTQ( - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - num_blocks=self.num_blocks) - - class GPTQ(GPxQ): """ Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: @@ -275,7 +196,7 @@ def single_layer_update(self, percdamp=.01): q_groups = self.get_quant_weights(i, i1, permutation_list) # [groups, OC/groups] for group_index in range(self.groups): perm = permutation_list[group_index] - q = q_groups[group_index] # [OC/groups] + q = q_groups[group_index].to(torch.float32) # [OC/groups] w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] d = h_inv_block[group_index, i, i] # [1] error = (w - q) / d # [OC/groups] @@ -292,3 +213,85 @@ def single_layer_update(self, percdamp=.01): i2:].to(dev))).to(dtype) if hasattr(self.layer, 'offload_params'): self.layer.offload_params(self.layer) + + +class gptq_mode(gpxq_mode): + """ + Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. + + Args: + model (Module): The model to quantize with GPTQ + group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group + of layer names that can be optimized in parallel. Default: None + inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + create_weight_orig (bool): If True, store the original floating point weights before applying + gptq. These weights will be used anytime quantization is disabled. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPTQ. Default: False + num_blocks (int): The number of sub-blocks to use to speed-up GPTQ computation. Default: 100 + act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False + return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the + forward call inside the context manager returns None. Default: False + gptq_class (GPTQ): The uninitialized class to perform GPTQ. Default: `brevitas.graph.gptq.GPTQ` + + Example: + >>> with torch.no_grad(): + >>> with gptq_mode(model) as gptq: + >>> gptq_model = gptq.model + >>> for i in tqdm(range(gptq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gptq_model(img) + >>> gptq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + create_weight_orig: bool = True, + use_quant_activations: bool = True, + num_blocks: int = 100, + return_forward_output: bool = False, + act_order: bool = False, + gptq_class: GPTQ = GPTQ) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + create_weight_orig, + use_quant_activations, + act_order, + return_forward_output) + + # How many subblock to use during GPTQ for each layer + self.num_blocks = num_blocks + self.gptq_class = gptq_class + + def catch_stopfwd(self, *args, **kwargs): + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + finally: + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + def initialize_module_optimizer( + self, layer, name, act_order, len_parallel_layers, create_weight_orig): + return self.gptq_class( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + num_blocks=self.num_blocks) diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py new file mode 100644 index 000000000..39e22535d --- /dev/null +++ b/src/brevitas_examples/common/axe.py @@ -0,0 +1,436 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import math +import warnings + +import numpy as np +import torch +from torch import Tensor + +try: + from torch.linalg import LinAlgError +except: + LinAlgError = RuntimeError + +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +from brevitas.graph.gpxq import SUPPORTED_TCONV_OP + + +def _get_average_of_nonzero_magnitudes(vec: np.ndarray, radius: float = 1.0): + assert radius > 0, "Error: radius needs to be strictly positive." + assert vec.ndim == 1, "Error: projection assumes a vector, not a matrix." + assert vec.min() >= 0, "Error: assuming a vector of non-negative numbers." + n_elems = vec.shape[0] + # if we are already within the simplex, then the best projection is itself + if vec.sum() <= radius: + return 0.0 + # using algorithm detailed in "Efficient Projections onto the L1-Ball for Learning in High Dimensions" + v = vec + u = np.sort(v)[::-1] + cumsum_u = np.cumsum(u) + rho = np.nonzero(u * np.arange(1, n_elems + 1) > (cumsum_u - radius))[0][-1] + theta = float(cumsum_u[rho] - radius) / (rho + 1) + return theta + + +def calc_average_nonzero_mag(weight: Tensor, lim: Tensor) -> Tensor: + thetas = torch.zeros(weight.shape[0], device=weight.device) + for i in range(weight.shape[0]): + l = lim[i].item() if lim.ndim > 0 else lim.item() + w = weight[i].cpu().detach().numpy() + t = _get_average_of_nonzero_magnitudes(np.abs(w), l) + thetas[i] = t + return thetas + + +def pad_tensor_with_zeros(tensor: Tensor, tile_size: int) -> Tensor: + pad_size = tile_size - (tensor.shape[1] % tile_size) + if pad_size == tile_size: + return tensor + padding = torch.zeros((tensor.shape[0], pad_size), device=tensor.device) + pad_tensor = torch.concat([tensor, padding], axis=1) + return pad_tensor + + +class A2GPTQ(GPTQ): + """ + Accumulator-aware GPTQ as proposed in https://arxiv.org/pdf/2409.17092 + """ + + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + num_blocks, + max_accumulator_bit_width, + max_accumulator_tile_size) -> None: + super().__init__( + layer, name, act_order, len_parallel_layers, create_weight_orig, num_blocks) + self.max_accumulator_bit_width = max_accumulator_bit_width + self.max_accumulator_tile_size = max_accumulator_tile_size + if self.max_accumulator_tile_size is None: + self.max_accumulator_tile_size = self.columns + assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2." + assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2." + + def single_layer_update(self, percdamp=0.01): + assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." + if self.quant_metadata is None: + raise ValueError( + "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. " + "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. " + "Also, check if `use_quant_activations=True` in `gptq_mode` when `max_accumulator_bit_width` is specified. " + ) + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) + weight = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # TODO: add support for signed input activations + if self.quant_metadata.signed: + raise NotImplementedError("Signed inputs not yet supported.") + + # TODO: currently assuming round-to-zero; need to handle other rounding functions + rounding_mode = self.layer.weight_quant.rounding_mode + if rounding_mode.lower() != "round": + raise NotImplementedError(f"{rounding_mode} not yet supported.") + + n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) + scales: Tensor = self.layer.weight_quant.scale() + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + scales = scales.transpose(1, 0) # This performs a view + scales = scales.flatten(1) + P = torch.tensor(self.max_accumulator_bit_width) + N = self.quant_metadata.bit_width + # NOTE: using sign-magnitude here, which is sufficient to support both + # sign-magnitude and 2s complement accumulators + self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) # A + self.lower_lim = -self.upper_lim # B + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # l1-norm lim for zero-centered weight vector + # translating into the quantized range; need to pad to get these thresholds + wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] + thresholds = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = thresholds.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] + del wT + # supporting groupwise quantization where each tile has its own scaling factor + if self.layer.weight_quant.is_groupwise: + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:, 0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] + # else each tile has the same scaling factor (per-tensor or per-channel) + else: + scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups] + scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] + thresholds *= scales # translating centers back to the float range + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + + # List with permutation tensors for the Hessian and weight matrix. + # If act_order is False, the tensors will be ordered indexes. + # For groupwise convolution, we have one tensor per group, + # thus len(permutation_list) is always equal to self.groups. + # We do not explicity permute the weight matrix, only the Hessian. + permutation_list = [] + weight = weight.view(self.groups, -1, weight.shape[-1]) + # For groupwise convolution, these operations are groupwise so we iterate + for i in range(self.groups): + # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding + # column in the weight matrix. + # The diagonal element is set to 1 to avoid division-by-zero + dead = torch.diag(self.H[i, :, :]) == 0 + self.H[i, dead, dead] = 1 + # If the diagonal of activations is zero, we set the weight to zero + weight[i, :, dead] = 0 + if self.act_order: + # Re-order Hessian so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True) + self.H[i, :, :] = self.H[i, perm, :][:, perm] + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(self.H.shape[-1]), device=dev) + permutation_list.append(perm) + + # Try/Except in case the inverse Hessian cannot be computed + try: + for i in range(self.groups): + damp = percdamp * torch.mean(torch.diag(self.H[i, :, :])) + diag = torch.arange(self.columns, device='cpu') + self.H[i, diag, diag] += damp + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) + self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) + h_inv = self.H + except LinAlgError: + warnings.warn( + f'Failed to compute the inverse of the Hessian for layer {self.name} ' + f'GPTQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H, self.B + + # initialize cumulative l1-norm + a = torch.zeros_like(thresholds, device=dev) # positive limits + b = torch.zeros_like(thresholds, device=dev) # negative limits + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + error_block = torch.zeros_like( + weight[:, :, permutation_list[-1][i1:i2]], + dtype=torch.float32) # [groups, OC/groups, i2-i1] + + h_inv_block = h_inv[:, i1:i2, i1:i2] + for i in range(count): + # need to apply soft thresholding and clamping before quantization + for group_index in range(self.groups): + perm = permutation_list[group_index] + bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index + # calculate the q_max and q_min for the right group and right block + q_max = scales[group_index, bx, :] * torch.clamp_min( + self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups] + q_min = scales[group_index, bx, :] * torch.clamp_max( + self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) # [OC/groups] + q_arg = weight[group_index, :, perm[i1:i2][i]] # [OC/groups] + # soft thresholding then clamping + q_arg = q_arg.sign() * torch.relu( + q_arg.abs() - thresholds[group_index, bx]) # [OC/groups] + q_arg.clamp_(q_min, q_max) # clamping to bounds + weight[group_index, :, perm[i1:i2][i]] = q_arg + q_groups = self.get_quant_weights(i, i1, permutation_list) # [Groups, OC/groups] + for group_index in range(self.groups): + perm = permutation_list[group_index] + q = q_groups[group_index].to(torch.float32) # [OC/groups] + w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] + d = h_inv_block[group_index, i, i] # [1] + error = (w - q) / d # [OC/groups] + error_block[group_index, :, i] = error + # We need to update the original weights + weight[group_index, :, perm[i1:i2][i:]] -= ( + error.unsqueeze(1).matmul( + h_inv_block[group_index, i, i:].unsqueeze(0).to(dev))).to(dtype) + # update the tracking mechanisms + for group_index in range(self.groups): + perm = permutation_list[group_index] + bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index + q = q_groups[group_index] / scales[group_index, bx] # [OC/groups] + # increment cumulative l1-norm + a[group_index, bx, q >= 0] += q[q >= 0] + b[group_index, bx, q <= 0] += q[q <= 0] + assert (a <= self.upper_lim).all() and (a >= 0).all() + assert (b >= self.lower_lim).all() and (b <= 0).all() + + for group_index in range(self.groups): + perm = permutation_list[group_index] + weight[group_index, :, perm[i2:]] -= ( + error_block[group_index].matmul(h_inv[group_index, i1:i2, + i2:].to(dev))).to(dtype) + if hasattr(self.layer, "offload_params"): + self.layer.offload_params(self.layer) + + del thresholds, scales # memory management + + +class A2GPFQ(GPFQv2): + """ + Memory-efficient, accumulator-aware GPFQ as proposed in https://arxiv.org/pdf/2409.17092 + """ + + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + p, + max_accumulator_bit_width, + max_accumulator_tile_size) -> None: + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig, p) + self.max_accumulator_bit_width = max_accumulator_bit_width + self.max_accumulator_tile_size = max_accumulator_tile_size + if self.max_accumulator_tile_size is None: + self.max_accumulator_tile_size = self.columns + assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2." + assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2." + + def single_layer_update(self, percdamp=0.01): + assert not self.layer.weight_quant.requires_quant_input, \ + "Error: GPFQ does not support weight quantizers that require quantized inputs." + if self.quant_metadata is None: + raise ValueError( + "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. " + "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. " + "Also, check if `use_quant_activations=True` in `gpfq_mode` when `max_accumulator_bit_width` is specified. " + ) + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) + weight: Tensor = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # TODO: add support for signed input activations + if self.quant_metadata.signed: + raise NotImplementedError("Signed inputs not yet supported.") + + # TODO: currently assuming round-to-zero; need to handle other rounding functions + rounding_mode = self.layer.weight_quant.rounding_mode + if rounding_mode.lower() != "round": + raise NotImplementedError(f"{rounding_mode} not yet supported.") + + n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) + scales: Tensor = self.layer.weight_quant.scale() + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + scales = scales.transpose(1, 0) # This performs a view + scales = scales.flatten(1) + P = torch.tensor(self.max_accumulator_bit_width) + N = self.quant_metadata.bit_width + # NOTE: using sign-magnitude here, which is sufficient to support both + # sign-magnitude and 2s complement accumulators + self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) # A + self.lower_lim = -self.upper_lim # B + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # l1-norm lim for zero-centered weight vector + # translating into the quantized range; need to pad to get these thresholds + wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] + thresholds = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = thresholds.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] + del wT + # supporting groupwise quantization where each tile has its own scaling factor + if self.layer.weight_quant.is_groupwise: + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:, 0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] + # else each tile has the same scaling factor (per-tensor or per-channel) + else: + scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups] + scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] + thresholds *= scales # translating centers back to the float range + + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + + # initialize cumulative l1-norm + a = torch.zeros_like(thresholds, device=dev) # positive limit + b = torch.zeros_like(thresholds, device=dev) # negative limit + + # Try/Except in case the square root of H cannot be computed + try: + norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32) + self.H = self.H.to(dev) + diag = torch.arange(self.columns, device='cpu') + for i in range(self.groups): + # stablize H with a dampening factor and then square root the matrix + damp = percdamp * self.H[i].diag().mean() + self.H[i, diag, diag] += damp + norms[i] = self.H[i].diag() # set the norms post-dampening + eigvals, eigvecs = torch.linalg.eigh(self.H[i]) + eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite + self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t() + del eigvecs, eigvals, diag + self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function + except LinAlgError: + warnings.warn( + f'Failed to compute the matrix square root of H for layer {self.name} ' + f'GPFQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + + # Try/Except in case the inverse of H cannot be computed + try: + self.float_input = self.H.clone() # going to calculate H^{-1} here + for i in range(self.groups): + # from our matrix sqrt, we know G is symmetric and positive-definite, so we + # can use Cholesky decomposition as an efficient, numerically stable inverse + L = torch.linalg.cholesky(self.float_input[i]) + self.float_input[i] = torch.cholesky_inverse(L) + self.float_input = torch.bmm(self.float_input.to(dev), self.G.to(dev)) + del L # memory management + except LinAlgError: + warnings.warn( + f'Failed to compute the inverse of H for layer {self.name} ' + f'GPFQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H, self.G, self.B # memory management + + permutation_list = self._get_permutation_list(weight) + + U = torch.zeros( + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, + dtype=torch.float32) # [Groups, OC/groups, Samples] + + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + i = permutation_list[group_index][t] + U[group_index] += torch.matmul( + weight[group_index, :, i].unsqueeze(1).to(torch.float32), + self.float_input[group_index, :, i].unsqueeze(0)) + norm = norms[group_index, i] + if norm > 0: + q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + bx = i // self.max_accumulator_tile_size # block index + q_arg = q_arg.sign() * torch.relu( + q_arg.abs() - thresholds[group_index, bx, :]) # soft thresholding + q_max = scales[group_index, bx] * torch.clamp_min( + self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) + q_min = scales[group_index, bx] * torch.clamp_max( + self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) + q_arg.clamp_(q_min, q_max) + weight[group_index, :, i] = q_arg.to(dtype) + q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list) + for group_index in range(self.groups): + i = permutation_list[group_index][t] + U[group_index] -= torch.matmul( + q_groups[group_index].unsqueeze(1).to(torch.float32), + self.quant_input[group_index, :, i].unsqueeze(0)) + bx = i // self.max_accumulator_tile_size # block index + q = q_groups[group_index] / scales[group_index, bx] # [OC/groups] + # increment cumulative l1-norm + a[group_index, bx, q >= 0] += q[q >= 0] + b[group_index, bx, q <= 0] += q[q <= 0] + assert (a <= self.upper_lim).all() and (a >= 0).all() + assert (b >= self.lower_lim).all() and (b <= 0).all() + + del self.quant_input, self.float_input diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 0151c9232..38ed85678 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -1,11 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from copy import deepcopy +from functools import partial import math import torch -import torch.backends.cudnn as cudnn from tqdm import tqdm from brevitas.core.function_wrapper.shape import OverBatchOverTensorView @@ -16,6 +15,8 @@ from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize @@ -60,7 +61,6 @@ from brevitas.quant.scaled_int import Int32Bias from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFixedPoint from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat -from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO @@ -68,6 +68,8 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE +from brevitas_examples.common.axe import A2GPFQ +from brevitas_examples.common.axe import A2GPTQ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator @@ -574,12 +576,32 @@ def apply_act_equalization(model, calib_loader, layerwise): model(images) -def apply_gptq(calib_loader, model, act_order=False): +def apply_gptq( + calib_loader, + model, + act_order=False, + use_quant_activations=False, + create_weight_orig=False, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gptq_class = partial( + A2GPTQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gptq_class = GPTQ model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq: + with gptq_mode(model, + act_order=act_order, + use_quant_activations=use_quant_activations, + create_weight_orig=create_weight_orig, + gptq_class=gptq_class) as gptq: gptq_model = gptq.model for i in tqdm(range(gptq.num_layers)): for i, (images, target) in enumerate(calib_loader): @@ -593,21 +615,27 @@ def apply_gpfq( calib_loader, model, act_order, - p=1.0, - use_gpfa2q=False, - accumulator_bit_width=None, - compression_rate=0.0): + create_weight_orig=False, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gpfq_class = partial( + A2GPFQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gpfq_class = GPFQv2 with torch.no_grad(): with gpfq_mode(model, - p=p, + create_weight_orig=create_weight_orig, use_quant_activations=True, act_order=act_order, - use_gpfa2q=use_gpfa2q, - accumulator_bit_width=accumulator_bit_width, - compression_rate=compression_rate) as gpfq: + gpfq_class=gpfq_class) as gpfq: gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 8a70e29ba..34bdfbc96 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -233,10 +233,15 @@ def validate_args(args): type=int, help='Exponent bit width used with float quantization for activations (default: 3)') parser.add_argument( - '--accumulator-bit-width', + '--gpxq-accumulator-bit-width', default=None, type=int, - help='Accumulator Bit Width for GPFA2Q (default: None)') + help='Accumulator Bit Width for GPxQ (default: None)') +parser.add_argument( + '--gpxq-accumulator-tile-size', + default=None, + type=int, + help='Accumulator tile size for GPxQ (default: None)') parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version') parser.add_argument( '--channel-splitting-ratio', @@ -245,17 +250,20 @@ def validate_args(args): help= 'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)' ) -parser.add_argument( - '--compression-rate', - default=0.0, - type=float, - help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.' -) add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') -add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') add_bool_arg( parser, 'gpxq-act-order', default=False, help='GPxQ Act order heuristic (default: disabled)') +add_bool_arg( + parser, + 'gptq-use-quant-activations', + default=False, + help='Use quant activations for GPTQ (default: disabled)') +add_bool_arg( + parser, + 'gpxq-create-weight-orig', + default=False, + help='Maintain original weights for non-quant forward pass (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)') add_bool_arg( @@ -270,7 +278,7 @@ def validate_args(args): help='Merge BN layers before quantizing the model (default: enabled)') add_bool_arg( parser, - 'uint_sym_act_for_unsigned_values', + 'uint-sym-act-for-unsigned-values', default=True, help='Use unsigned act quant when possible (default: enabled)') add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)') @@ -312,7 +320,6 @@ def main(): f"w{args.weight_bit_width}_" f"{'gptq_' if args.gptq else ''}" f"{'gpfq_' if args.gpfq else ''}" - f"{'gpfa2q_' if args.gpfa2q else ''}" f"{'gpxq_act_order_' if args.gpxq_act_order else ''}" f"{'learned_round_' if args.learned_round else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" @@ -335,10 +342,8 @@ def main(): f"Weight bit width: {args.weight_bit_width} - " f"GPTQ: {args.gptq} - " f"GPFQ: {args.gpfq} - " - f"GPFA2Q: {args.gpfa2q} - " - f"GPFQ P: {args.gpfq_p} - " f"GPxQ Act Order: {args.gpxq_act_order} - " - f"GPFA2Q Accumulator Bit Width: {args.accumulator_bit_width} - " + f"GPxQ Accumulator Bit Width: {args.gpxq_accumulator_bit_width} - " f"Learned Round: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - " @@ -412,7 +417,9 @@ def main(): if args.act_equalization is not None: print("Applying activation equalization:") apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise') + device = next(iter(model.parameters())).device + # Define the quantized model quant_model = quantize_model( model, @@ -452,24 +459,21 @@ def main(): apply_gpfq( calib_loader, quant_model, - p=args.gpfq_p, act_order=args.gpxq_act_order, - compression_rate=args.compression_rate) + create_weight_orig=args.gpxq_create_weight_orig, + max_accumulator_bit_width=args.gpxq_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_accumulator_tile_size) - if args.gpfa2q: - print("Performing GPFA2Q:") - apply_gpfq( + if args.gptq: + print("Performing GPTQ:") + apply_gptq( calib_loader, quant_model, - p=args.gpfq_p, act_order=args.gpxq_act_order, - use_gpfa2q=args.gpfa2q, - accumulator_bit_width=args.accumulator_bit_width, - compression_rate=args.compression_rate) - - if args.gptq: - print("Performing GPTQ:") - apply_gptq(calib_loader, quant_model, act_order=args.gpxq_act_order) + use_quant_activations=args.gptq_use_quant_activations, + create_weight_orig=args.gpxq_create_weight_orig, + max_accumulator_bit_width=args.gpxq_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_accumulator_tile_size) if args.learned_round: print("Applying Learned Round:") diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 457c74804..5cd067e64 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -34,6 +34,9 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--input-quant-granularity {per_tensor,per_row,per_group}] [--input-group-size INPUT_GROUP_SIZE] [--quantize-input-zero-point] [--quantize-last-layer] [--gptq] + [--gpfq] [--gpxq-act-order] [--gpxq-use-quant-activations] [--gpxq-create-weight-orig] + [--gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH] + [--gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE] [--act-calibration] [--bias-corr] [--ln-affine-merge] [--no-quantize] [--no-float16] [--replace-mha] [--weight-equalization] @@ -105,6 +108,16 @@ options: --quantize-last-layer Quantize last nn.Linear layer. --gptq Apply GPTQ. + --gpfq Apply GPFQ. + --gpxq-act-order Apply GPxQ activation ordering. + --gpxq-use-quant-activations + Use quantized activations in GPxQ. + --gpxq-create-weight-orig + Create weight_orig in GPxQ. + --gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH + Maximum accumulator bit width for GPxQ using AXE. + --gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE + Maximum accumulator tile size for GPxQ using AXE. --act-calibration Apply activation calibration. --bias-corr Apply bias correction. --ln-affine-merge Merge LN affine params. diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 44b99772f..5e61306d4 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -1,9 +1,8 @@ -""" -Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -""" from copy import deepcopy +from functools import partial from accelerate.utils.operations import send_to_device import torch @@ -13,9 +12,13 @@ from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode from brevitas.graph.gpxq import StopFwdException from brevitas.utils.python_utils import recurse_getattr +from brevitas_examples.common.axe import A2GPFQ +from brevitas_examples.common.axe import A2GPTQ @torch.no_grad() @@ -109,20 +112,33 @@ def apply_gptq( use_quant_activations=False, create_weight_orig=False, group_of_parallel_layers=None, - block_name=None): + block_name=None, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gptq_class = partial( + A2GPTQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gptq_class = GPTQ if block_name is not None: context_manager_kwargs = { 'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': create_weight_orig, - 'use_quant_activations': use_quant_activations} + 'use_quant_activations': use_quant_activations, + 'gptq_class': gptq_class} block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs) else: with gptq_mode(model, use_quant_activations=use_quant_activations, group_of_parallel_layers=group_of_parallel_layers, act_order=act_order, - create_weight_orig=create_weight_orig) as gptq: + create_weight_orig=create_weight_orig, + gptq_class=gptq_class) as gptq: gptq_model = gptq.model for _ in tqdm(range(gptq.num_layers)): for inps in dataloader: @@ -131,14 +147,36 @@ def apply_gptq( @torch.no_grad() -def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): +def apply_gpfq( + model, + dataloader, + act_order=True, + group_of_parallel_layers=None, + block_name=None, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gpfq_class = partial( + A2GPFQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gpfq_class = GPFQv2 if block_name is not None: - raise RuntimeError("Block optimization not support for GPFQ at the moment") + context_manager_kwargs = { + 'act_order': act_order, + 'group_of_parallel_layers': group_of_parallel_layers, + 'create_weight_orig': True, + 'gpfq_class': gpfq_class} + block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs) else: with gpfq_mode(model, act_order=act_order, group_of_parallel_layers=group_of_parallel_layers, - create_weight_orig=True) as gpfq: + create_weight_orig=True, + gpfq_class=gpfq_class) as gpfq: gpfq_model = gpfq.model for _ in tqdm(range(gpfq.num_layers)): for inps in dataloader: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bf995a426..4a87f5a1a 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -1,7 +1,5 @@ -""" -Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -""" import argparse import sys @@ -74,6 +72,20 @@ def validate(args): if not args.no_quantize: if args.gptq and args.gpfq: warn("Both GPTQ and GPFQ are enabled.") + if args.gpxq_max_accumulator_bit_width is not None: + assert args.weight_quant_format == 'int', "AXE only supports integer formats." + assert args.input_quant_format == 'int', "AXE only supports integer formats." + assert args.input_bit_width is not None, \ + "Specify input bit width; activation quantization is required to guarantee accumulator bounds." + if not (args.gptq or args.gpfq): + warn("Max accumulator bit width is specified, but no GPxQ is enabled.") + if args.gpxq_max_accumulator_tile_size is not None: + if args.weight_quant_granularity == 'per_group': + assert args.gpxq_max_accumulator_tile_size == args.weight_group_size, \ + "Group size must be equal to tile size with per_group quantization." + if args.input_quant_granularity == 'per_group': + assert args.gpxq_max_accumulator_tile_size == args.input_group_size, \ + "Group size must be equal to tile size with per_group quantization." if args.export_target is not None: assert args.input_quant_format == 'int', "Only integer quantization supported for export currently." if args.export_target is not None and args.input_bit_width is not None: @@ -158,8 +170,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=args.fuse_sequences, - ) + fuse_sequences=args.fuse_sequences) validation_loader = get_dataset_for_model( args.model, @@ -171,8 +182,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=args.fuse_sequences, - ) + fuse_sequences=args.fuse_sequences) device = next(iter(model.parameters())).device print("Data loaded.") @@ -287,7 +297,9 @@ def main(args): act_order=args.gpxq_act_order, use_quant_activations=args.gpxq_use_quant_activations, create_weight_orig=args.gpxq_create_weight_orig, - block_name=args.gpxq_block_name) + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPTQ applied.") if args.gpfq: @@ -296,7 +308,9 @@ def main(args): model, calibration_loader, act_order=args.gpxq_act_order, - block_name=args.gpxq_block_name) + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPFQ applied.") if args.bias_corr: @@ -304,7 +318,7 @@ def main(args): apply_bias_correction(model, calibration_loader) print("Bias correction applied.") - if args.eval: + if args.eval and not args.no_quantize: print("Model eval...") quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) @@ -455,13 +469,23 @@ def parse_args(args): parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') parser.add_argument('--gpfq', action='store_true', help='Apply GPFQ.') parser.add_argument( - '--gpxq-act-order', action='store_true', help='Apply GPXQ activation ordering.') + '--gpxq-act-order', action='store_true', help='Apply GPxQ activation ordering.') parser.add_argument( '--gpxq-use-quant-activations', action='store_true', - help='Use quantized activations in GPXQ.') + help='Use quantized activations in GPxQ.') parser.add_argument( - '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPXQ.') + '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPxQ.') + parser.add_argument( + '--gpxq-max-accumulator-bit-width', + type=int, + default=None, + help='Maximum accumulator bit width for GPxQ using AXE.') + parser.add_argument( + '--gpxq-max-accumulator-tile-size', + type=int, + default=None, + help='Maximum accumulator tile size for GPxQ using AXE.') parser.add_argument( '--act-calibration', action='store_true', help='Apply activation calibration.') parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.')