From b9eecf7ed2a3bd2ea72658d4752feee70f862187 Mon Sep 17 00:00:00 2001 From: nickfraser Date: Tue, 20 Aug 2024 14:38:28 +0100 Subject: [PATCH] Feat: Update LLM entry-point (#987) * Feat (example/llm): Added fnuz/ocp args * Docs (example/llm): typo fix in args description * Feat (example/llm): Add zero bias to linear layers when doing bias correction. * Fix (example/llm): Remove unnecessary forward pass * Feat (example/llm): Leveraged data utils from optimum-amd integration * Feat (example/llm): Load KV Cache to correct `dtype` * Feat (example/llm): Added progress bar to bias correction * Fix (example/llm): Fix formatting. * feat (example/llm): Switched `ln_affine_merge` to use HF's tracer * feat (example/llm): decompose `quantize_model` into component parts. * Fix (example/llm): Assert that TorchQCDQ export & Eval aren't both enabled. * feat (example/llm): Added option not to quantize the last linear layer * Fix precommit * Fix (example/llm): disable embedded lookup quantization --- .../llm/llm_quant/bias_corr.py | 3 +- .../llm/llm_quant/data_utils.py | 108 +++++++++++++++ src/brevitas_examples/llm/llm_quant/eval.py | 2 +- .../llm/llm_quant/ln_affine_merge.py | 6 +- .../llm/llm_quant/prepare_for_quantize.py | 16 +++ src/brevitas_examples/llm/main.py | 131 ++++++++++++++---- 6 files changed, 231 insertions(+), 35 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/data_utils.py diff --git a/src/brevitas_examples/llm/llm_quant/bias_corr.py b/src/brevitas_examples/llm/llm_quant/bias_corr.py index dc603f8a3..049ae0baa 100644 --- a/src/brevitas_examples/llm/llm_quant/bias_corr.py +++ b/src/brevitas_examples/llm/llm_quant/bias_corr.py @@ -4,6 +4,7 @@ """ import torch +from tqdm import tqdm from brevitas.graph.calibrate import bias_correction_mode @@ -11,5 +12,5 @@ @torch.no_grad() def apply_bias_correction(model, dataloader): with bias_correction_mode(model): - for inps in dataloader: + for inps in tqdm(dataloader): model(**inps) diff --git a/src/brevitas_examples/llm/llm_quant/data_utils.py b/src/brevitas_examples/llm/llm_quant/data_utils.py new file mode 100644 index 000000000..5375fcddf --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/data_utils.py @@ -0,0 +1,108 @@ +""" +Adapted from https://github.com/huggingface/optimum-amd, released under the following LICENSE: + +MIT License + +Copyright (c) 2023 Hugging Face + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import random +from typing import Any, Optional, Union + +import numpy as np +from optimum.amd.brevitas.data_utils import DatasetToDevice +from optimum.amd.brevitas.data_utils import get_c4 +from optimum.amd.brevitas.data_utils import get_wikitext2 +from optimum.utils.normalized_config import NormalizedConfigManager +import torch +from transformers import AutoConfig + + +def get_dataset_for_model( + model_name_or_path: str, + dataset_name: str, + tokenizer: Any, + nsamples: int = 128, + seqlen: int = 2048, + seed: int = 0, + split: str = "train", + fuse_sequences: bool = True, + require_fx: bool = False, + device: Optional[Union[str, torch.device]] = None, +): + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + get_dataset_map = { + "wikitext2": get_wikitext2, + "c4": get_c4,} + if split not in ["train", "validation"]: + raise ValueError(f"The split need to be 'train' or 'validation' but found {split}") + if dataset_name not in get_dataset_map: + raise ValueError( + f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}") + get_dataset_fn = get_dataset_map[dataset_name] + + data = get_dataset_fn( + tokenizer=tokenizer, + nsamples=nsamples, + seqlen=seqlen, + split=split, + fuse_sequences=fuse_sequences, + seed=seed) + + # In case the dataset is loaded to be used with an fx.GraphModule, we need to add empty past_key_values inputs in the dataset. + if require_fx: + config = AutoConfig.from_pretrained(model_name_or_path) + + normalized_config_class = NormalizedConfigManager.get_normalized_config_class( + config.model_type) + normalized_config = normalized_config_class(config) + + num_heads = normalized_config.num_attention_heads + if hasattr(normalized_config, "num_key_value_heads"): + num_kv_heads = normalized_config.num_key_value_heads + else: + num_kv_heads = num_heads + head_dim = normalized_config.hidden_size // num_heads + num_layers = normalized_config.num_layers + + for sample in data: + sample["past_key_values"] = tuple(( + torch.zeros( + 1, + num_kv_heads, + 0, + head_dim, + device=sample["input_ids"].device, + dtype=sample["input_ids"].dtype), + torch.zeros( + 1, + num_kv_heads, + 0, + head_dim, + device=sample["input_ids"].device, + dtype=sample["input_ids"].dtype), + ) for _ in range(num_layers)) + + data = DatasetToDevice(data, device=device) + + return data diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index 271a5b36e..0691e5cfa 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -34,12 +34,12 @@ def create_validation_dataloader(data, seqlen, device): @torch.no_grad() def model_eval(model, valenc, seqlen): nsamples = len(valenc) - dev = next(iter(model.parameters())).device with torch.no_grad(): nlls = [] for inps in valenc: lm_logits = model(**inps)['logits'] shift_logits = lm_logits[:, :-1, :].contiguous() + dev = shift_logits.device shift_labels = inps['input_ids'][:, 1:].to(dev) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) diff --git a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py index 37aa8d5d3..7ac39347f 100644 --- a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py +++ b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py @@ -6,7 +6,6 @@ import torch from torch import nn -from brevitas.fx import value_trace from brevitas.graph.equalize import _is_reshaping_op from brevitas.graph.equalize import _is_scale_invariant_module from brevitas.graph.utils import get_module @@ -84,9 +83,8 @@ def merge_layernorm_affine_params(graph_model): @torch.no_grad() -def apply_layernorm_affine_merge(model, dtype, ref_kwargs): +def apply_layernorm_affine_merge(graph_model, dtype): # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply merging, and then cast back - with cast_to_float32(model, dtype): - graph_model = value_trace(model, value_args=ref_kwargs) + with cast_to_float32(graph_model, dtype): merge_layernorm_affine_params(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index 2a9505227..d22b2eff1 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -1,5 +1,6 @@ import warnings +import torch from transformers.models.opt.modeling_opt import OPTAttention from brevitas.graph import ModuleToModuleByClass @@ -21,3 +22,18 @@ def replace_mha_with_quantizable_layers(model, dtype): for rewriter in rewriters: model = rewriter.apply(model) return model + + +@torch.no_grad() +def add_zero_bias_to_linear(model: torch.nn.Module) -> torch.nn.Module: + for name, module in model.named_modules(): + if type(module) == torch.nn.Linear: + if module.bias is None: + module.register_parameter( + "bias", + torch.nn.Parameter( + torch.zeros((module.weight.shape[0],), + device=module.weight.device, + dtype=module.weight.dtype)), + ) + return model diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 5237c31c7..b8c19a9d7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -9,6 +9,7 @@ import numpy as np from optimum.amd.brevitas.accelerate_utils import offload_model from optimum.amd.brevitas.accelerate_utils import remove_hooks +from optimum.amd.brevitas.data_utils import compute_perplexity from optimum.exporters.onnx import onnx_export_from_model import torch from transformers import AutoModelForCausalLM @@ -16,12 +17,14 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.common.generative.quantize import quantize_model +from brevitas.graph.quantize import layerwise_quantize +from brevitas_examples.common.generative.quantize import generate_quant_maps +from brevitas_examples.common.generative.quantize import generate_quantizers +from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction from brevitas_examples.llm.llm_quant.calibrate import apply_calibration -from brevitas_examples.llm.llm_quant.data import get_c4 -from brevitas_examples.llm.llm_quant.data import get_wikitext2 +from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization from brevitas_examples.llm.llm_quant.eval import create_validation_dataloader @@ -30,6 +33,7 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gptq import apply_gptq from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge +from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 @@ -47,7 +51,13 @@ parser.add_argument( '--nsamples', type=int, default=128, help='Number of calibration data samples. Default: 128.') parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length. Default: 2048.') -parser.add_argument('--eval', action='store_true', help='Eval model PPL on C4.') +parser.add_argument('--eval', action='store_true', help='Eval model PPL on the chosen Dataset.') +parser.add_argument( + '--dataset', + type=str, + choices=['wikitext2', 'c4'], + default='wikitext2', + help='Dataset to use for quantization (default: %(default)s)') parser.add_argument('--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') parser.add_argument( '--weight-param-method', @@ -135,8 +145,6 @@ help='Group size for per_group input quantization. Default: 64.') parser.add_argument( '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') -parser.add_argument( - '--quantize-embedding', action='store_true', help='Quantize first nn.Embedding layer.') parser.add_argument( '--quantize-last-layer', action='store_true', help='Quantize last nn.Linear layer.') parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') @@ -174,6 +182,15 @@ 'sharded_torchmlir_group_weight', 'sharded_packed_torchmlir_group_weight'], help='Model export.') +parser.add_argument( + '--checkpoint-name', + type=str, + default=None, + help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)") +add_bool_arg( + parser, 'use-ocp', default=False, help='Use OCP format for float quantization. Default: False') +add_bool_arg( + parser, 'use-fnuz', default=True, help='Use FNUZ format for float quantization. Default: True') def set_seed(seed): @@ -274,19 +291,56 @@ def main(): with CastFloat16ToFloat32(): apply_awq(model, awq_results) - calibration_loader = get_wikitext2( - nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0) - val_data = get_wikitext2( - nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0) + require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge else False + fuse_sequences = False + + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=require_fx, + device=None, + fuse_sequences=fuse_sequences, + ) + + validation_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="validation", + seed=args.seed, + require_fx=require_fx, + device=None, + fuse_sequences=fuse_sequences, + ) + device = next(iter(model.parameters())).device - val_data = create_validation_dataloader(val_data, args.seqlen, device) print("Data loaded.") + if args.eval: + assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" + print("Float model eval...") + model = offload_model(model) + ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + remove_hooks(model) + print(f"Float perplexity ({args.dataset}): {ppl}") + + if require_fx: + model = get_fx(model) + # Apply LN affine merging before inserting MHA layers # since currently there is support only for merging into Linear if args.ln_affine_merge: print("Apply LN affine merge...") - apply_layernorm_affine_merge(model, dtype, ref_kwargs={'input_ids': calibration_loader[0]}) + apply_layernorm_affine_merge(model, dtype) print("LN affine merge applied.") # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing @@ -296,11 +350,6 @@ def main(): model = replace_mha_with_quantizable_layers(model, dtype) print("Replacing done.") - if args.weight_equalization or args.act_equalization == 'fx': - model = get_fx(model) - calibration_loader = modify_dataloader(args.model, calibration_loader, dtype=dtype) - val_data = modify_dataloader(args.model, val_data, dtype=dtype) - if args.weight_equalization: print("Apply weight equalization...") # In case of float16 model, we need to offload to account for missing ops @@ -317,39 +366,58 @@ def main(): remove_hooks(model) if not args.no_quantize: + name_blacklist = [] print("Applying model quantization...") - model = quantize_model( - model, + linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant = generate_quantizers( dtype=dtype, - weight_quant_format=args.weight_quant_format, - weight_quant_type=args.weight_quant_type, weight_bit_width=args.weight_bit_width, weight_param_method=args.weight_param_method, weight_scale_precision=args.weight_scale_precision, + weight_quant_type=args.weight_quant_type, weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, quantize_weight_zero_point=args.quantize_weight_zero_point, + weight_quant_format=args.weight_quant_format, input_bit_width=args.input_bit_width, - input_quant_type=args.input_quant_type, input_quant_format=args.input_quant_format, - input_param_method=args.input_param_method, input_scale_precision=args.input_scale_precision, input_scale_type=args.input_scale_type, + input_param_method=args.input_param_method, + input_quant_type=args.input_quant_type, input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, - quantize_embedding=args.quantize_embedding) + use_ocp=args.use_ocp, + use_fnuz=args.use_fnuz, + device=device) + layer_map = generate_quant_maps( + linear_input_quant=linear_input_quant, + weight_quant=weight_quant, + input_quant=input_quant, + q_scaled_quant=q_scaled_quant, + k_transposed_quant=k_transposed_quant, + v_quant=v_quant, + attn_output_weights_quant=attn_output_weights_quant, + dtype=dtype, + device=device, + input_quant_format=args.input_quant_format, + quantize_embedding=False) + if not args.quantize_last_layer: + name_blacklist += ["lm_head"] + model = layerwise_quantize( + model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) # Tie back first/last layer weights in case they got untied print("Model quantization applied.") # If any equalization has taken places, the embedding layer and the fully connected one are # not tied anymore, and they need to be treated as standalone, separate layers. # In all other cases we can tie them back so to preserve memory. - if args.act_equalization is None and not args.weight_equalization: + if args.act_equalization is None and not require_fx: model.tie_weights() - with cast_to_float32(model, dtype): - model(**calibration_loader[0]) + if args.bias_corr: + model = add_zero_bias_to_linear(model) + model = offload_model(model) if args.act_calibration: @@ -369,10 +437,15 @@ def main(): if args.eval: print("Model eval...") - ppl = model_eval(model, val_data, args.seqlen) - print(f"C4 perplexity: {ppl}") + ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Quantized perplexity ({args.dataset}): {ppl}") remove_hooks(model) + if args.checkpoint_name is not None: + print(f"Saving checkpoint to {args.checkpoint_name}") + torch.save(model.state_dict(), args.checkpoint_name) + if args.export_target: print(f"Export to {args.export_target}") # Currently we always export on CPU with a float32 container to avoid float16 CPU errors