From 424ce6fd88d63b188bbb90d0b2f20c3ad3e449c1 Mon Sep 17 00:00:00 2001 From: Ian Colbert <88047104+i-colbert@users.noreply.github.com> Date: Wed, 9 Oct 2024 08:15:24 -0700 Subject: [PATCH] Feat (gpfq): adding memory-efficient formulation (#1043) --- src/brevitas/graph/gpfq.py | 377 +++++++++++--------- src/brevitas/graph/gpxq.py | 11 +- src/brevitas_examples/llm/llm_quant/gptq.py | 23 -- src/brevitas_examples/llm/llm_quant/gpxq.py | 41 +++ src/brevitas_examples/llm/main.py | 38 +- tests/brevitas/graph/test_gpxq.py | 69 +--- 6 files changed, 299 insertions(+), 260 deletions(-) delete mode 100644 src/brevitas_examples/llm/llm_quant/gptq.py create mode 100644 src/brevitas_examples/llm/llm_quant/gpxq.py diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index fd7df9223..33ee5fbb4 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -3,42 +3,32 @@ from copy import deepcopy import math -from math import pi -from typing import Callable, List, Optional +from typing import List, Optional import numpy as np +from packaging import version import torch -from torch.fft import fft -from torch.fft import fftn +from torch import Tensor import torch.nn as nn + +try: + from torch.linalg import LinAlgError +except: + LinAlgError = RuntimeError +import warnings + import unfoldNd -from brevitas.function import get_upper_bound_on_l1_norm +from brevitas import torch_version from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP +from brevitas.graph.gpxq import SUPPORTED_TCONV_OP import brevitas.nn as qnn - - -def random_projection( - float_input: torch.Tensor, quantized_input: torch.Tensor, compression_rate: float): - # use random projection to reduce dimensionality - n = quantized_input.size(1) - target_dim = int(compression_rate * n) - dev = float_input.device - # create gaussian random matrix - R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev) - quantized_input = torch.transpose(quantized_input, 1, 2) @ R.T - float_input = torch.transpose(float_input, 1, 2) @ R.T - del R - # reshape back - quantized_input = torch.transpose(quantized_input, 1, 2) - float_input = torch.transpose(float_input, 1, 2) - - return float_input, quantized_input +from brevitas.quant_tensor import _unpack_quant_tensor class gpfq_mode(gpxq_mode): @@ -58,10 +48,6 @@ class gpfq_mode(gpxq_mode): return_forward_output (bool): If True, returns the output of the forward pass. Otherwise the forward call inside the context manager returns None. Default: False act_order (bool): Whether to order greedy path following by Hessian approximation. Default: False - use_gpfa2q (bool): Whether to use accumulator-aware GPFQ. Default: False - accumulator_bit_width (Optional, int): The target accumulator bit width. Default: None - a2q_layer_filter_fnc (Optional, callable): An optional lambda function to filter layers for - accumulator cosntraints. Should return True for layers to constrain. Default: `lambda x: True` Example: >>> with torch.no_grad(): @@ -84,10 +70,7 @@ def __init__( p: float = 1.0, return_forward_output: bool = False, act_order: bool = False, - use_gpfa2q: bool = False, - accumulator_bit_width: Optional[int] = None, - a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True, - compression_rate: Optional[float] = 0.0) -> None: + gpfq_class: Optional[nn.Module] = None) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -100,16 +83,11 @@ def __init__( return_forward_output) self.p = p - - # GPFA2Q params - self.use_gpfa2q = use_gpfa2q - self.accumulator_bit_width = accumulator_bit_width - self.a2q_layer_filter_fnc = a2q_layer_filter_fnc # returns true when to use GPFA2Q - - # selecting impl of random proj - self.compression_rate = compression_rate - if self.compression_rate < 0.0 or self.compression_rate > 1.0: - raise ValueError('Compression rate for random projection must be between 0 and 1.') + if gpfq_class is None: + gpfq_class = GPFQ + self.gpfq_class = gpfq_class + assert issubclass(gpfq_class, GPxQ), \ + "Error: expected `gpfq_class` to be derived from `brevitas.graph.gpxq.GPxQ`." def catch_stopfwd(self, *args, **kwargs): # Collect quant input @@ -148,25 +126,13 @@ def catch_stopfwd(self, *args, **kwargs): def initialize_module_optimizer( self, layer, name, act_order, len_parallel_layers, create_weight_orig): - if (not self.a2q_layer_filter_fnc(layer)) or (not self.use_gpfa2q): - return GPFQ( - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - p=self.p, - compression_rate=self.compression_rate) - else: - return GPFA2Q( - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - p=self.p, - accumulator_bit_width=self.accumulator_bit_width, - compression_rate=self.compression_rate) + return self.gpfq_class( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + p=self.p) class GPFQ(GPxQ): @@ -174,17 +140,14 @@ class GPFQ(GPxQ): Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ - def __init__( - self, layer, name, act_order, len_parallel_layers, create_weight_orig, p, - compression_rate) -> None: + def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_orig, p) -> None: super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) self.float_input = None - self.quantized_input = None + self.quant_input = None self.index_computed = False self.p = p - self.compression_rate = compression_rate def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: @@ -206,9 +169,7 @@ def update_batch(self, module, input, current_layer): if isinstance(self.layer, SUPPORTED_CONV_OP): # Pick the correct unfoldNd class - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + if isinstance(self.layer, SUPPORTED_TCONV_OP): unfold_impl = unfoldNd.UnfoldTransposeNd else: unfold_impl = unfoldNd.UnfoldNd @@ -255,10 +216,10 @@ def update_batch(self, module, input, current_layer): else: self.float_input = torch.cat([self.float_input, inp_processed], dim=1) else: - if self.quantized_input is None: - self.quantized_input = inp_processed + if self.quant_input is None: + self.quant_input = inp_processed else: - self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) + self.quant_input = torch.cat([self.quant_input, inp_processed], dim=1) # If we are executing GPFQ with group of parallel layers, we keep track of how many forward # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException @@ -268,36 +229,35 @@ def update_batch(self, module, input, current_layer): raise StopFwdException def single_layer_update(self): - assert not self.layer.weight_quant.requires_quant_input, "Error: GPFQ does not support weight quantizers that require quantized inputs." + assert not self.layer.weight_quant.requires_quant_input, \ + "Error: GPFQ does not support weight quantizers that require quantized inputs." weight = self.layer.weight.data dev = weight.device dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance( - self.layer, - (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + if isinstance(self.layer, SUPPORTED_TCONV_OP): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] - if self.compression_rate > 0.0: - self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate) + self.float_input = self.float_input.to(dev) - self.quantized_input = self.quantized_input.to(dev) + self.quant_input = self.quant_input.to(dev) U = torch.zeros( weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) # We don't need full Hessian, we just need the diagonal - self.H_diag = self.quantized_input.transpose(2, 1).square().sum( - 2) # summing over Batch dimension + # Summing over batch dimension + H_diag = self.quant_input.transpose(2, 1).square().sum(2) permutation_list = [] for group_index in range(self.groups): if self.act_order: # Re-order Hessian_diagonal so that weights associated to # higher magnitude activations are quantized first - perm = torch.argsort(self.H_diag[group_index, :], descending=True) + perm = torch.argsort(H_diag[group_index, :], descending=True) else: # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + del H_diag for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( @@ -305,11 +265,10 @@ def single_layer_update(self): self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] norm = torch.linalg.norm( - self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2 + self.quant_input[group_index, :, permutation_list[group_index][t]], 2) ** 2 if norm > 0: q_arg = U[group_index].matmul( - self.quantized_input[group_index, :, - permutation_list[group_index][t]]) / norm + self.quant_input[group_index, :, permutation_list[group_index][t]]) / norm else: q_arg = torch.zeros_like(U[group_index, :, 0]) @@ -318,111 +277,197 @@ def single_layer_update(self): for group_index in range(self.groups): U[group_index] -= torch.matmul( q[group_index].unsqueeze(1), - self.quantized_input[group_index, :, - permutation_list[group_index][t]].unsqueeze(0)) + self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0)) del self.float_input - del self.quantized_input + del self.quant_input -class GPFA2Q(GPFQ): +class GPFQv2(GPFQ): + """ + Memory-efficient GPFQ formulation introduced in https://arxiv.org/pdf/2409.17092 + """ - def __init__( - self, - layer, - name, - act_order, - len_parallel_layers, - create_weight_orig, - accumulator_bit_width, - p, - compression_rate) -> None: - GPFQ.__init__( - self, - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - p=p, - compression_rate=compression_rate) - self.accumulator_bit_width = accumulator_bit_width - assert self.accumulator_bit_width is not None + def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_orig, p) -> None: + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig, p) + # Initialize covariance matrices. We need it in float32 to compute the inverse + # H = (\hat{X} \hat{X}^T)^{1/2} + self.H: Tensor = torch.zeros((self.groups, self.columns, self.columns), + device="cpu", + dtype=torch.float32) + # G = X \hat{X}^T + self.G: Tensor = torch.zeros((self.groups, self.columns, self.columns), + device="cpu", + dtype=torch.float32) + # buffer to speed-up GPU to CPU transfer + self.B: Tensor = torch.zeros((self.groups, self.columns, self.columns), + device="cpu", + dtype=torch.float32, + pin_memory=torch.cuda.is_available()) + self.nsamples = 0 + + assert torch_version >= version.parse('1.10'), "GPFQv2 requires torch 1.10 or higher" - def single_layer_update(self): - # raise error in case no quant-input is here - if self.quant_metadata is None: - raise ValueError('Expected self.quant_metadata to calculate L1-norm upper bound, but recevied None. ' + \ - 'Make sure that either the input to the model is a IntQuantTensor or the layer has an input quant enabled. ' \ - 'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \ - 'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.') + def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input + + # Update reference to current layer + current_layer.layer_names.add(self.name) + is_quant_enabled = module.weight_quant.is_quant_enabled + + inp = self.process_input(input) + inp = _unpack_quant_tensor(inp) + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + inp = inp.t() + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance(self.layer, SUPPORTED_TCONV_OP): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for inp in inp_by_group: + inp = unfold(inp) + inp = inp.transpose(1, 0) + inp = inp.flatten(1) + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + # NOTE: in the gpfq_mode context manager, we first collect quant inputs, then + # we collect float inputs for the same batch. We assume this pattern here, but + # will add a check just in case. + n = inp_processed.size(1) + inp_processed = math.sqrt(2 / n) * inp_processed.to(torch.float32) + + # if quant is not enabled, then it is the float input; if it is a float input + # then a quant input has already happened and we can update G + if not is_quant_enabled: + # Computing the normalized H matrix using CPU buffer + self.B.copy_(self.quant_input.bmm(inp_processed.transpose(2, 1))) + self.G += self.B + self.quant_input = None # NOTE: set back to None now that we've used it + else: + # Computing the normalized H matrix using CPU buffer + self.B.copy_(inp_processed.bmm(inp_processed.transpose(2, 1))) + self.H += self.B + # store the quantized input for computing the H matrix + assert self.quant_input is None + self.quant_input = inp_processed + + # If we are executing GPFQ with group of parallel layers, we keep track of how many forward + # we executed. Once we executed as many as the number of parallel_layers, we raise + # StopFwdException + current_layer.forward_count += 1 + if current_layer.forward_count == self.len_parallel_layers: + current_layer.forward_count = 0 + raise StopFwdException + + def _get_permutation_list(self, weight: Tensor): + permutation_list = [] + if self.act_order: + # We don't need full Hessian, we just need the diagonal + H_diag = self.quant_input.transpose(2, 1).square().sum(2) + for group_index in range(self.groups): + # Re-order Hessian_diagonal so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(H_diag[group_index, :], descending=True) + perm = perm.to(weight.device) + permutation_list.append(perm) + else: + for group_index in range(self.groups): + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(weight.shape[-1]), device=weight.device) + permutation_list.append(perm) + return permutation_list + + def single_layer_update(self, percdamp: float = 0.01): + assert not self.layer.weight_quant.requires_quant_input, \ + "Error: GPFQ does not support weight quantizers that require quantized inputs." weight = self.layer.weight.data dev = weight.device dtype = weight.dtype if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + if isinstance(self.layer, SUPPORTED_TCONV_OP): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] - U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) - if self.compression_rate > 0.0: - self.float_input, self.quantized_input = random_projection(self.float_input, self.quantized_input, self.compression_rate) - self.float_input = self.float_input.to(dev) - self.quantized_input = self.quantized_input.to(dev) - # get upper bound - input_bit_width = self.quant_metadata.bit_width - input_is_signed = self.quant_metadata.signed - T = get_upper_bound_on_l1_norm( - torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed) - s = self.layer.weight_quant.scale() - if s.ndim > 1: - s = s.view(self.groups, -1) # [Groups, OC/Groups] - - # initialize cumulative l1-norm - z = torch.zeros(weight.shape[:-1], device=dev) + # stablize H with a dampening factor and then square root the matrix + norms = torch.zeros((self.groups, self.columns), device=dev, dtype=dtype) + self.H = self.H.to(dev) + diag = torch.arange(self.columns, device='cpu') + for i in range(self.groups): + damp = percdamp * self.H[i].diag().mean() + self.H[i, diag, diag] += damp + norms[i] = self.H[i].diag() # set the norms post-dampening + eigvals, eigvecs = torch.linalg.eigh(self.H[i]) + eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite + self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t() + del eigvecs, eigvals, diag + self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function + + # Try/Except in case the inverse of H cannot be computed + try: + self.float_input = self.H.clone() # going to calculate H^{-1} here + for i in range(self.groups): + # from our matrix sqrt, we know G is symmetric and positive-definite, so we + # can use Cholesky decomposition as an efficient, numerically stable inverse + L = torch.linalg.cholesky(self.float_input[i]) + self.float_input[i] = torch.cholesky_inverse(L) + self.float_input = torch.bmm(self.float_input.to(dev), self.G.to(dev)) + del L # memory management + except LinAlgError: + warnings.warn( + f'Failed to compute the inverse of H for layer {self.name} ' + f'GPFQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H, self.G, self.B # memory management + + permutation_list = self._get_permutation_list(weight) - # We don't need full Hessian, we just need the diagonal - self.H_diag = self.quantized_input.transpose(2, 1).square().sum( - 2) # summing over Batch dimension - permutation_list = [] - for group_index in range(self.groups): - if self.act_order: - # Re-order Hessian_diagonal so that weights associated to - # higher magnitude activations are quantized first - perm = torch.argsort(self.H_diag[group_index, :], descending=True) - else: - # No permutation, permutation tensor is a ordered index - perm = torch.tensor(range(weight.shape[-1]), device=dev) - permutation_list.append(perm) + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, + dtype=dtype) # [Groups, OC/groups, Samples] for t in range(weight.shape[-1]): for group_index in range(self.groups): + i = permutation_list[group_index][t] U[group_index] += torch.matmul( - weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1), - self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( - 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] - norm = torch.linalg.norm( - self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2 + weight[group_index, :, i].unsqueeze(1), + self.float_input[group_index, :, i].unsqueeze(0), + ) # [OC/Groups, 1] * [1, INSHAPE[1]] + norm = norms[group_index, i] if norm > 0: - q_arg = U[group_index].matmul( - self.quantized_input[group_index, :, - permutation_list[group_index][t]]) / norm + q_arg = U[group_index].matmul(self.quant_input[group_index, :, i]) / norm else: q_arg = torch.zeros_like(U[group_index, :, 0]) - - max_q_arg = s * torch.clamp_min(T - z, 0.) - q_arg = q_arg.sign() * torch.clamp_max(q_arg.abs(), max_q_arg[group_index, :]) - weight[group_index, :, permutation_list[group_index][t]] = q_arg - q = self.get_quant_weights(t, 0, permutation_list) - z += q.abs() / s # increment cumulative l1-norm - + weight[group_index, :, i] = q_arg + q_groups = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): U[group_index] -= torch.matmul( - q[group_index].unsqueeze(1), - self.quantized_input[group_index, :, - permutation_list[group_index][t]].unsqueeze(0)) + q_groups[group_index].unsqueeze(1), + self.quant_input[group_index, :, permutation_list[group_index][t]].unsqueeze(0), + ) del self.float_input - del self.quantized_input + del self.quant_input diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index d8b436fc1..e71a273c3 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -21,16 +21,11 @@ from brevitas.graph.utils import is_conv_transposed import brevitas.nn as qnn from brevitas.quant_tensor import IntQuantTensor -from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO -SUPPORTED_CONV_OP = ( - qnn.QuantConv1d, - qnn.QuantConv2d, - qnn.QuantConv3d, - qnn.QuantConvTranspose1d, - qnn.QuantConvTranspose2d, - qnn.QuantConvTranspose3d) +SUPPORTED_TCONV_OP = (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d) + +SUPPORTED_CONV_OP = (qnn.QuantConv1d, qnn.QuantConv2d, qnn.QuantConv3d, *SUPPORTED_TCONV_OP) class StopFwdException(Exception): diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py deleted file mode 100644 index 1eafa2851..000000000 --- a/src/brevitas_examples/llm/llm_quant/gptq.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -""" - -import torch -from tqdm import tqdm - -from brevitas.graph.gptq import gptq_mode - - -@torch.no_grad() -def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None): - with gptq_mode(model, - use_quant_activations=False, - group_of_parallel_layers=group_of_parallel_layers, - act_order=act_order, - create_weight_orig=False) as gptq: - gptq_model = gptq.model - for _ in tqdm(range(gptq.num_layers)): - for inps in dataloader: - gptq_model(**inps) - gptq.update() diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py new file mode 100644 index 000000000..e2bfba989 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -0,0 +1,41 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +import torch +from tqdm import tqdm + +from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gptq import gptq_mode + + +@torch.no_grad() +def apply_gptq( + model, + dataloader, + act_order=True, + group_of_parallel_layers=None, + use_quant_activations=True, + create_weight_orig=False): + with gptq_mode(model, + act_order=act_order, + group_of_parallel_layers=group_of_parallel_layers, + use_quant_activations=use_quant_activations, + create_weight_orig=create_weight_orig) as gptq: + gptq_model = gptq.model + for _ in tqdm(range(gptq.num_layers)): + for inps in dataloader: + gptq_model(**inps) + gptq.update() + + +@torch.no_grad() +def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None): + with gpfq_mode(model, act_order=act_order, + group_of_parallel_layers=group_of_parallel_layers) as gpfq: + gpfq_model = gpfq.model + for _ in tqdm(range(gpfq.num_layers)): + for inps in dataloader: + gpfq_model(**inps) + gpfq.update() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index cf6f01895..c33de54c8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -4,7 +4,6 @@ """ import argparse -import re import sys from warnings import warn @@ -22,25 +21,21 @@ from brevitas.graph.quantize import layerwise_quantize from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers -from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction from brevitas_examples.llm.llm_quant.calibrate import apply_calibration from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization -from brevitas_examples.llm.llm_quant.eval import create_validation_dataloader -from brevitas_examples.llm.llm_quant.eval import model_eval from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode -from brevitas_examples.llm.llm_quant.gptq import apply_gptq +from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq +from brevitas_examples.llm.llm_quant.gpxq import apply_gptq from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers -from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import get_fx -from brevitas_examples.llm.llm_quant.run_utils import modify_dataloader def set_seed(seed): @@ -77,6 +72,8 @@ def model_export(model, ref_input, args): def validate(args): if not args.no_quantize: + if args.gptq and args.gpfq: + warn("Both GPTQ and GPFQ are enabled.") if args.export_target is not None: assert args.input_quant_format == 'int', "Only integer quantization supported for export currently." if args.export_target is not None and args.input_bit_width is not None: @@ -188,7 +185,7 @@ def main(args): float_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) remove_hooks(model) - print(f"Float perplexity ({args.dataset}): {float_ppl}") + print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") if require_fx: model = get_fx(model) @@ -259,7 +256,7 @@ def main(args): input_quant_format=args.input_quant_format, quantize_embedding=False) if not args.quantize_last_layer: - name_blacklist += ["lm_head"] + name_blacklist += ["lm_head", "embed_out"] model = layerwise_quantize( model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) # Tie back first/last layer weights in case they got untied @@ -283,9 +280,19 @@ def main(args): if args.gptq: print("Applying GPTQ...") - apply_gptq(model, calibration_loader) + apply_gptq( + model, + calibration_loader, + act_order=args.gpxq_act_order, + use_quant_activations=args.gpxq_use_quant_activations, + create_weight_orig=args.gpxq_create_weight_orig) print("GPTQ applied.") + if args.gpfq: + print("Applying GPFQ...") + apply_gpfq(model, calibration_loader, act_order=args.gpxq_act_order) + print("GPFQ applied.") + if args.bias_corr: print("Applying bias correction...") apply_bias_correction(model, calibration_loader) @@ -295,7 +302,7 @@ def main(args): print("Model eval...") quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - print(f"Quantized perplexity ({args.dataset}): {quant_ppl}") + print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") remove_hooks(model) if args.checkpoint_name is not None: @@ -433,6 +440,15 @@ def parse_args(args): parser.add_argument( '--quantize-last-layer', action='store_true', help='Quantize last nn.Linear layer.') parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') + parser.add_argument('--gpfq', action='store_true', help='Apply GPFQ.') + parser.add_argument( + '--gpxq-act-order', action='store_true', help='Apply GPXQ activation ordering.') + parser.add_argument( + '--gpxq-use-quant-activations', + action='store_true', + help='Use quantized activations in GPXQ.') + parser.add_argument( + '--gpxq-create-weight-orig', action='store_true', help='Create weight_orig in GPXQ.') parser.add_argument( '--act-calibration', action='store_true', help='Apply activation calibration.') parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.') diff --git a/tests/brevitas/graph/test_gpxq.py b/tests/brevitas/graph/test_gpxq.py index 4293b8582..33116332c 100644 --- a/tests/brevitas/graph/test_gpxq.py +++ b/tests/brevitas/graph/test_gpxq.py @@ -9,7 +9,9 @@ from torch.utils.data import DataLoader from torch.utils.data import TensorDataset +from brevitas.graph.gpfq import GPFQ from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gpfq import GPFQv2 from brevitas.graph.gptq import gptq_mode from .equalization_fixtures import * @@ -19,22 +21,16 @@ def apply_gpfq( calib_loader: DataLoader, model: nn.Module, act_order: bool, - use_quant_activations: bool = True, - accumulator_bit_width: int = 32, - a2q_layer_filter_fnc=lambda x: True): + use_quant_activations: bool, + gpfq_class: GPFQ): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - # use A2GPFQ if accumulator is less than 32 is specified - with gpfq_mode( - model, - use_quant_activations=use_quant_activations, - act_order=act_order, - use_gpfa2q=accumulator_bit_width < 32, - accumulator_bit_width=accumulator_bit_width, - a2q_layer_filter_fnc=a2q_layer_filter_fnc, - ) as gpfq: + with gpfq_mode(model, + use_quant_activations=use_quant_activations, + act_order=act_order, + gpfq_class=gpfq_class) as gpfq: gpfq_model = gpfq.model for _ in range(gpfq.num_layers): for _, (images, _) in enumerate(calib_loader): @@ -64,44 +60,24 @@ def apply_gptq( gptq.update() -def custom_layer_filter_fnc(layer: nn.Module) -> bool: - if isinstance(layer, nn.Conv2d) and layer.in_channels == 3: - return False - elif isinstance(layer, nn.ConvTranspose2d) and layer.in_channels == 3: - return False - return True - - -apply_gpxq_func_map = {"gpfq": apply_gpfq, "gptq": apply_gptq} +apply_gpxq_func_map = { + "gpfq": partial(apply_gpfq, gpfq_class=GPFQ), + "gpfq2": partial(apply_gpfq, gpfq_class=GPFQv2), + "gptq": apply_gptq} @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("use_quant_activations", [True, False]) -@pytest.mark.parametrize("acc_bit_width", [32, 24, 16, 12]) -@pytest.mark.parametrize("apply_gpxq_tuple", apply_gpxq_func_map.items()) -def test_toymodels( - toy_quant_model, act_order, use_quant_activations, acc_bit_width, apply_gpxq_tuple, - request): +@pytest.mark.parametrize( + "apply_gpxq_tuple", apply_gpxq_func_map.items(), ids=apply_gpxq_func_map.keys()) +def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq_tuple, request): test_id = request.node.callspec.id - input_quant = test_id.split('-')[1] - weight_quant = test_id.split('-')[2] - - if ('MXFloat' in input_quant or 'MXInt' in weight_quant) and acc_bit_width < 32: - pytest.skip("MX quant does not support accumulator-aware quantization.") torch.manual_seed(SEED) name, apply_gpxq = apply_gpxq_tuple - if (name == 'gptq' and acc_bit_width < 32): - pytest.skip("GPTQ does not support accumulator-aware quantization.") - - if name == 'gpfq': - filter_func = custom_layer_filter_fnc - apply_gpxq = partial( - apply_gpxq, accumulator_bit_width=acc_bit_width, a2q_layer_filter_fnc=filter_func) - model_class = toy_quant_model model = model_class() if 'mha' in test_id: @@ -113,8 +89,8 @@ def test_toymodels( dataset = TensorDataset(inp, inp) calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True) - if (name == 'gptq' and torch_version < version.parse('1.10')): - # GPTQ usage of linalg_cholesky() is not compatible with torch 1.9.1 and below + if ((name == 'gptq' or name == 'gpfq2') and torch_version < version.parse('1.10')): + # Usage of linalg_cholesky() is not compatible with torch 1.9.1 and below with pytest.raises(AssertionError): apply_gpxq( calib_loader=calib_loader, @@ -122,17 +98,6 @@ def test_toymodels( act_order=act_order, use_quant_activations=use_quant_activations) - elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or - input_quant == 'None'): - # GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will - # raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will - # happen when `use_quant_activations=False` or when the input to a model is not quantized - with pytest.raises(ValueError): - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) else: apply_gpxq( calib_loader=calib_loader,