From 24d356df6fdd757f96086326e9377ebc49f72fb4 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 15 Oct 2024 15:25:24 +0000 Subject: [PATCH] Feat (axe): adding accumulator-aware extensions for GPxQ --- src/brevitas_examples/llm/llm_quant/axe.py | 382 ++++++++++++++++++++ src/brevitas_examples/llm/llm_quant/gpxq.py | 50 ++- src/brevitas_examples/llm/main.py | 39 +- 3 files changed, 450 insertions(+), 21 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/axe.py diff --git a/src/brevitas_examples/llm/llm_quant/axe.py b/src/brevitas_examples/llm/llm_quant/axe.py new file mode 100644 index 000000000..c7c96af67 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/axe.py @@ -0,0 +1,382 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import math +import warnings + +import numpy as np +import torch +from torch import Tensor + +try: + from torch.linalg import LinAlgError +except: + LinAlgError = RuntimeError + +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +from brevitas.graph.gpxq import SUPPORTED_TCONV_OP + + +def _get_average_of_nonzero_magnitudes(vec: np.ndarray, radius: float = 1.0): + assert radius > 0, "Error: radius needs to be strictly positive." + assert vec.ndim == 1, "Error: projection assumes a vector, not a matrix." + assert vec.min() >= 0, "Error: assuming a vector of non-negative numbers." + n_elems = vec.shape[0] + # if we are already within the simplex, then the best projection is itself + if vec.sum() <= radius: + return 0.0 + # using algorithm detailed in "Efficient Projections onto the L1-Ball for Learning in High Dimensions" + v = vec + u = np.sort(v)[::-1] + cumsum_u = np.cumsum(u) + rho = np.nonzero(u * np.arange(1, n_elems + 1) > (cumsum_u - radius))[0][-1] + theta = float(cumsum_u[rho] - radius) / (rho + 1) + return theta + + +def calc_average_nonzero_mag(weight: Tensor, lim: Tensor) -> Tensor: + thetas = torch.zeros(weight.shape[0], device=weight.device) + for i in range(weight.shape[0]): + l = lim[i].item() if lim.ndim > 0 else lim.item() + w = weight[i].cpu().detach().numpy() + t = _get_average_of_nonzero_magnitudes(np.abs(w), l) + thetas[i] = t + return thetas + + +def pad_tensor_with_zeros(tensor: Tensor, tile_size: int) -> Tensor: + pad_size = tile_size - (tensor.shape[1] % tile_size) + if pad_size == tile_size: + return tensor + padding = torch.zeros((tensor.shape[0], pad_size), device=tensor.device) + pad_tensor = torch.concat([tensor, padding], axis=1) + return pad_tensor + + +class A2GPTQ(GPTQ): + """ + Accumulator-aware GPTQ as proposed in https://arxiv.org/pdf/2409.17092 + """ + + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + num_blocks, + max_accumulator_bit_width, + max_accumulator_tile_size) -> None: + super().__init__( + layer, name, act_order, len_parallel_layers, create_weight_orig, num_blocks) + self.max_accumulator_bit_width = max_accumulator_bit_width + self.max_accumulator_tile_size = max_accumulator_tile_size + if self.max_accumulator_tile_size is None: + self.max_accumulator_tile_size = self.columns + assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2." + assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2." + + def single_layer_update(self, percdamp=0.01): + assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." + if self.quant_metadata is None: + raise ValueError( + "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. " + "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. " + "Also, check if `use_quant_activations=True` in `gptq_mode` when `max_accumulator_bit_width` is specified. " + ) + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) + weight = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # TODO: add support for signed input activations + assert not self.quant_metadata.signed + + n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) + + s = self.layer.weight_quant.scale() + P = torch.tensor(self.max_accumulator_bit_width) + N = self.quant_metadata.bit_width + # TODO: add support for two's complement accumulator representation + A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) + B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1) + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) + # translating into the quantized range; need to pad to get these thresholds + wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] + T = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles] + T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] + s = s.view(self.groups, -1) # [Groups, OC/Groups] + T *= s # translating centers back to the float range + + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + + # List with permutation tensors for the Hessian and weight matrix. + # If act_order is False, the tensors will be ordered indexes. + # For groupwise convolution, we have one tensor per group, + # thus len(permutation_list) is always equal to self.groups. + # We do not explicity permute the weight matrix, only the Hessian. + permutation_list = [] + weight = weight.view(self.groups, -1, weight.shape[-1]) + # For groupwise convolution, these operations are groupwise so we iterate + for i in range(self.groups): + # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding + # column in the weight matrix. + # The diagonal element is set to 1 to avoid division-by-zero + dead = torch.diag(self.H[i, :, :]) == 0 + self.H[i, dead, dead] = 1 + # If the diagonal of activations is zero, we set the weight to zero + weight[i, :, dead] = 0 + if self.act_order: + # Re-order Hessian so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True) + self.H[i, :, :] = self.H[i, perm, :][:, perm] + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(self.H.shape[-1]), device=dev) + permutation_list.append(perm) + + # Try/Except in case the inverse Hessian cannot be computed + try: + for i in range(self.groups): + damp = percdamp * torch.mean(torch.diag(self.H[i, :, :])) + diag = torch.arange(self.columns, device='cpu') + self.H[i, diag, diag] += damp + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) + self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) + h_inv = self.H + except LinAlgError: + warnings.warn( + f'Failed to compute the inverse of the Hessian for layer {self.name} ' + f'GPTQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H, self.B + + # initialize cumulative l1-norm + a = torch.zeros_like(T, device=dev) # pos + b = torch.zeros_like(T, device=dev) # neg + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + error_block = torch.zeros_like( + weight[:, :, permutation_list[-1][i1:i2]], + dtype=torch.float32, + ) # [groups, OC/groups, i2-i1] + + h_inv_block = h_inv[:, i1:i2, i1:i2] + for i in range(count): + # need to apply soft thresholding and clamping before quantization + for group_index in range(self.groups): + perm = permutation_list[group_index] + bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index + # calculate the q_max and q_min for the right group and right block + # TODO: currently assuming round-to-zero; need to handle other rounding functions + q_max = s[group_index, :] * torch.clamp_min( + A - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups] + q_min = s[group_index, :] * torch.clamp_max( + B - b[group_index, bx, :] + 0.5, 0.0) # [OC/groups] + q_arg = weight[group_index, :, perm[i1:i2][i]] # [OC/groups] + # soft thresholding then clamping + q_arg = q_arg.sign() * torch.relu( + q_arg.abs() - T[group_index, bx]) # [OC/groups] + q_arg.clamp_(q_min, q_max) # clamping to bounds + weight[group_index, :, perm[i1:i2][i]] = q_arg + q_groups = self.get_quant_weights(i, i1, permutation_list) # [Groups, OC/groups] + for group_index in range(self.groups): + perm = permutation_list[group_index] + q = q_groups[group_index] # [OC/groups] + w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] + d = h_inv_block[group_index, i, i] # [1] + error = (w - q) / d # [OC/groups] + error_block[group_index, :, i] = error + # We need to update the original weights + weight[group_index, :, perm[i1:i2][i:]] -= ( + error.unsqueeze(1).matmul( + h_inv_block[group_index, i, i:].unsqueeze(0).to(dev))).to(dtype) + # update the tracking mechanisms + # TODO: need to handle non-zero zero points + for group_index in range(self.groups): + perm = permutation_list[group_index] + bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index + q = q_groups[group_index] / s[group_index] # [OC/groups] + # increment cumulative l1-norm + a[group_index, bx, q >= 0] += q[q >= 0] + b[group_index, bx, q <= 0] += q[q <= 0] + assert (a <= A).all() and (a >= 0).all() + assert (b >= B).all() and (b <= 0).all() + + for group_index in range(self.groups): + perm = permutation_list[group_index] + weight[group_index, :, perm[i2:]] -= ( + error_block[group_index].matmul(h_inv[group_index, i1:i2, + i2:].to(dev))).to(dtype) + if hasattr(self.layer, "offload_params"): + self.layer.offload_params(self.layer) + + +class A2GPFQ(GPFQv2): + """ + Memory-efficient, accumulator-aware GPFQ as proposed in https://arxiv.org/pdf/2409.17092 + """ + + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + p, + max_accumulator_bit_width, + max_accumulator_tile_size) -> None: + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig, p) + self.max_accumulator_bit_width = max_accumulator_bit_width + self.max_accumulator_tile_size = max_accumulator_tile_size + if self.max_accumulator_tile_size is None: + self.max_accumulator_tile_size = self.columns + assert self.max_accumulator_tile_size > 2, "Error: accumulator tile size needs to be bigger than 2." + assert self.max_accumulator_bit_width > 2, "Error: accumulator bit width needs to be bigger than 2." + + def single_layer_update(self, percdamp=0.01): + assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." + if self.quant_metadata is None: + raise ValueError( + "Expected self.quant_metadata to calculate accumualtor bounds, but recevied None. " + "Make sure that either the input to the model is an IntQuantTensor or the layer has an input quant enabled. " + "Also, check if `use_quant_activations=True` in `gpfq_mode` when `max_accumulator_bit_width` is specified. " + ) + if hasattr(self.layer, "allocate_params"): + self.layer.allocate_params(self.layer) + weight: Tensor = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, SUPPORTED_TCONV_OP): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # TODO: add support for signed input activations + assert not self.quant_metadata.signed + + n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size) + + s = self.layer.weight_quant.scale() + P = torch.tensor(self.max_accumulator_bit_width) + N = self.quant_metadata.bit_width + # TODO: add support for two's complement accumulator representation + A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) + B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1) + Z = (pow(2, P) - 2) / float(pow(2, N) - 1) + # translating into the quantized range; need to pad to get these thresholds + wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] + T = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles] + T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] + s = s.view(self.groups, -1) # [Groups, OC/Groups] + T *= s # translating centers back to the float range + + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + + # initialize cumulative l1-norm + a = torch.zeros_like(T, device=dev) # pos + b = torch.zeros_like(T, device=dev) # neg + + # stablize G with a dampening factor and then square root the matrix + norms = torch.zeros((self.groups, self.columns), device=dev, dtype=dtype) + self.H = self.H.to(dev) + diag = torch.arange(self.columns, device='cpu') + for i in range(self.groups): + damp = percdamp * self.H[i].diag().mean() + self.H[i, diag, diag] += damp + norms[i] = self.H[i].diag() # set the norms post-dampening + eigvals, eigvecs = torch.linalg.eigh(self.H[i]) + eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite + self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t() + del eigvecs, eigvals, diag + self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function + + # Try/Except in case the inverse of H cannot be computed + try: + self.float_input = self.H.clone() # going to calculate H^{-1} here + for i in range(self.groups): + # from our matrix sqrt, we know G is symmetric and positive-definite, so we + # can use Cholesky decomposition as an efficient, numerically stable inverse + L = torch.linalg.cholesky(self.float_input[i]) + self.float_input[i] = torch.cholesky_inverse(L) + self.float_input = torch.bmm(self.float_input.to(dev), self.G.to(dev)) + del L # memory management + except LinAlgError: + warnings.warn( + f'Failed to compute the inverse of H for layer {self.name} ' + f'GPFQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H, self.G, self.B # memory management + + permutation_list = self._get_permutation_list(weight) + + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, + dtype=dtype) # [Groups, OC/groups, Samples] + + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + i = permutation_list[group_index][t] + U[group_index] += torch.matmul( + weight[group_index, :, i].unsqueeze(1), + self.float_input[group_index, :, i].unsqueeze(0)) + norm = norms[group_index, i] + if norm > 0: + q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + bx = i // self.max_accumulator_tile_size # block index + q_arg = q_arg.sign() * torch.relu( + q_arg.abs() - T[group_index, bx, :]) # soft thresholding + + # TODO: assuming round to nearest; need to generally support other rounding + q_max = s[group_index] * torch.clamp_min(A - a[group_index, bx, :] - 0.5, 0.0) + q_min = s[group_index] * torch.clamp_max(B - b[group_index, bx, :] + 0.5, 0.0) + q_arg.clamp_(q_min, q_max) + weight[group_index, :, i] = q_arg + q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list) + for group_index in range(self.groups): + i = permutation_list[group_index][t] + U[group_index] -= torch.matmul( + q_groups[group_index].unsqueeze(1), + self.quant_input[group_index, :, i].unsqueeze(0)) + bx = i // self.max_accumulator_tile_size # block index + q = q_groups[group_index] / s[group_index] # [OC/groups] + # increment cumulative l1-norm + a[group_index, bx, q >= 0] += q[q >= 0] + b[group_index, bx, q <= 0] += q[q <= 0] + assert (a <= A).all() and (a >= 0).all() + assert (b >= B).all() and (b <= 0).all() + + del self.quant_input, self.float_input diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index e2bfba989..ef1d161f1 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -1,14 +1,19 @@ -""" -Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -""" + +from functools import partial import torch from tqdm import tqdm from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gpfq import GPFQv2 +from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode +from .axe import A2GPFQ +from .axe import A2GPTQ + @torch.no_grad() def apply_gptq( @@ -17,12 +22,24 @@ def apply_gptq( act_order=True, group_of_parallel_layers=None, use_quant_activations=True, - create_weight_orig=False): + create_weight_orig=False, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gptq_class = partial( + A2GPTQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gptq_class = GPTQ with gptq_mode(model, act_order=act_order, group_of_parallel_layers=group_of_parallel_layers, use_quant_activations=use_quant_activations, - create_weight_orig=create_weight_orig) as gptq: + create_weight_orig=create_weight_orig, + gptq_class=gptq_class) as gptq: gptq_model = gptq.model for _ in tqdm(range(gptq.num_layers)): for inps in dataloader: @@ -31,9 +48,26 @@ def apply_gptq( @torch.no_grad() -def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None): - with gpfq_mode(model, act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers) as gpfq: +def apply_gpfq( + model, + dataloader, + act_order=True, + group_of_parallel_layers=None, + max_accumulator_bit_width=None, + max_accumulator_tile_size=128): + if max_accumulator_bit_width is not None: + # Use accumulator-aware extension (AXE) framework + print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...") + gpfq_class = partial( + A2GPFQ, + max_accumulator_bit_width=max_accumulator_bit_width, + max_accumulator_tile_size=max_accumulator_tile_size) + else: + gpfq_class = GPFQv2 + with gpfq_mode(model, + act_order=act_order, + group_of_parallel_layers=group_of_parallel_layers, + gpfq_class=gpfq_class) as gpfq: gpfq_model = gpfq.model for _ in tqdm(range(gpfq.num_layers)): for inps in dataloader: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index c33de54c8..ba3dd6cab 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -1,7 +1,5 @@ -""" -Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -""" import argparse import sys @@ -159,8 +157,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=args.fuse_sequences, - ) + fuse_sequences=args.fuse_sequences) validation_loader = get_dataset_for_model( args.model, @@ -172,8 +169,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=args.fuse_sequences, - ) + fuse_sequences=args.fuse_sequences) device = next(iter(model.parameters())).device print("Data loaded.") @@ -285,12 +281,19 @@ def main(args): calibration_loader, act_order=args.gpxq_act_order, use_quant_activations=args.gpxq_use_quant_activations, - create_weight_orig=args.gpxq_create_weight_orig) + create_weight_orig=args.gpxq_create_weight_orig, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPTQ applied.") if args.gpfq: print("Applying GPFQ...") - apply_gpfq(model, calibration_loader, act_order=args.gpxq_act_order) + apply_gpfq( + model, + calibration_loader, + act_order=args.gpxq_act_order, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) print("GPFQ applied.") if args.bias_corr: @@ -298,7 +301,7 @@ def main(args): apply_bias_correction(model, calibration_loader) print("Bias correction applied.") - if args.eval: + if args.eval and not args.no_quantize: print("Model eval...") quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) @@ -442,13 +445,23 @@ def parse_args(args): parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') parser.add_argument('--gpfq', action='store_true', help='Apply GPFQ.') parser.add_argument( - '--gpxq-act-order', action='store_true', help='Apply GPXQ activation ordering.') + '--gpxq-act-order', action='store_true', help='Apply GPxQ activation ordering.') parser.add_argument( '--gpxq-use-quant-activations', action='store_true', - help='Use quantized activations in GPXQ.') + help='Use quantized activations in GPxQ.') parser.add_argument( - '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPXQ.') + '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPxQ.') + parser.add_argument( + '--gpxq-max-accumulator-bit-width', + type=int, + default=None, + help='Maximum accumulator bit width for GPxQ using AXE.') + parser.add_argument( + '--gpxq-max-accumulator-tile-size', + type=int, + default=128, + help='Maximum accumulator tile size for GPxQ using AXE.') parser.add_argument( '--act-calibration', action='store_true', help='Apply activation calibration.') parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.')