From dbec9569ddb63c3fc3236fab216a5d31beaeabe3 Mon Sep 17 00:00:00 2001 From: dleunji Date: Thu, 14 Nov 2024 21:20:16 +0900 Subject: [PATCH 1/2] Create INT8 KV Cache on Qserve --- examples/llama/convert_checkpoint.py | 6 ++- tensorrt_llm/models/llama/convert.py | 65 ++++++++++++++++++++++++++- tensorrt_llm/models/llama/model.py | 9 +++- tensorrt_llm/quantization/quantize.py | 2 +- 4 files changed, 77 insertions(+), 5 deletions(-) mode change 100644 => 100755 examples/llama/convert_checkpoint.py mode change 100644 => 100755 tensorrt_llm/models/llama/convert.py mode change 100644 => 100755 tensorrt_llm/models/llama/model.py mode change 100644 => 100755 tensorrt_llm/quantization/quantize.py diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py old mode 100644 new mode 100755 index 25aa340ed..17e0870ec --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -439,7 +439,7 @@ def convert_and_save_hf(args): # llava_llama needs its own defined config. logger.warning("AutoConfig cannot load the huggingface config.") - if args.smoothquant is not None or args.int8_kv_cache: + if (args.smoothquant is not None or args.int8_kv_cache) and not args.use_qserve: assert not args.load_by_shard, "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported" mapping = Mapping(world_size=world_size, tp_size=args.tp_size, @@ -474,6 +474,10 @@ def convert_and_save_rank(args, rank): args.dtype, mapping=mapping, quant_config=quant_config, + device='cpu' if args.load_model_on_cpu else 'cuda', + calib_dataset=args.calib_dataset, + calib_batches=args.calib_size, + calib_max_seq_length=args.calib_max_seq_length, load_by_shard=load_by_shard, **override_fields, ) diff --git a/tensorrt_llm/models/llama/convert.py b/tensorrt_llm/models/llama/convert.py old mode 100644 new mode 100755 index d99e57e4a..803f2ea5d --- a/tensorrt_llm/models/llama/convert.py +++ b/tensorrt_llm/models/llama/convert.py @@ -43,7 +43,7 @@ retrieved_layer_index_from_name, smooth_gemm, smooth_gemm_fc1_gate, split, split_matrix_tp, split_qkv_bias_tp, split_qkv_tp) -from ..modeling_utils import PretrainedConfig +from ..modeling_utils import PretrainedConfig, QuantConfig from .config import LLaMAConfig @@ -1921,7 +1921,16 @@ def process_and_assign_weight(v: List[torch.Tensor], return weights -def load_weights_from_lmquant(lmquant_ckpt_path: str, config: LLaMAConfig): +def load_weights_from_lmquant( + lmquant_ckpt_path: str, + config: LLaMAConfig, + quant_config: QuantConfig, + hf_model_dir: str, + device: str = "cuda", + calib_dataset: str = "cnn_dailymail", + calib_batches: int = 512, + calib_max_seq_length: int = 512, + ): logger.info( 'Loading weights from lmquant torch checkpoint for QServe W4A8 inference...' ) @@ -1945,6 +1954,40 @@ def load_weights_from_lmquant(lmquant_ckpt_path: str, config: LLaMAConfig): quant_params = torch.load(lmquant_ckpt_path + '/scale.pt', map_location='cpu') + int8_kv_cache = quant_config.kv_cache_quant_algo == QuantAlgo.INT8 + + act_range = {} + if int8_kv_cache: + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_dir, + device_map=device if device != 'cpu' else 'cpu', + torch_dtype='auto', + trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + hf_model_dir, + trust_remote_code=True, + use_fast=False, + padding_side='left') + + dataset = load_calib_dataset(calib_dataset) + + if calib_batches == -1: + calib_batches = len(dataset) + + model_prefix, layer_prefix, param_name_map = get_prefix_and_param_name_map( + config.architecture, use_safetensors=True) + + lmquant_keys = fake_quant_weights.keys() + for name, param in hf_model.named_parameters(): + if name in lmquant_keys: + param.data = fake_quant_weights[name].data.to(device) + + act_range = capture_activation_range(hf_model, + tokenizer, + dataset, + num_samples=calib_batches, + seq_len=calib_max_seq_length) + def load(key): if 'zero' in key: v = quant_params[key] @@ -2082,6 +2125,24 @@ def process_weight_and_params(v: List[torch.Tensor], tllm_prex: str): ] weights.update( process_weight_and_params(qkv, f'{tllm_prex}.attention.qkv')) + + if int8_kv_cache: + act_range_prefix = f'{model_prefix}.{layer_prefix}.{layer_idx}.' + qkv_y = torch.cat([ + # act_range.get(act_range_prefix + + # f'{param_name_map["attention.qkv"]}.q_proj')["y"], + act_range.get(act_range_prefix + + f'{param_name_map["attention.qkv"]}.k_proj')["y"], + act_range.get(act_range_prefix + + f'{param_name_map["attention.qkv"]}.v_proj')["y"] + ], dim=0) + + int8_kv_scales = qkv_y.max() / 127. + + kv_cache_weights = {} + + kv_cache_weights[f'{tllm_prex}.attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape([1]) + weights.update(kv_cache_weights) # 4.2 attention.dense v = [ diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py old mode 100644 new mode 100755 index 0387cd0ad..34ad03809 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -36,6 +36,7 @@ load_weights_from_hf_safetensors, load_weights_from_lmquant, load_weights_from_meta_ckpt) +from tensorrt_llm.quantization import QuantAlgo class LLaMADecoderLayer(Module): @@ -331,6 +332,10 @@ def from_hugging_face( dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, + device: str = 'cuda', + calib_dataset: str = 'cnn_dailymail', + calib_batches: int = 512, + calib_max_seq_length: int = 512, **kwargs): ''' Create a LLaMAForCausalLM object from give parameters ''' @@ -413,7 +418,9 @@ def from_hugging_face( if quant_config.quant_mode.is_int4_weight_only(): weights = load_weights_from_gptq(quant_ckpt_path, config) elif quant_config.quant_mode.is_qserve_w4a8(): - weights = load_weights_from_lmquant(quant_ckpt_path, config) + weights = load_weights_from_lmquant(quant_ckpt_path, + config, quant_config, hf_model_dir, + device, calib_dataset, calib_batches, calib_max_seq_length) else: raise ValueError( "quant_ckpt_path should be specified only for GPTQ or QServe" diff --git a/tensorrt_llm/quantization/quantize.py b/tensorrt_llm/quantization/quantize.py old mode 100644 new mode 100755 index 8f383f2c2..94e166b88 --- a/tensorrt_llm/quantization/quantize.py +++ b/tensorrt_llm/quantization/quantize.py @@ -521,7 +521,7 @@ def qserve_quantize(model, quant_config: QuantConfig): def kv_cache_quantize(model): for name, module in model.named_modules(): if isinstance(module, - (Attention, SmoothQuantAttention, Fp8RowwiseAttention)): + (Attention, SmoothQuantAttention, Fp8RowwiseAttention, QServeAttention)): module.kv_cache_scaling_factor = Parameter(shape=(1, ), dtype='float32') return model From 6a6ec86e9cf74dba8042aac0c268900d227593db Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Thu, 14 Nov 2024 23:21:38 +0900 Subject: [PATCH 2/2] fix unused module --- tensorrt_llm/models/llama/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index 34ad03809..e6dfde2cf 100755 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -36,7 +36,6 @@ load_weights_from_hf_safetensors, load_weights_from_lmquant, load_weights_from_meta_ckpt) -from tensorrt_llm.quantization import QuantAlgo class LLaMADecoderLayer(Module):