From cd1449d20aae36252ff2161a9c21019d108cd8da Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 Nov 2024 09:47:05 -0500 Subject: [PATCH] [GPTQ] Iterative Parameter Updating (#863) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement iterative parameter updating Signed-off-by: Kyle Sayers * [Bugfix] Use weight parameter of linear layer (#836) * use weight parameter of linear layer * add weight attribute check Signed-off-by: Kyle Sayers * [Bugfix] Rename files to remove colons (#846) * rename files to remove colons Signed-off-by: Kyle Sayers * [Bugfix] Workaround tied tensors bug (#659) * load offload state dict * add test * remove merge duplication * prepare to fix tie_word_embeddings * add full tests * patch second bug * comment out failing tests, point to next pr * link to issue * accomodate offloaded models in test * add back passing test * WIP * add error if not in expected list * apply style * update passing failing list * add shared tensors tests * clean up * add comment with link * make failing tests a todo * Remove failing tests * explicitly set safe_serialization * separate out gpu tests, apply style --------- Co-authored-by: Kyle Sayers Co-authored-by: Dipika Sikka Signed-off-by: Kyle Sayers * only untie word embeddings (#839) Signed-off-by: Kyle Sayers * check for config hidden size (#840) Signed-off-by: Kyle Sayers * Use float32 for Hessian dtype (#847) * use float32 for hessian dtype * explicitly set inp dtype as well * float precision for obcq hessian Signed-off-by: Kyle Sayers * GPTQ: Depreciate non-sequential update option (#762) * remove from gptq, apply style * remove instances of sequential_update argument in GPTQ tests * update examples * update example tests * documentation, remove from example * apply style * revert back to auto type * apply style --------- Co-authored-by: Dipika Sikka Signed-off-by: Kyle Sayers * Typehint nits (#826) Signed-off-by: Kyle Sayers * [ DOC ] Remove version restrictions in W8A8 exmaple (#849) The latest compressored-tensor 0.8.0 removed some API, https://github.com/neuralmagic/compressed-tensors/pull/156/files If installed the older llmcompressor from pip, it would throw the error like: ``` ImportError: cannot import name 'update_layer_weight_quant_params' from 'compressed_tensors.quantization' ``` Signed-off-by: Kyle Sayers * Fix inconsistence (#80) Use group strategy with 128 group size instead of channel Co-authored-by: Dipika Sikka Signed-off-by: Kyle Sayers * 2of4 Signed-off-by: Kyle Sayers * revert change to unrelated example Signed-off-by: Kyle Sayers * rename test file Signed-off-by: Kyle Sayers * fix fwd func call (#845) Signed-off-by: Kyle Sayers --------- Signed-off-by: Kyle Sayers Co-authored-by: Kyle Sayers Co-authored-by: Kyle Sayers Co-authored-by: Dipika Sikka Co-authored-by: Jincheng Miao Co-authored-by: 黄石 Signed-off-by: Kyle Sayers * cover all 3.9-3.12 in commit testing (#864) Co-authored-by: dhuangnm Signed-off-by: Kyle Sayers * Add marlin-24 recipe/configs for e2e testing (#866) * add marlin-24 recipe/configs for e2e testing * update Signed-off-by: Kyle Sayers * [Bugfix] onload during sparsity calculation (#862) * onload during sparsity calculation * fix sparsity --------- Co-authored-by: Dipika Signed-off-by: Kyle Sayers * Fix HFTrainer overloads (#869) * add missing arguments Signed-off-by: Kyle Sayers * names Signed-off-by: Kyle Sayers * style Signed-off-by: Kyle Sayers * named args all around Signed-off-by: Kyle Sayers --------- Signed-off-by: Kyle Sayers Co-authored-by: Dipika Sikka Signed-off-by: Kyle Sayers * Support Model Offloading Tied Tensors Patch (#872) * update parameter of offloaded modules Signed-off-by: Kyle Sayers * in place function Signed-off-by: Kyle Sayers --------- Signed-off-by: Kyle Sayers * add advice about dealing with non-invertable hessians (#875) Signed-off-by: Kyle Sayers * seed commit workflow (#877) * seed commit workflow Signed-off-by: andy-neuma * tickle Signed-off-by: andy-neuma * let's give it a try Signed-off-by: andy-neuma * whitespace Signed-off-by: andy-neuma * delete unneeded workflow Signed-off-by: andy-neuma * adjust trigger Signed-off-by: andy-neuma --------- Signed-off-by: andy-neuma Co-authored-by: andy-neuma Signed-off-by: Kyle Sayers * [Observer Restructure]: Add Observers; Add `calibration` and `frozen` steps to `QuantizationModifier` (#837) * update functioon * wip * clean-up; fix imports * clean-up * more clean-up * bug fix * update for kvcache * get kv_cache to work * docstring * fix comment * fix condition for dynamic * update * update tests * add observer tests * add flake8 skip * apply updated mse fixes * fix import * Update src/llmcompressor/modifiers/quantization/calibration.py Co-authored-by: Kyle Sayers * Update src/llmcompressor/modifiers/quantization/calibration.py Co-authored-by: Kyle Sayers * PR comments * clean-up * move hook check to observer call * update * separate out calibration step --------- Co-authored-by: Kyle Sayers Signed-off-by: Kyle Sayers * WIP, observer Signed-off-by: Kyle Sayers * use minmax observer Signed-off-by: Kyle Sayers * Bugfix get observer from name (#883) Signed-off-by: Rahul Tuli * BugFix: Fix Sparsity Reload Testing (#882) * fix * fix remaining test cases * add comments * fix Signed-off-by: Kyle Sayers * Use custom unique test names for e2e tests (#892) * Include `testconfig_path` in parsed config data Signed-off-by: Domenic Barbuzzi * Use custom unique names for e2e tests Signed-off-by: Domenic Barbuzzi --------- Signed-off-by: Domenic Barbuzzi Signed-off-by: Kyle Sayers * Revert "Use custom unique test names for e2e tests (#892)" (#893) This reverts commit 10facf2633e58778e82d5d53bd661d970c610258. Signed-off-by: Kyle Sayers * Move config["testconfig_path"] assignment (#895) * Use custom unique test names for e2e tests (#892) * Include `testconfig_path` in parsed config data Signed-off-by: Domenic Barbuzzi * Use custom unique names for e2e tests Signed-off-by: Domenic Barbuzzi --------- Signed-off-by: Domenic Barbuzzi * Revert "Use custom unique test names for e2e tests (#892)" (#893) This reverts commit 10facf2633e58778e82d5d53bd661d970c610258. Signed-off-by: Domenic Barbuzzi * Move config["testconfig_path"] assignment Signed-off-by: Domenic Barbuzzi * Use a function name generator for e2e test names Signed-off-by: Domenic Barbuzzi --------- Signed-off-by: Domenic Barbuzzi Co-authored-by: Dipika Sikka Signed-off-by: Kyle Sayers * cap accelerate version to avoid bug (#897) Signed-off-by: Kyle Sayers * Fix observing offloaded weight (#896) * load weight within onloading Signed-off-by: Kyle Sayers * remove moving activation to execution device, since this is already done since activation calibration always happens within forward pass Signed-off-by: Kyle Sayers --------- Signed-off-by: Kyle Sayers Co-authored-by: Dipika Sikka Signed-off-by: Kyle Sayers * Update image in README.md (#861) Co-authored-by: Dipika Sikka Signed-off-by: Kyle Sayers * use user-specified observer Signed-off-by: Kyle Sayers --------- Signed-off-by: Kyle Sayers Signed-off-by: andy-neuma Signed-off-by: Rahul Tuli Signed-off-by: Domenic Barbuzzi Co-authored-by: Kyle Sayers Co-authored-by: Kyle Sayers Co-authored-by: Dipika Sikka Co-authored-by: Jincheng Miao Co-authored-by: 黄石 Co-authored-by: dhuangnm <74931910+dhuangnm@users.noreply.github.com> Co-authored-by: dhuangnm Co-authored-by: Andy Linfoot <78757007+andy-neuma@users.noreply.github.com> Co-authored-by: andy-neuma Co-authored-by: Rahul Tuli Co-authored-by: Domenic Barbuzzi Co-authored-by: Michael Goin --- .../quantization/gptq/utils/gptq_wrapper.py | 65 ++++++++++--------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index bc8b43284..542f64bab 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -1,11 +1,7 @@ import time from typing import Tuple -from compressed_tensors.quantization import ( - ActivationOrdering, - QuantizationArgs, - QuantizationStrategy, -) +from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import fake_quantize from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD @@ -100,20 +96,27 @@ def compress( diagonal norm """ args_loc = "quantization_scheme.weights" - weight_quant_args = getattr_chain(self.layer, args_loc, None) - if weight_quant_args is None: + quant_args = getattr_chain(self.layer, args_loc, None) + if quant_args is None: logger.debug(f"Skipping unquantized layer {self.name}...") return if is_module_offloaded(self.layer): self.layer._hf_hook.pre_forward(self.layer) - strategy = weight_quant_args.strategy - actorder = weight_quant_args.actorder + strategy = quant_args.strategy + actorder = quant_args.actorder final_shape = self.layer.weight.shape final_dtype = self.layer.weight.dtype W = self.layer.weight.data.clone() + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + quant_args.observer, + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + # standardize shape and dtype if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) @@ -127,26 +130,28 @@ def compress( # mapping from column index to group index g_idx = ( torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size + // quant_args.group_size ) if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) + scale, zero_point = observer(W, g_idx=None) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) + scale, zero_point = observer(W, g_idx=None) W, self.H, perm = self._apply_activation_ordering(W, self.H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point + else: + scale, zero_point = observer(W, g_idx=None) + else: + scale, zero_point = observer(W, g_idx=None) # sparsity mask sparsity = tensor_sparsity(W) @@ -212,16 +217,28 @@ def compress( q, scale[:, 0], zero_point[:, 0], - weight_quant_args, + quant_args, ) elif strategy == QuantizationStrategy.GROUP: # get the group index for the current column column_idx = i1 + i group_index = g_idx[column_idx] + # update quantization parameters to reflect changes + # resulting from previous blocks + if ( + actorder != ActivationOrdering.WEIGHT + and column_idx % quant_args.group_size == 0 + ): + _scale, _zero_point = observer.get_qparams_along_dim( + W[:, g_idx == group_index], dim=0 + ) + scale[:, group_index] = _scale[:, 0] + zero_point[:, group_index] = _zero_point[:, 0] + # Since we're only applying quantization to a slice, this # ends up being a channelwise application - altered_qargs = copy(weight_quant_args) + altered_qargs = copy(quant_args) altered_qargs.strategy = QuantizationStrategy.CHANNEL q = fake_quantize( q, @@ -279,6 +296,9 @@ def compress( W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) + update_parameter_data(self.layer, scale, "weight_scale") + update_parameter_data(self.layer, zero_point, "weight_zero_point") + # This is a bit hacky, but FSDP updates only work if we change # the weight in place, clone() or direct assignment won't work self.layer.weight -= self.layer.weight @@ -296,19 +316,6 @@ def free(self): delattr(self, "H") super().free() - def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): - """ - Update layer quantization parameters with potentially permuted weight - - :param args: quantization arguments - :param W: weight to calculate quantization parameters from - """ - observer = args.get_observer() - observer = Observer.load_from_registry(observer, quantization_args=args) - _scale, _zero_point = observer(W, g_idx=None) - update_parameter_data(self.layer, _scale, "weight_scale") - update_parameter_data(self.layer, _zero_point, "weight_zero_point") - def _apply_activation_ordering( self, W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: