From 2de7c794b44da29b87ff124392bd61b8b76f49c0 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 15 Oct 2024 15:18:13 +0000 Subject: [PATCH 01/19] Feat (gptq): parameterize gptq_class --- src/brevitas/graph/gptq.py | 163 +++++++++++++++++++------------------ 1 file changed, 84 insertions(+), 79 deletions(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index a1380da4e..b48cde64d 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -23,85 +23,6 @@ import brevitas.nn as qnn -class gptq_mode(gpxq_mode): - """ - Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. - - Args: - model (Module): The model to quantize with GPTQ - group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group - of layer names that can be optimized in parallel. Default: None - inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True - create_weight_orig (bool): If True, store the original floating point weights before applying - gptq. These weights will be used anytime quantization is disabled. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPTQ. Default: False - num_blocks (int): The number of sub-blocks to use to speed-up GPTQ computation. Default: 100 - act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False - return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the - forward call inside the context manager returns None. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gptq_mode(model) as gptq: - >>> gptq_model = gptq.model - >>> for i in tqdm(range(gptq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gptq_model(img) - >>> gptq.update() - """ - - def __init__( - self, - model, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - create_weight_orig: bool = True, - use_quant_activations: bool = True, - num_blocks: int = 100, - return_forward_output: bool = False, - act_order: bool = False) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - create_weight_orig, - use_quant_activations, - act_order, - return_forward_output) - - # How many subblock to use during GPTQ for each layer - self.num_blocks = num_blocks - - def catch_stopfwd(self, *args, **kwargs): - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - finally: - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - def initialize_module_optimizer( - self, layer, name, act_order, len_parallel_layers, create_weight_orig): - return GPTQ( - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - num_blocks=self.num_blocks) - - class GPTQ(GPxQ): """ Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: @@ -292,3 +213,87 @@ def single_layer_update(self, percdamp=.01): i2:].to(dev))).to(dtype) if hasattr(self.layer, 'offload_params'): self.layer.offload_params(self.layer) + + +class gptq_mode(gpxq_mode): + """ + Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. + + Args: + model (Module): The model to quantize with GPTQ + group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group + of layer names that can be optimized in parallel. Default: None + inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + create_weight_orig (bool): If True, store the original floating point weights before applying + gptq. These weights will be used anytime quantization is disabled. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPTQ. Default: False + num_blocks (int): The number of sub-blocks to use to speed-up GPTQ computation. Default: 100 + act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False + return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the + forward call inside the context manager returns None. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gptq_mode(model) as gptq: + >>> gptq_model = gptq.model + >>> for i in tqdm(range(gptq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gptq_model(img) + >>> gptq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + create_weight_orig: bool = True, + use_quant_activations: bool = True, + num_blocks: int = 100, + return_forward_output: bool = False, + act_order: bool = False, + gptq_class: Optional[GPxQ] = None) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + create_weight_orig, + use_quant_activations, + act_order, + return_forward_output) + + # How many subblock to use during GPTQ for each layer + self.num_blocks = num_blocks + if gptq_class is None: + gptq_class = GPTQ + self.gptq_class = gptq_class + + def catch_stopfwd(self, *args, **kwargs): + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + finally: + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + def initialize_module_optimizer( + self, layer, name, act_order, len_parallel_layers, create_weight_orig): + return self.gptq_class( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + num_blocks=self.num_blocks) + From 87ccb88cf27d7101c04f6aa7a2c3f4a398011c93 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 15 Oct 2024 15:19:09 +0000 Subject: [PATCH 02/19] Pre-commit fixes --- src/brevitas/graph/gptq.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index b48cde64d..ab1384654 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -296,4 +296,3 @@ def initialize_module_optimizer( len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, num_blocks=self.num_blocks) - From 9c57beab737beed9a231509df23b52005304b699 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 15 Oct 2024 15:20:42 +0000 Subject: [PATCH 03/19] Feat (gpfq): parameter allocation/offloading --- src/brevitas/graph/gpfq.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 33ee5fbb4..2af517df3 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -70,7 +70,7 @@ def __init__( p: float = 1.0, return_forward_output: bool = False, act_order: bool = False, - gpfq_class: Optional[nn.Module] = None) -> None: + gpfq_class: Optional[GPxQ] = None) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -86,8 +86,6 @@ def __init__( if gpfq_class is None: gpfq_class = GPFQ self.gpfq_class = gpfq_class - assert issubclass(gpfq_class, GPxQ), \ - "Error: expected `gpfq_class` to be derived from `brevitas.graph.gpxq.GPxQ`." def catch_stopfwd(self, *args, **kwargs): # Collect quant input @@ -401,6 +399,8 @@ def _get_permutation_list(self, weight: Tensor): def single_layer_update(self, percdamp: float = 0.01): assert not self.layer.weight_quant.requires_quant_input, \ "Error: GPFQ does not support weight quantizers that require quantized inputs." + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) weight = self.layer.weight.data dev = weight.device dtype = weight.dtype @@ -468,6 +468,7 @@ def single_layer_update(self, percdamp: float = 0.01): q_groups[group_index].unsqueeze(1), self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0), ) - + if hasattr(self.layer, 'offload_params'): + self.layer.offload_params(self.layer) del self.float_input del self.quant_input From 26d7c95b40a8b4e9d4dee664304eb0c8a07368d1 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 15 Oct 2024 15:25:24 +0000 Subject: [PATCH 04/19] Feat (axe): adding accumulator-aware extensions for GPxQ --- src/brevitas_examples/llm/llm_quant/axe.py | 382 ++++++++++++++++++++ src/brevitas_examples/llm/llm_quant/gpxq.py | 51 ++- src/brevitas_examples/llm/main.py | 36 +- 3 files changed, 448 insertions(+), 21 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/axe.py diff --git a/src/brevitas_examples/llm/llm_quant/axe.py b/src/brevitas_examples/llm/llm_quant/axe.py new file mode 100644 index 000000000..c7c96af67 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/axe.py @@ -0,0 +1,382 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import math +import warnings + +import numpy as np +import torch +from torch import Tensor + +try: + from torch.linalg import LinAlgError +except: + LinAlgError = RuntimeError + +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +from brevitas.graph.gpxq import SUPPORTED_TCONV_OP + + +def _get_average_of_nonzero_magnitudes(vec: np.ndarray, radius: float = 1.0): + assert radius > 0, "Error: radius needs to be strictly positive." + assert vec.ndim == 1, "Error: projection assumes a vector, not a matrix." + assert vec.min() >= 0, "Error: assuming a vector of non-negative numbers." + n_elems = vec.shape[0] + # if we are already within the simplex, then the best projection is itself + if vec.sum() <= radius: + return 0.0 + # using algorithm detailed in "Efficient Projections onto the L1-Ball for Learning in High Dimensions" + v = vec + u = np.sort(v)[::-1] + cumsum_u = np.cumsum(u) + rho = np.nonzero(u * np.arange(1, n_elems + 1) > (cumsum_u - radius))[0][-1] + theta = float(cumsum_u[rho] - radius) / (rho + 1) + return theta + + +def calc_average_nonzero_mag(weight: Tensor, lim: Tensor) -> Tensor: + thetas = torch.zeros(weight.shape[0], device=weight.device) + for i in range(weight.shape[0]): + l = lim[i].item() if lim.ndim > 0 else lim.item() + w = weight[i].cpu().detach().numpy() + t = _get_average_of_nonzero_magnitudes(np.abs(w), l) + thetas[i] = t + return thetas + + +def pad_tensor_with_zeros(tensor: Tensor, tile_size: int) -> Tensor: + pad_size = tile_size - (tensor.shape[1] % tile_size) + if pad_size == tile_size: + return tensor + padding = torch.zeros((tensor.shape[0], pad_size), device=tensor.device) + pad_tensor = torch.concat([tensor, padding], axis=1) + return pad_tensor + + +class A2GPTQ(GPTQ): + """ + Accumulator-aware GPTQ as proposed in https://arxiv.org/pdf/2409.17092 + """ + + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + num_blocks, + max_accumulator_bit_width, + max_accumulator_tile_size) -> None: + super().__init__( + layer, name, act_order, len_parallel_layers, create_weight_orig, num_blocks) + self.max_accumulator_bit_width = max_accumulator_bit_width + self.max_accumulator_tile_size = max_accumulator_tile_size + if self.max_accumulator_tile_size is None: + self.max_accumulator_tile_size = self.columns + assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2." + assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2." + + def single_layer_update(self, percdamp=0.01): + assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." + if self.quant_metadata is None: + raise ValueError( + "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. " + "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. " + "Also, check if `use_quant_activations=True` in `gptq_mode` when `max_accumulator_bit_width` is specified. " + ) + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) + weight = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # TODO: add support for signed input activations + assert not self.quant_metadata.signed + + n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) + + s = self.layer.weight_quant.scale() + P = torch.tensor(self.max_accumulator_bit_width) + N = self.quant_metadata.bit_width + # TODO: add support for two's complement accumulator representation + A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) + B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1) + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) + # translating into the quantized range; need to pad to get these thresholds + wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] + T = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles] + T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] + s = s.view(self.groups, -1) # [Groups, OC/Groups] + T *= s # translating centers back to the float range + + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + + # List with permutation tensors for the Hessian and weight matrix. + # If act_order is False, the tensors will be ordered indexes. + # For groupwise convolution, we have one tensor per group, + # thus len(permutation_list) is always equal to self.groups. + # We do not explicity permute the weight matrix, only the Hessian. + permutation_list = [] + weight = weight.view(self.groups, -1, weight.shape[-1]) + # For groupwise convolution, these operations are groupwise so we iterate + for i in range(self.groups): + # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding + # column in the weight matrix. + # The diagonal element is set to 1 to avoid division-by-zero + dead = torch.diag(self.H[i, :, :]) == 0 + self.H[i, dead, dead] = 1 + # If the diagonal of activations is zero, we set the weight to zero + weight[i, :, dead] = 0 + if self.act_order: + # Re-order Hessian so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True) + self.H[i, :, :] = self.H[i, perm, :][:, perm] + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(self.H.shape[-1]), device=dev) + permutation_list.append(perm) + + # Try/Except in case the inverse Hessian cannot be computed + try: + for i in range(self.groups): + damp = percdamp * torch.mean(torch.diag(self.H[i, :, :])) + diag = torch.arange(self.columns, device='cpu') + self.H[i, diag, diag] += damp + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) + self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) + h_inv = self.H + except LinAlgError: + warnings.warn( + f'Failed to compute the inverse of the Hessian for layer {self.name} ' + f'GPTQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H, self.B + + # initialize cumulative l1-norm + a = torch.zeros_like(T, device=dev) # pos + b = torch.zeros_like(T, device=dev) # neg + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + error_block = torch.zeros_like( + weight[:, :, permutation_list[-1][i1:i2]], + dtype=torch.float32, + ) # [groups, OC/groups, i2-i1] + + h_inv_block = h_inv[:, i1:i2, i1:i2] + for i in range(count): + # need to apply soft thresholding and clamping before quantization + for group_index in range(self.groups): + perm = permutation_list[group_index] + bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index + # calculate the q_max and q_min for the right group and right block + # TODO: currently assuming round-to-zero; need to handle other rounding functions + q_max = s[group_index, :] * torch.clamp_min( + A - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups] + q_min = s[group_index, :] * torch.clamp_max( + B - b[group_index, bx, :] + 0.5, 0.0) # [OC/groups] + q_arg = weight[group_index, :, perm[i1:i2][i]] # [OC/groups] + # soft thresholding then clamping + q_arg = q_arg.sign() * torch.relu( + q_arg.abs() - T[group_index, bx]) # [OC/groups] + q_arg.clamp_(q_min, q_max) # clamping to bounds + weight[group_index, :, perm[i1:i2][i]] = q_arg + q_groups = self.get_quant_weights(i, i1, permutation_list) # [Groups, OC/groups] + for group_index in range(self.groups): + perm = permutation_list[group_index] + q = q_groups[group_index] # [OC/groups] + w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] + d = h_inv_block[group_index, i, i] # [1] + error = (w - q) / d # [OC/groups] + error_block[group_index, :, i] = error + # We need to update the original weights + weight[group_index, :, perm[i1:i2][i:]] -= ( + error.unsqueeze(1).matmul( + h_inv_block[group_index, i, i:].unsqueeze(0).to(dev))).to(dtype) + # update the tracking mechanisms + # TODO: need to handle non-zero zero points + for group_index in range(self.groups): + perm = permutation_list[group_index] + bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index + q = q_groups[group_index] / s[group_index] # [OC/groups] + # increment cumulative l1-norm + a[group_index, bx, q >= 0] += q[q >= 0] + b[group_index, bx, q <= 0] += q[q <= 0] + assert (a <= A).all() and (a >= 0).all() + assert (b >= B).all() and (b <= 0).all() + + for group_index in range(self.groups): + perm = permutation_list[group_index] + weight[group_index, :, perm[i2:]] -= ( + error_block[group_index].matmul(h_inv[group_index, i1:i2, + i2:].to(dev))).to(dtype) + if hasattr(self.layer, "offload_params"): + self.layer.offload_params(self.layer) + + +class A2GPFQ(GPFQv2): + """ + Memory-efficient, accumulator-aware GPFQ as proposed in https://arxiv.org/pdf/2409.17092 + """ + + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + p, + max_accumulator_bit_width, + max_accumulator_tile_size) -> None: + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig, p) + self.max_accumulator_bit_width = max_accumulator_bit_width + self.max_accumulator_tile_size = max_accumulator_tile_size + if self.max_accumulator_tile_size is None: + self.max_accumulator_tile_size = self.columns + assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2." + assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2." + + def single_layer_update(self, percdamp=0.01): + assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." + if self.quant_metadata is None: + raise ValueError( + "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. " + "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. " + "Also, check if `use_quant_activations=True` in `gpfq_mode` when `max_accumulator_bit_width` is specified. " + ) + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) + weight: Tensor = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # TODO: add support for signed input activations + assert not self.quant_metadata.signed + + n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) + + s = self.layer.weight_quant.scale() + P = torch.tensor(self.max_accumulator_bit_width) + N = self.quant_metadata.bit_width + # TODO: add support for two's complement accumulator representation + A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) + B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1) + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) + # translating into the quantized range; need to pad to get these thresholds + wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] + T = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles] + T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] + s = s.view(self.groups, -1) # [Groups, OC/Groups] + T *= s # translating centers back to the float range + + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + + # initialize cumulative l1-norm + a = torch.zeros_like(T, device=dev) # pos + b = torch.zeros_like(T, device=dev) # neg + + # stablize G with a dampening factor and then square root the matrix + norms = torch.zeros((self.groups, self.columns), device=dev, dtype=dtype) + self.H = self.H.to(dev) + diag = torch.arange(self.columns, device='cpu') + for i in range(self.groups): + damp = percdamp * self.H[i].diag().mean() + self.H[i, diag, diag] += damp + norms[i] = self.H[i].diag() # set the norms post-dampening + eigvals, eigvecs = torch.linalg.eigh(self.H[i]) + eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite + self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t() + del eigvecs, eigvals, diag + self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function + + # Try/Except in case the inverse of H cannot be computed + try: + self.float_input = self.H.clone() # going to calculate H^{-1} here + for i in range(self.groups): + # from our matrix sqrt, we know G is symmetric and positive-definite, so we + # can use Cholesky decomposition as an efficient, numerically stable inverse + L = torch.linalg.cholesky(self.float_input[i]) + self.float_input[i] = torch.cholesky_inverse(L) + self.float_input = torch.bmm(self.float_input.to(dev), self.G.to(dev)) + del L # memory management + except LinAlgError: + warnings.warn( + f'Failed to compute the inverse of H for layer {self.name} ' + f'GPFQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H, self.G, self.B # memory management + + permutation_list = self._get_permutation_list(weight) + + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, + dtype=dtype) # [Groups, OC/groups, Samples] + + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + i = permutation_list[group_index][t] + U[group_index] += torch.matmul( + weight[group_index, :, i].unsqueeze(1), + self.float_input[group_index, :, i].unsqueeze(0)) + norm = norms[group_index, i] + if norm > 0: + q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + bx = i // self.max_accumulator_tile_size # block index + q_arg = q_arg.sign() * torch.relu( + q_arg.abs() - T[group_index, bx, :]) # soft thresholding + + # TODO: assuming round to nearest; need to generally support other rounding + q_max = s[group_index] * torch.clamp_min(A - a[group_index, bx, :] - 0.5, 0.0) + q_min = s[group_index] * torch.clamp_max(B - b[group_index, bx, :] + 0.5, 0.0) + q_arg.clamp_(q_min, q_max) + weight[group_index, :, i] = q_arg + q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list) + for group_index in range(self.groups): + i = permutation_list[group_index][t] + U[group_index] -= torch.matmul( + q_groups[group_index].unsqueeze(1), + self.quant_input[group_index, :, i].unsqueeze(0)) + bx = i // self.max_accumulator_tile_size # block index + q = q_groups[group_index] / s[group_index] # [OC/groups] + # increment cumulative l1-norm + a[group_index, bx, q >= 0] += q[q >= 0] + b[group_index, bx, q <= 0] += q[q <= 0] + assert (a <= A).all() and (a >= 0).all() + assert (b >= B).all() and (b <= 0).all() + + del self.quant_input, self.float_input diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 44b99772f..b9f6d5ddd 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -1,7 +1,7 @@ -""" -Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -""" + +from functools import partial from copy import deepcopy @@ -13,10 +13,15 @@ from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode from brevitas.graph.gpxq import StopFwdException from brevitas.utils.python_utils import recurse_getattr +from .axe import A2GPFQ +from .axe import A2GPTQ + @torch.no_grad() def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs): @@ -109,20 +114,33 @@ def apply_gptq( use_quant_activations=False, create_weight_orig=False, group_of_parallel_layers=None, - block_name=None): + block_name=None, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gptq_class = partial( + A2GPTQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gptq_class = GPTQ if block_name is not None: context_manager_kwargs = { 'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': create_weight_orig, - 'use_quant_activations': use_quant_activations} + 'use_quant_activations': use_quant_activations, + 'gptq_class': gptq_class} block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs) else: with gptq_mode(model, use_quant_activations=use_quant_activations, group_of_parallel_layers=group_of_parallel_layers, act_order=act_order, - create_weight_orig=create_weight_orig) as gptq: + create_weight_orig=create_weight_orig, + gptq_class=gptq_class) as gptq: gptq_model = gptq.model for _ in tqdm(range(gptq.num_layers)): for inps in dataloader: @@ -131,14 +149,31 @@ def apply_gptq( @torch.no_grad() -def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): +def apply_gpfq( + model, + dataloader, + act_order=True, + group_of_parallel_layers=None, + block_name=None, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gpfq_class = partial( + A2GPFQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gpfq_class = GPFQv2 if block_name is not None: raise RuntimeError("Block optimization not support for GPFQ at the moment") else: with gpfq_mode(model, act_order=act_order, group_of_parallel_layers=group_of_parallel_layers, - create_weight_orig=True) as gpfq: + create_weight_orig=True, + gpfq_class=gpfq_class) as gpfq: gpfq_model = gpfq.model for _ in tqdm(range(gpfq.num_layers)): for inps in dataloader: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bf995a426..c66d7c15a 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -1,7 +1,5 @@ -""" -Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -""" import argparse import sys @@ -158,8 +156,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=args.fuse_sequences, - ) + fuse_sequences=args.fuse_sequences) validation_loader = get_dataset_for_model( args.model, @@ -171,8 +168,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=args.fuse_sequences, - ) + fuse_sequences=args.fuse_sequences) device = next(iter(model.parameters())).device print("Data loaded.") @@ -287,7 +283,9 @@ def main(args): act_order=args.gpxq_act_order, use_quant_activations=args.gpxq_use_quant_activations, create_weight_orig=args.gpxq_create_weight_orig, - block_name=args.gpxq_block_name) + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPTQ applied.") if args.gpfq: @@ -296,7 +294,9 @@ def main(args): model, calibration_loader, act_order=args.gpxq_act_order, - block_name=args.gpxq_block_name) + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPFQ applied.") if args.bias_corr: @@ -304,7 +304,7 @@ def main(args): apply_bias_correction(model, calibration_loader) print("Bias correction applied.") - if args.eval: + if args.eval and not args.no_quantize: print("Model eval...") quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) @@ -455,13 +455,23 @@ def parse_args(args): parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') parser.add_argument('--gpfq', action='store_true', help='Apply GPFQ.') parser.add_argument( - '--gpxq-act-order', action='store_true', help='Apply GPXQ activation ordering.') + '--gpxq-act-order', action='store_true', help='Apply GPxQ activation ordering.') parser.add_argument( '--gpxq-use-quant-activations', action='store_true', - help='Use quantized activations in GPXQ.') + help='Use quantized activations in GPxQ.') parser.add_argument( - '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPXQ.') + '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPxQ.') + parser.add_argument( + '--gpxq-max-accumulator-bit-width', + type=int, + default=None, + help='Maximum accumulator bit width for GPxQ using AXE.') + parser.add_argument( + '--gpxq-max-accumulator-tile-size', + type=int, + default=128, + help='Maximum accumulator tile size for GPxQ using AXE.') parser.add_argument( '--act-calibration', action='store_true', help='Apply activation calibration.') parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.') From 3a6d2e44a1e520782b3708be826fe8c17ef1ecd1 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 18 Oct 2024 05:26:50 +0000 Subject: [PATCH 05/19] Fix (gpfq): code comment --- src/brevitas/graph/gpfq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 2af517df3..4da9e3a14 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -358,7 +358,7 @@ def update_batch(self, module, input, current_layer): # if quant is not enabled, then it is the float input; if it is a float input # then a quant input has already happened and we can update G if not is_quant_enabled: - # Computing the normalized H matrix using CPU buffer + # Computing the normalized G matrix using CPU buffer self.B.copy_(self.quant_input.bmm(inp_processed.transpose(2, 1))) self.G += self.B self.quant_input = None # NOTE: set back to None now that we've used it From bcf2f498fce8262bb52b1b55adce0237c216b2ca Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 18 Oct 2024 05:27:53 +0000 Subject: [PATCH 06/19] Fix (gptq): upcast q_groups --- src/brevitas/graph/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index ab1384654..d80ee1069 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -196,7 +196,7 @@ def single_layer_update(self, percdamp=.01): q_groups = self.get_quant_weights(i, i1, permutation_list) # [groups, OC/groups] for group_index in range(self.groups): perm = permutation_list[group_index] - q = q_groups[group_index] # [OC/groups] + q = q_groups[group_index].to(torch.float32) # [OC/groups] w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] d = h_inv_block[group_index, i, i] # [1] error = (w - q) / d # [OC/groups] From b161fb8ea18b1067f165fe6c1e7472ad448d6fb2 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 18 Oct 2024 05:29:21 +0000 Subject: [PATCH 07/19] Pre-commit fixes --- src/brevitas_examples/llm/llm_quant/gpxq.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index b9f6d5ddd..595b47e3c 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -1,9 +1,8 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from functools import partial - from copy import deepcopy +from functools import partial from accelerate.utils.operations import send_to_device import torch @@ -150,10 +149,10 @@ def apply_gptq( @torch.no_grad() def apply_gpfq( - model, - dataloader, - act_order=True, - group_of_parallel_layers=None, + model, + dataloader, + act_order=True, + group_of_parallel_layers=None, block_name=None, max_accumulator_bit_width=None, max_accumulator_tile_size=128): @@ -167,7 +166,12 @@ def apply_gpfq( else: gpfq_class = GPFQv2 if block_name is not None: - raise RuntimeError("Block optimization not support for GPFQ at the moment") + context_manager_kwargs = { + 'act_order': act_order, + 'group_of_parallel_layers': group_of_parallel_layers, + 'create_weight_orig': True, + 'gpfq_class': gpfq_class} + block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs) else: with gpfq_mode(model, act_order=act_order, From ac735f3517be3ed23eb9a03e3339e7af2cabd5a7 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 18 Oct 2024 21:52:13 +0000 Subject: [PATCH 08/19] Initial AXE support for imagnet --- .../{llm/llm_quant => common}/axe.py | 33 ++++++----- .../imagenet_classification/ptq/ptq_common.py | 54 ++++++++++++----- .../ptq/ptq_evaluate.py | 58 ++++++++++--------- src/brevitas_examples/llm/llm_quant/gpxq.py | 5 +- 4 files changed, 94 insertions(+), 56 deletions(-) rename src/brevitas_examples/{llm/llm_quant => common}/axe.py (94%) diff --git a/src/brevitas_examples/llm/llm_quant/axe.py b/src/brevitas_examples/common/axe.py similarity index 94% rename from src/brevitas_examples/llm/llm_quant/axe.py rename to src/brevitas_examples/common/axe.py index c7c96af67..b156ddc88 100644 --- a/src/brevitas_examples/llm/llm_quant/axe.py +++ b/src/brevitas_examples/common/axe.py @@ -103,16 +103,18 @@ def single_layer_update(self, percdamp=0.01): weight = weight.flatten(1) # TODO: add support for signed input activations - assert not self.quant_metadata.signed + if self.quant_metadata.signed: + raise NotImplementedError("Signed inputs not yet supported.") n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) s = self.layer.weight_quant.scale() P = torch.tensor(self.max_accumulator_bit_width) N = self.quant_metadata.bit_width - # TODO: add support for two's complement accumulator representation + # NOTE: using sign-magnitude here, which is sufficient to support both + # sign-magnitude and 2s complement accumulators A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) - B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1) + B = -A Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # translating into the quantized range; need to pad to get these thresholds wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( @@ -203,7 +205,7 @@ def single_layer_update(self, percdamp=0.01): q_groups = self.get_quant_weights(i, i1, permutation_list) # [Groups, OC/groups] for group_index in range(self.groups): perm = permutation_list[group_index] - q = q_groups[group_index] # [OC/groups] + q = q_groups[group_index].to(torch.float32) # [OC/groups] w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] d = h_inv_block[group_index, i, i] # [1] error = (w - q) / d # [OC/groups] @@ -280,16 +282,18 @@ def single_layer_update(self, percdamp=0.01): weight = weight.flatten(1) # TODO: add support for signed input activations - assert not self.quant_metadata.signed + if self.quant_metadata.signed: + raise NotImplementedError("Signed inputs not yet supported.") n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) s = self.layer.weight_quant.scale() P = torch.tensor(self.max_accumulator_bit_width) N = self.quant_metadata.bit_width - # TODO: add support for two's complement accumulator representation + # NOTE: using sign-magnitude here, which is sufficient to support both + # sign-magnitude and 2s complement accumulators A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) - B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1) + B = -A Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # translating into the quantized range; need to pad to get these thresholds wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( @@ -307,7 +311,7 @@ def single_layer_update(self, percdamp=0.01): b = torch.zeros_like(T, device=dev) # neg # stablize G with a dampening factor and then square root the matrix - norms = torch.zeros((self.groups, self.columns), device=dev, dtype=dtype) + norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32) self.H = self.H.to(dev) diag = torch.arange(self.columns, device='cpu') for i in range(self.groups): @@ -342,14 +346,17 @@ def single_layer_update(self, percdamp=0.01): permutation_list = self._get_permutation_list(weight) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, - dtype=dtype) # [Groups, OC/groups, Samples] + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, + dtype=torch.float32) # [Groups, OC/groups, Samples] for t in range(weight.shape[-1]): for group_index in range(self.groups): i = permutation_list[group_index][t] U[group_index] += torch.matmul( - weight[group_index, :, i].unsqueeze(1), + weight[group_index, :, i].unsqueeze(1).to(torch.float32), self.float_input[group_index, :, i].unsqueeze(0)) norm = norms[group_index, i] if norm > 0: @@ -364,12 +371,12 @@ def single_layer_update(self, percdamp=0.01): q_max = s[group_index] * torch.clamp_min(A - a[group_index, bx, :] - 0.5, 0.0) q_min = s[group_index] * torch.clamp_max(B - b[group_index, bx, :] + 0.5, 0.0) q_arg.clamp_(q_min, q_max) - weight[group_index, :, i] = q_arg + weight[group_index, :, i] = q_arg.to(dtype) q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): i = permutation_list[group_index][t] U[group_index] -= torch.matmul( - q_groups[group_index].unsqueeze(1), + q_groups[group_index].unsqueeze(1).to(torch.float32), self.quant_input[group_index, :, i].unsqueeze(0)) bx = i // self.max_accumulator_tile_size # block index q = q_groups[group_index] / s[group_index] # [OC/groups] diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 0151c9232..38ed85678 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -1,11 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from copy import deepcopy +from functools import partial import math import torch -import torch.backends.cudnn as cudnn from tqdm import tqdm from brevitas.core.function_wrapper.shape import OverBatchOverTensorView @@ -16,6 +15,8 @@ from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize @@ -60,7 +61,6 @@ from brevitas.quant.scaled_int import Int32Bias from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFixedPoint from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat -from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO @@ -68,6 +68,8 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE +from brevitas_examples.common.axe import A2GPFQ +from brevitas_examples.common.axe import A2GPTQ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator @@ -574,12 +576,32 @@ def apply_act_equalization(model, calib_loader, layerwise): model(images) -def apply_gptq(calib_loader, model, act_order=False): +def apply_gptq( + calib_loader, + model, + act_order=False, + use_quant_activations=False, + create_weight_orig=False, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gptq_class = partial( + A2GPTQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gptq_class = GPTQ model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq: + with gptq_mode(model, + act_order=act_order, + use_quant_activations=use_quant_activations, + create_weight_orig=create_weight_orig, + gptq_class=gptq_class) as gptq: gptq_model = gptq.model for i in tqdm(range(gptq.num_layers)): for i, (images, target) in enumerate(calib_loader): @@ -593,21 +615,27 @@ def apply_gpfq( calib_loader, model, act_order, - p=1.0, - use_gpfa2q=False, - accumulator_bit_width=None, - compression_rate=0.0): + create_weight_orig=False, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gpfq_class = partial( + A2GPFQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gpfq_class = GPFQv2 with torch.no_grad(): with gpfq_mode(model, - p=p, + create_weight_orig=create_weight_orig, use_quant_activations=True, act_order=act_order, - use_gpfa2q=use_gpfa2q, - accumulator_bit_width=accumulator_bit_width, - compression_rate=compression_rate) as gpfq: + gpfq_class=gpfq_class) as gpfq: gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index fd5e5c386..b1d9821d8 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -228,10 +228,15 @@ def parse_type(v, default_type): type=int, help='Exponent bit width used with float quantization for activations (default: 3)') parser.add_argument( - '--accumulator-bit-width', + '--gpxq-accumulator-bit-width', default=None, type=int, - help='Accumulator Bit Width for GPFA2Q (default: None)') + help='Accumulator Bit Width for GPxQ (default: None)') +parser.add_argument( + '--gpxq-accumulator-tile-size', + default=None, + type=int, + help='Accumulator tile size for GPxQ (default: None)') parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version') parser.add_argument( '--channel-splitting-ratio', @@ -240,17 +245,20 @@ def parse_type(v, default_type): help= 'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)' ) -parser.add_argument( - '--compression-rate', - default=0.0, - type=float, - help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.' -) add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') -add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') add_bool_arg( parser, 'gpxq-act-order', default=False, help='GPxQ Act order heuristic (default: disabled)') +add_bool_arg( + parser, + 'gptq-use-quant-activations', + default=False, + help='Use quant activations for GPTQ (default: disabled)') +add_bool_arg( + parser, + 'gpxq-create-weight-orig', + default=False, + help='Maintain original weights for non-quant forward pass (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)') add_bool_arg( @@ -265,7 +273,7 @@ def parse_type(v, default_type): help='Merge BN layers before quantizing the model (default: enabled)') add_bool_arg( parser, - 'uint_sym_act_for_unsigned_values', + 'uint-sym-act-for-unsigned-values', default=True, help='Use unsigned act quant when possible (default: enabled)') add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)') @@ -306,7 +314,6 @@ def main(): f"w{args.weight_bit_width}_" f"{'gptq_' if args.gptq else ''}" f"{'gpfq_' if args.gpfq else ''}" - f"{'gpfa2q_' if args.gpfa2q else ''}" f"{'gpxq_act_order_' if args.gpxq_act_order else ''}" f"{'learned_round_' if args.learned_round else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" @@ -329,10 +336,8 @@ def main(): f"Weight bit width: {args.weight_bit_width} - " f"GPTQ: {args.gptq} - " f"GPFQ: {args.gpfq} - " - f"GPFA2Q: {args.gpfa2q} - " - f"GPFQ P: {args.gpfq_p} - " f"GPxQ Act Order: {args.gpxq_act_order} - " - f"GPFA2Q Accumulator Bit Width: {args.accumulator_bit_width} - " + f"GPxQ Accumulator Bit Width: {args.gpxq_accumulator_bit_width} - " f"Learned Round: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - " @@ -406,7 +411,9 @@ def main(): if args.act_equalization is not None: print("Applying activation equalization:") apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise') + device = next(iter(model.parameters())).device + # Define the quantized model quant_model = quantize_model( model, @@ -446,24 +453,21 @@ def main(): apply_gpfq( calib_loader, quant_model, - p=args.gpfq_p, act_order=args.gpxq_act_order, - compression_rate=args.compression_rate) + create_weight_orig=args.gpxq_create_weight_orig, + max_accumulator_bit_width=args.gpxq_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_accumulator_tile_size) - if args.gpfa2q: - print("Performing GPFA2Q:") - apply_gpfq( + if args.gptq: + print("Performing GPTQ:") + apply_gptq( calib_loader, quant_model, - p=args.gpfq_p, act_order=args.gpxq_act_order, - use_gpfa2q=args.gpfa2q, - accumulator_bit_width=args.accumulator_bit_width, - compression_rate=args.compression_rate) - - if args.gptq: - print("Performing GPTQ:") - apply_gptq(calib_loader, quant_model, act_order=args.gpxq_act_order) + use_quant_activations=args.gptq_use_quant_activations, + create_weight_orig=args.gpxq_create_weight_orig, + max_accumulator_bit_width=args.gpxq_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_accumulator_tile_size) if args.learned_round: print("Applying Learned Round:") diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 595b47e3c..5e61306d4 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -17,9 +17,8 @@ from brevitas.graph.gptq import gptq_mode from brevitas.graph.gpxq import StopFwdException from brevitas.utils.python_utils import recurse_getattr - -from .axe import A2GPFQ -from .axe import A2GPTQ +from brevitas_examples.common.axe import A2GPFQ +from brevitas_examples.common.axe import A2GPTQ @torch.no_grad() From 6b7dd28ca096c542decbf1960ff9daa89c0af001 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 22 Oct 2024 05:27:06 +0000 Subject: [PATCH 09/19] Fixi (gpfq): support for fp16 weights --- src/brevitas/graph/gpfq.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 4da9e3a14..c1826bb6d 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -241,7 +241,7 @@ def single_layer_update(self): self.float_input = self.float_input.to(dev) self.quant_input = self.quant_input.to(dev) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=torch.float32) # We don't need full Hessian, we just need the diagonal # Summing over batch dimension H_diag = self.quant_input.transpose(2, 1).square().sum(2) @@ -259,7 +259,7 @@ def single_layer_update(self): for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( - weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1), + weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1).to(torch.float32), self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] norm = torch.linalg.norm( @@ -270,11 +270,11 @@ def single_layer_update(self): else: q_arg = torch.zeros_like(U[group_index, :, 0]) - weight[group_index, :, permutation_list[group_index][t]] = q_arg + weight[group_index, :, permutation_list[group_index][t]] = q_arg.to(dtype) q = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): U[group_index] -= torch.matmul( - q[group_index].unsqueeze(1), + q[group_index].unsqueeze(1).to(torch.float32), self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0)) del self.float_input @@ -447,13 +447,13 @@ def single_layer_update(self, percdamp: float = 0.01): U = torch.zeros( weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, - dtype=dtype) # [Groups, OC/groups, Samples] + dtype=torch.float32) # [Groups, OC/groups, Samples] for t in range(weight.shape[-1]): for group_index in range(self.groups): i = permutation_list[group_index][t] U[group_index] += torch.matmul( - weight[group_index, :, i].unsqueeze(1), + weight[group_index, :, i].unsqueeze(1).to(torch.float32), self.float_input[group_index, :, i].unsqueeze(0), ) # [OC/Groups, 1] * [1, INSHAPE[1]] norm = norms[group_index, i] @@ -461,11 +461,11 @@ def single_layer_update(self, percdamp: float = 0.01): q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm else: q_arg = torch.zeros_like(U[group_index, :, 0]) - weight[group_index, :, i] = q_arg + weight[group_index, :, i] = q_arg.to(dtype) q_groups = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): U[group_index] -= torch.matmul( - q_groups[group_index].unsqueeze(1), + q_groups[group_index].unsqueeze(1).to(torch.float32), self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0), ) if hasattr(self.layer, 'offload_params'): From c2f0ba5a4f6ff0eb547871ca7970d717a5b3d3ac Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 22 Oct 2024 06:35:49 +0000 Subject: [PATCH 10/19] Feat (axe): adding support for per-group quantization --- src/brevitas_examples/common/axe.py | 142 +++++++++++++++++----------- 1 file changed, 86 insertions(+), 56 deletions(-) diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py index b156ddc88..05bc99c22 100644 --- a/src/brevitas_examples/common/axe.py +++ b/src/brevitas_examples/common/axe.py @@ -107,24 +107,34 @@ def single_layer_update(self, percdamp=0.01): raise NotImplementedError("Signed inputs not yet supported.") n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) - - s = self.layer.weight_quant.scale() + scales: Tensor = self.layer.weight_quant.scale() + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + scales = scales.transpose(1, 0) # This performs a view + scales = scales.flatten(1) P = torch.tensor(self.max_accumulator_bit_width) N = self.quant_metadata.bit_width # NOTE: using sign-magnitude here, which is sufficient to support both # sign-magnitude and 2s complement accumulators - A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) - B = -A - Z = (pow(2, P) - 2) / float(pow(2, N) - 1) + self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) # A + self.lower_lim = -self.upper_lim # B + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # l1-norm lim for zero-centered weight vector # translating into the quantized range; need to pad to get these thresholds - wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( + wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] - T = calc_average_nonzero_mag( - wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles] - T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] - s = s.view(self.groups, -1) # [Groups, OC/Groups] - T *= s # translating centers back to the float range - + thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = thresholds.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] + del wT + # supporting groupwise quantization where each tile has its own scaling factor + if self.layer.weight_quant.is_groupwise: + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:,0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + # else each tile has the same scaling factor (per-tensor or per-channel) + else: + scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups] + scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] + thresholds *= scales # translating centers back to the float range weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] # List with permutation tensors for the Hessian and weight matrix. @@ -173,16 +183,15 @@ def single_layer_update(self, percdamp=0.01): del self.H, self.B # initialize cumulative l1-norm - a = torch.zeros_like(T, device=dev) # pos - b = torch.zeros_like(T, device=dev) # neg + a = torch.zeros_like(thresholds, device=dev) # positive limits + b = torch.zeros_like(thresholds, device=dev) # negative limits for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) count = i2 - i1 error_block = torch.zeros_like( weight[:, :, permutation_list[-1][i1:i2]], - dtype=torch.float32, - ) # [groups, OC/groups, i2-i1] + dtype=torch.float32) # [groups, OC/groups, i2-i1] h_inv_block = h_inv[:, i1:i2, i1:i2] for i in range(count): @@ -192,14 +201,14 @@ def single_layer_update(self, percdamp=0.01): bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index # calculate the q_max and q_min for the right group and right block # TODO: currently assuming round-to-zero; need to handle other rounding functions - q_max = s[group_index, :] * torch.clamp_min( - A - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups] - q_min = s[group_index, :] * torch.clamp_max( - B - b[group_index, bx, :] + 0.5, 0.0) # [OC/groups] + q_max = scales[group_index, bx, :] * torch.clamp_min( + self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups] + q_min = scales[group_index, bx, :] * torch.clamp_max( + self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) # [OC/groups] q_arg = weight[group_index, :, perm[i1:i2][i]] # [OC/groups] # soft thresholding then clamping q_arg = q_arg.sign() * torch.relu( - q_arg.abs() - T[group_index, bx]) # [OC/groups] + q_arg.abs() - thresholds[group_index, bx]) # [OC/groups] q_arg.clamp_(q_min, q_max) # clamping to bounds weight[group_index, :, perm[i1:i2][i]] = q_arg q_groups = self.get_quant_weights(i, i1, permutation_list) # [Groups, OC/groups] @@ -219,12 +228,12 @@ def single_layer_update(self, percdamp=0.01): for group_index in range(self.groups): perm = permutation_list[group_index] bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index - q = q_groups[group_index] / s[group_index] # [OC/groups] + q = q_groups[group_index] / scales[group_index, bx] # [OC/groups] # increment cumulative l1-norm a[group_index, bx, q >= 0] += q[q >= 0] b[group_index, bx, q <= 0] += q[q <= 0] - assert (a <= A).all() and (a >= 0).all() - assert (b >= B).all() and (b <= 0).all() + assert (a <= self.upper_lim).all() and (a >= 0).all() + assert (b >= self.lower_lim).all() and (b <= 0).all() for group_index in range(self.groups): perm = permutation_list[group_index] @@ -234,6 +243,8 @@ def single_layer_update(self, percdamp=0.01): if hasattr(self.layer, "offload_params"): self.layer.offload_params(self.layer) + del thresholds, scales # memory management + class A2GPFQ(GPFQv2): """ @@ -286,43 +297,62 @@ def single_layer_update(self, percdamp=0.01): raise NotImplementedError("Signed inputs not yet supported.") n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) - - s = self.layer.weight_quant.scale() + scales: Tensor = self.layer.weight_quant.scale() + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + scales = scales.transpose(1, 0) # This performs a view + scales = scales.flatten(1) P = torch.tensor(self.max_accumulator_bit_width) N = self.quant_metadata.bit_width # NOTE: using sign-magnitude here, which is sufficient to support both # sign-magnitude and 2s complement accumulators - A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) - B = -A - Z = (pow(2, P) - 2) / float(pow(2, N) - 1) + self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) # A + self.lower_lim = -self.upper_lim # B + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # l1-norm lim for zero-centered weight vector # translating into the quantized range; need to pad to get these thresholds - wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( + wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] - T = calc_average_nonzero_mag( - wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles] - T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] - s = s.view(self.groups, -1) # [Groups, OC/Groups] - T *= s # translating centers back to the float range + thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = thresholds.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + del wT + # supporting groupwise quantization where each tile has its own scaling factor + if self.layer.weight_quant.is_groupwise: + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:,0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + # else each tile has the same scaling factor (per-tensor or per-channel) + else: + scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups] + scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] + thresholds *= scales # translating centers back to the float range weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] # initialize cumulative l1-norm - a = torch.zeros_like(T, device=dev) # pos - b = torch.zeros_like(T, device=dev) # neg + a = torch.zeros_like(thresholds, device=dev) # positive limit + b = torch.zeros_like(thresholds, device=dev) # negative limit - # stablize G with a dampening factor and then square root the matrix - norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32) - self.H = self.H.to(dev) - diag = torch.arange(self.columns, device='cpu') - for i in range(self.groups): - damp = percdamp * self.H[i].diag().mean() - self.H[i, diag, diag] += damp - norms[i] = self.H[i].diag() # set the norms post-dampening - eigvals, eigvecs = torch.linalg.eigh(self.H[i]) - eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite - self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t() - del eigvecs, eigvals, diag - self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function + # Try/Except in case the square root of H cannot be computed + try: + norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32) + self.H = self.H.to(dev) + diag = torch.arange(self.columns, device='cpu') + for i in range(self.groups): + # stablize H with a dampening factor and then square root the matrix + damp = percdamp * self.H[i].diag().mean() + self.H[i, diag, diag] += damp + norms[i] = self.H[i].diag() # set the norms post-dampening + eigvals, eigvecs = torch.linalg.eigh(self.H[i]) + eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite + self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t() + del eigvecs, eigvals, diag + self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function + except LinAlgError: + warnings.warn( + f'Failed to compute the matrix square root of H for layer {self.name} ' + f'GPFQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return # Try/Except in case the inverse of H cannot be computed try: @@ -365,11 +395,11 @@ def single_layer_update(self, percdamp=0.01): q_arg = torch.zeros_like(U[group_index, :, 0]) bx = i // self.max_accumulator_tile_size # block index q_arg = q_arg.sign() * torch.relu( - q_arg.abs() - T[group_index, bx, :]) # soft thresholding + q_arg.abs() - thresholds[group_index, bx, :]) # soft thresholding # TODO: assuming round to nearest; need to generally support other rounding - q_max = s[group_index] * torch.clamp_min(A - a[group_index, bx, :] - 0.5, 0.0) - q_min = s[group_index] * torch.clamp_max(B - b[group_index, bx, :] + 0.5, 0.0) + q_max = scales[group_index, bx] * torch.clamp_min(self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) + q_min = scales[group_index, bx] * torch.clamp_max(self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) q_arg.clamp_(q_min, q_max) weight[group_index, :, i] = q_arg.to(dtype) q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list) @@ -379,11 +409,11 @@ def single_layer_update(self, percdamp=0.01): q_groups[group_index].unsqueeze(1).to(torch.float32), self.quant_input[group_index, :, i].unsqueeze(0)) bx = i // self.max_accumulator_tile_size # block index - q = q_groups[group_index] / s[group_index] # [OC/groups] + q = q_groups[group_index] / scales[group_index, bx] # [OC/groups] # increment cumulative l1-norm a[group_index, bx, q >= 0] += q[q >= 0] b[group_index, bx, q <= 0] += q[q <= 0] - assert (a <= A).all() and (a >= 0).all() - assert (b >= B).all() and (b <= 0).all() + assert (a <= self.upper_lim).all() and (a >= 0).all() + assert (b >= self.lower_lim).all() and (b <= 0).all() del self.quant_input, self.float_input From b17dcc9294cde0f4350caf48366535c9067d1a45 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 22 Oct 2024 06:36:42 +0000 Subject: [PATCH 11/19] Feat (llm): adding assertion for tile/group size --- src/brevitas_examples/llm/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index c66d7c15a..f02dfa1bd 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -80,6 +80,8 @@ def validate(args): assert args.weight_quant_granularity == 'per_group', "Sharded torch group export requires per group weight quant." assert args.input_bit_width is None, "Sharded torch group weight export doesn't support input quant." assert not args.quantize_weight_zero_point, "Quantized weight zero point not supported." + if args.max_accumulator_bit_width is not None: + assert args.max_accumulator_tile_size == args.weight_group_size, "Group size must be equal to tile size." if args.export_target == 'sharded_packed_torchmlir_group_weight': assert args.weight_quant_granularity == 'per_group', "Sharded torch group export requires per group weight quant." assert args.input_bit_width is None, "Sharded packed torch group weight export doesn't support input quant." @@ -470,7 +472,7 @@ def parse_args(args): parser.add_argument( '--gpxq-max-accumulator-tile-size', type=int, - default=128, + default=None, help='Maximum accumulator tile size for GPxQ using AXE.') parser.add_argument( '--act-calibration', action='store_true', help='Apply activation calibration.') From a429705b8ad18b800d915412ca00d514c95e44e8 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 22 Oct 2024 06:37:53 +0000 Subject: [PATCH 12/19] Pre-commit fixes --- src/brevitas/graph/gpfq.py | 14 +++++++++--- src/brevitas_examples/common/axe.py | 35 ++++++++++++++++++----------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index c1826bb6d..9a32eb6b5 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -241,7 +241,11 @@ def single_layer_update(self): self.float_input = self.float_input.to(dev) self.quant_input = self.quant_input.to(dev) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=torch.float32) + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, + dtype=torch.float32) # We don't need full Hessian, we just need the diagonal # Summing over batch dimension H_diag = self.quant_input.transpose(2, 1).square().sum(2) @@ -259,7 +263,8 @@ def single_layer_update(self): for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( - weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1).to(torch.float32), + weight[group_index, :, + permutation_list[group_index][t]].unsqueeze(1).to(torch.float32), self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] norm = torch.linalg.norm( @@ -446,7 +451,10 @@ def single_layer_update(self, percdamp: float = 0.01): permutation_list = self._get_permutation_list(weight) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, dtype=torch.float32) # [Groups, OC/groups, Samples] for t in range(weight.shape[-1]): diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py index 05bc99c22..ff7569e53 100644 --- a/src/brevitas_examples/common/axe.py +++ b/src/brevitas_examples/common/axe.py @@ -122,17 +122,20 @@ def single_layer_update(self, percdamp=0.01): # translating into the quantized range; need to pad to get these thresholds wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] - thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] thresholds = thresholds.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] del wT # supporting groupwise quantization where each tile has its own scaling factor if self.layer.weight_quant.is_groupwise: - scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] - scales = scales[:,0] # [Groups * OC * Tiles, 1] - scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:, 0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] # else each tile has the same scaling factor (per-tensor or per-channel) else: - scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups] + scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups] scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] thresholds *= scales # translating centers back to the float range weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] @@ -312,17 +315,21 @@ def single_layer_update(self, percdamp=0.01): # translating into the quantized range; need to pad to get these thresholds wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] - thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] - thresholds = thresholds.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + thresholds = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = thresholds.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] del wT # supporting groupwise quantization where each tile has its own scaling factor if self.layer.weight_quant.is_groupwise: - scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] - scales = scales[:,0] # [Groups * OC * Tiles, 1] - scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:, 0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] # else each tile has the same scaling factor (per-tensor or per-channel) else: - scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups] + scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups] scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] thresholds *= scales # translating centers back to the float range @@ -398,8 +405,10 @@ def single_layer_update(self, percdamp=0.01): q_arg.abs() - thresholds[group_index, bx, :]) # soft thresholding # TODO: assuming round to nearest; need to generally support other rounding - q_max = scales[group_index, bx] * torch.clamp_min(self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) - q_min = scales[group_index, bx] * torch.clamp_max(self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) + q_max = scales[group_index, bx] * torch.clamp_min( + self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) + q_min = scales[group_index, bx] * torch.clamp_max( + self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) q_arg.clamp_(q_min, q_max) weight[group_index, :, i] = q_arg.to(dtype) q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list) From 8990fec2bb00ea1dabef13c83d45d68d7806de91 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Wed, 23 Oct 2024 21:28:04 +0000 Subject: [PATCH 13/19] Fix (axe): minor reshaping fix --- src/brevitas_examples/common/axe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py index ff7569e53..5c272c4b1 100644 --- a/src/brevitas_examples/common/axe.py +++ b/src/brevitas_examples/common/axe.py @@ -124,7 +124,8 @@ def single_layer_update(self, percdamp=0.01): -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] thresholds = calc_average_nonzero_mag( wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] - thresholds = thresholds.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] + thresholds = thresholds.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] del wT # supporting groupwise quantization where each tile has its own scaling factor if self.layer.weight_quant.is_groupwise: @@ -273,7 +274,8 @@ def __init__( assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2." def single_layer_update(self, percdamp=0.01): - assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." + assert not self.layer.weight_quant.requires_quant_input, \ + "Error: GPFQ does not support weight quantizers that require quantized inputs." if self.quant_metadata is None: raise ValueError( "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. " From 9797698a5f422d7860a2fd23d2c76ab66a975d0f Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 25 Oct 2024 21:40:27 +0000 Subject: [PATCH 14/19] Fix (main): updating validation for AXE CLI args --- src/brevitas_examples/llm/main.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index f02dfa1bd..4a87f5a1a 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -72,6 +72,20 @@ def validate(args): if not args.no_quantize: if args.gptq and args.gpfq: warn("Both GPTQ and GPFQ are enabled.") + if args.gpxq_max_accumulator_bit_width is not None: + assert args.weight_quant_format == 'int', "AXE only supports integer formats." + assert args.input_quant_format == 'int', "AXE only supports integer formats." + assert args.input_bit_width is not None, \ + "Specify input bit width; activation quantization is required to guarantee accumulator bounds." + if not (args.gptq or args.gpfq): + warn("Max accumulator bit width is specified, but no GPxQ is enabled.") + if args.gpxq_max_accumulator_tile_size is not None: + if args.weight_quant_granularity == 'per_group': + assert args.gpxq_max_accumulator_tile_size == args.weight_group_size, \ + "Group size must be equal to tile size with per_group quantization." + if args.input_quant_granularity == 'per_group': + assert args.gpxq_max_accumulator_tile_size == args.input_group_size, \ + "Group size must be equal to tile size with per_group quantization." if args.export_target is not None: assert args.input_quant_format == 'int', "Only integer quantization supported for export currently." if args.export_target is not None and args.input_bit_width is not None: @@ -80,8 +94,6 @@ def validate(args): assert args.weight_quant_granularity == 'per_group', "Sharded torch group export requires per group weight quant." assert args.input_bit_width is None, "Sharded torch group weight export doesn't support input quant." assert not args.quantize_weight_zero_point, "Quantized weight zero point not supported." - if args.max_accumulator_bit_width is not None: - assert args.max_accumulator_tile_size == args.weight_group_size, "Group size must be equal to tile size." if args.export_target == 'sharded_packed_torchmlir_group_weight': assert args.weight_quant_granularity == 'per_group', "Sharded torch group export requires per group weight quant." assert args.input_bit_width is None, "Sharded packed torch group weight export doesn't support input quant." From 0bef8e0a47b1460bc5738d0e1b08f15650afaf75 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 25 Oct 2024 21:45:20 +0000 Subject: [PATCH 15/19] Fix (gptq): explicitly setting gptq_class default --- src/brevitas/graph/gptq.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index d80ee1069..667e47d40 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -232,6 +232,7 @@ class gptq_mode(gpxq_mode): act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the forward call inside the context manager returns None. Default: False + gptq_class (GPTQ): The uninitialized class to perform GPTQ. Default: `brevitas.graph.gptq.GPTQ` Example: >>> with torch.no_grad(): @@ -254,7 +255,7 @@ def __init__( num_blocks: int = 100, return_forward_output: bool = False, act_order: bool = False, - gptq_class: Optional[GPxQ] = None) -> None: + gptq_class: GPTQ = GPTQ) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -268,8 +269,6 @@ def __init__( # How many subblock to use during GPTQ for each layer self.num_blocks = num_blocks - if gptq_class is None: - gptq_class = GPTQ self.gptq_class = gptq_class def catch_stopfwd(self, *args, **kwargs): From 3441920afcac64164df9dae39eb7d6f9dd0206eb Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 25 Oct 2024 21:45:43 +0000 Subject: [PATCH 16/19] Fix (gpfq): explicitly setting gpfq_class default --- src/brevitas/graph/gpfq.py | 204 ++++++++++++++++++------------------- 1 file changed, 102 insertions(+), 102 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 9a32eb6b5..92e3da2bf 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -31,108 +31,6 @@ from brevitas.quant_tensor import _unpack_quant_tensor -class gpfq_mode(gpxq_mode): - """ - Apply GPFQ algorithm. - - Args: - model (Module): The model to quantize with GPFQ - group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group - of layer names that can be optimized in parallel. Default: None - inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True - create_weight_orig (bool): If True, store the original floating point weights before applying - gpfq. These weights will be used anytime quantization is disabled. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPFQ. Default: False - p (float): The percentage of processed inputs to use. Default: 1.0 - return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the - forward call inside the context manager returns None. Default: False - act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gpfq_mode(model) as gpfq: - >>> gpfq_model = gpfq.model - >>> for i in tqdm(range(gpfq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gpfq_model(img) - >>> gpfq.update() - """ - - def __init__( - self, - model: nn.Module, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - create_weight_orig: bool = True, - use_quant_activations: bool = True, - p: float = 1.0, - return_forward_output: bool = False, - act_order: bool = False, - gpfq_class: Optional[GPxQ] = None) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - create_weight_orig, - use_quant_activations, - act_order, - return_forward_output) - - self.p = p - if gpfq_class is None: - gpfq_class = GPFQ - self.gpfq_class = gpfq_class - - def catch_stopfwd(self, *args, **kwargs): - # Collect quant input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - - # Disable quantization - self.return_quant_tensor_state = disable_return_quant_tensor(self.model) - self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) - self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) - # Collect float input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - - # Re-enable quantization. If activation quantization is disabled, - # we also disable bias quantization - self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) - if self.use_quant_activations: - self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) - else: - self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) - restore_return_quant_tensor(self.model, self.return_quant_tensor_state) - - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - def initialize_module_optimizer( - self, layer, name, act_order, len_parallel_layers, create_weight_orig): - return self.gpfq_class( - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - p=self.p) - - class GPFQ(GPxQ): """ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main @@ -480,3 +378,105 @@ def single_layer_update(self, percdamp: float = 0.01): self.layer.offload_params(self.layer) del self.float_input del self.quant_input + + +class gpfq_mode(gpxq_mode): + """ + Apply GPFQ algorithm. + + Args: + model (Module): The model to quantize with GPFQ + group_of_parallel_layers (Optional, List[str]): .List of lists where each inner list is a group + of layer names that can be optimized in parallel. Default: None + inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True + create_weight_orig (bool): If True, store the original floating point weights before applying + gpfq. These weights will be used anytime quantization is disabled. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPFQ. Default: False + p (float): The percentage of processed inputs to use. Default: 1.0 + return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the + forward call inside the context manager returns None. Default: False + act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False + gpfq_class (GPFQ): The uninitialized class to perform GPFQ. Default: `brevitas.graph.gpfq.GPFQv2`, + which is the memory-efficient formulation + + Example: + >>> with torch.no_grad(): + >>> with gpfq_mode(model) as gpfq: + >>> gpfq_model = gpfq.model + >>> for i in tqdm(range(gpfq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gpfq_model(img) + >>> gpfq.update() + """ + + def __init__( + self, + model: nn.Module, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + create_weight_orig: bool = True, + use_quant_activations: bool = True, + p: float = 1.0, + return_forward_output: bool = False, + act_order: bool = False, + gpfq_class: GPFQ = GPFQv2) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + create_weight_orig, + use_quant_activations, + act_order, + return_forward_output) + + self.p = p + self.gpfq_class = gpfq_class + + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Disable quantization + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + def initialize_module_optimizer( + self, layer, name, act_order, len_parallel_layers, create_weight_orig): + return self.gpfq_class( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + p=self.p) From d75f8fac62916ceb184868a24543aa27098c9d0e Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 25 Oct 2024 21:56:29 +0000 Subject: [PATCH 17/19] Fix (axe): adding assertions for rounding_mode checks --- src/brevitas_examples/common/axe.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py index 5c272c4b1..a89c3c8c2 100644 --- a/src/brevitas_examples/common/axe.py +++ b/src/brevitas_examples/common/axe.py @@ -105,6 +105,11 @@ def single_layer_update(self, percdamp=0.01): # TODO: add support for signed input activations if self.quant_metadata.signed: raise NotImplementedError("Signed inputs not yet supported.") + + # TODO: currently assuming round-to-zero; need to handle other rounding functions + rounding_mode = self.layer.weight_quant.rounding_mode + if rounding_mode.lower() != "round": + raise NotImplementedError(f"{rounding_mode} not yet supported.") n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) scales: Tensor = self.layer.weight_quant.scale() @@ -204,7 +209,6 @@ def single_layer_update(self, percdamp=0.01): perm = permutation_list[group_index] bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index # calculate the q_max and q_min for the right group and right block - # TODO: currently assuming round-to-zero; need to handle other rounding functions q_max = scales[group_index, bx, :] * torch.clamp_min( self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups] q_min = scales[group_index, bx, :] * torch.clamp_max( @@ -228,7 +232,6 @@ def single_layer_update(self, percdamp=0.01): error.unsqueeze(1).matmul( h_inv_block[group_index, i, i:].unsqueeze(0).to(dev))).to(dtype) # update the tracking mechanisms - # TODO: need to handle non-zero zero points for group_index in range(self.groups): perm = permutation_list[group_index] bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index @@ -301,6 +304,11 @@ def single_layer_update(self, percdamp=0.01): if self.quant_metadata.signed: raise NotImplementedError("Signed inputs not yet supported.") + # TODO: currently assuming round-to-zero; need to handle other rounding functions + rounding_mode = self.layer.weight_quant.rounding_mode + if rounding_mode.lower() != "round": + raise NotImplementedError(f"{rounding_mode} not yet supported.") + n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) scales: Tensor = self.layer.weight_quant.scale() if isinstance(self.layer, SUPPORTED_CONV_OP): @@ -405,8 +413,6 @@ def single_layer_update(self, percdamp=0.01): bx = i // self.max_accumulator_tile_size # block index q_arg = q_arg.sign() * torch.relu( q_arg.abs() - thresholds[group_index, bx, :]) # soft thresholding - - # TODO: assuming round to nearest; need to generally support other rounding q_max = scales[group_index, bx] * torch.clamp_min( self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) q_min = scales[group_index, bx] * torch.clamp_max( From 237f72583f6980283b49ceac528122b8f3a20f50 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 25 Oct 2024 21:56:59 +0000 Subject: [PATCH 18/19] Pre-commit fixes --- src/brevitas_examples/common/axe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py index a89c3c8c2..39e22535d 100644 --- a/src/brevitas_examples/common/axe.py +++ b/src/brevitas_examples/common/axe.py @@ -105,7 +105,7 @@ def single_layer_update(self, percdamp=0.01): # TODO: add support for signed input activations if self.quant_metadata.signed: raise NotImplementedError("Signed inputs not yet supported.") - + # TODO: currently assuming round-to-zero; need to handle other rounding functions rounding_mode = self.layer.weight_quant.rounding_mode if rounding_mode.lower() != "round": From ce45127ff1bf394781e0db6361ce5e384824f0d3 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Fri, 25 Oct 2024 22:03:22 +0000 Subject: [PATCH 19/19] Feat (docs): updating readme --- src/brevitas_examples/llm/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 457c74804..5cd067e64 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -34,6 +34,9 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--input-quant-granularity {per_tensor,per_row,per_group}] [--input-group-size INPUT_GROUP_SIZE] [--quantize-input-zero-point] [--quantize-last-layer] [--gptq] + [--gpfq] [--gpxq-act-order] [--gpxq-use-quant-activations] [--gpxq-create-weight-orig] + [--gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH] + [--gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE] [--act-calibration] [--bias-corr] [--ln-affine-merge] [--no-quantize] [--no-float16] [--replace-mha] [--weight-equalization] @@ -105,6 +108,16 @@ options: --quantize-last-layer Quantize last nn.Linear layer. --gptq Apply GPTQ. + --gpfq Apply GPFQ. + --gpxq-act-order Apply GPxQ activation ordering. + --gpxq-use-quant-activations + Use quantized activations in GPxQ. + --gpxq-create-weight-orig + Create weight_orig in GPxQ. + --gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH + Maximum accumulator bit width for GPxQ using AXE. + --gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE + Maximum accumulator tile size for GPxQ using AXE. --act-calibration Apply activation calibration. --bias-corr Apply bias correction. --ln-affine-merge Merge LN affine params.