diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 555d0c297..372480f6f 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -25,7 +25,11 @@ 'MGMLlamaForCausalLM': 'LlamaDecoderLayer', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2DecoderLayer', 'Phi3ForCausalLM': 'Phi3DecoderLayer', - 'ChatGLMForConditionalGeneration': 'GLMBlock' + 'ChatGLMForConditionalGeneration': 'GLMBlock', + 'MixtralForCausalLM': 'MixtralDecoderLayer', + 'Qwen2VLForConditionalGeneration': 'Qwen2VLDecoderLayer', + 'MistralForCausalLM': 'MistralDecoderLayer', + 'InternLM3MoEForCausalLM': 'InternLM3MoEDecoderLayer', } NORM_TYPE_MAP = { @@ -41,7 +45,11 @@ 'MGMLlamaForCausalLM': 'LlamaRMSNorm', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2RMSNorm', 'Phi3ForCausalLM': 'Phi3RMSNorm', - 'ChatGLMForConditionalGeneration': 'RMSNorm' + 'ChatGLMForConditionalGeneration': 'RMSNorm', + 'MixtralForCausalLM': 'MixtralRMSNorm', + 'Qwen2VLForConditionalGeneration': 'Qwen2RMSNorm', + 'MistralForCausalLM': 'MistralRMSNorm', + 'InternLM3MoEForCausalLM': 'InternLM3MoERMSNorm', } HEAD_NAME_MAP = { @@ -57,7 +65,11 @@ 'MGMLlamaForCausalLM': 'lm_head', # mini gemini 'InternLMXComposer2ForCausalLM': 'output', 'Phi3ForCausalLM': 'lm_head', - 'ChatGLMForConditionalGeneration': 'output_layer' + 'ChatGLMForConditionalGeneration': 'output_layer', + 'MixtralForCausalLM': 'lm_head', + 'Qwen2VLForConditionalGeneration': 'lm_head', + 'MistralForCausalLM': 'lm_head', + 'InternLM3MoEForCausalLM': 'output', } @@ -185,7 +197,7 @@ def calibrate(model: str, trust_remote_code=True) model = load_hf_from_pretrained(model, - torch_dtype=torch.float16, + torch_dtype=torch.bfloat16, trust_remote_code=True) vl_model = None elif model_type == 'vlm': diff --git a/lmdeploy/lite/apis/gptq.py b/lmdeploy/lite/apis/gptq.py index 12b88a52c..3d8049bde 100644 --- a/lmdeploy/lite/apis/gptq.py +++ b/lmdeploy/lite/apis/gptq.py @@ -54,6 +54,10 @@ def auto_gptq(model: str, SUPPORTED_MODELS.append('internlm2') GPTQ_CAUSAL_LM_MODEL_MAP.update(dict(internlm2=InternLM2GPTQForCausalLM)) + from ..modeling.internlm3_moe_gptq import InternLM3MoEGPTQForCausalLM + SUPPORTED_MODELS.append('InternLM3_MoE') + GPTQ_CAUSAL_LM_MODEL_MAP.update(dict(InternLM3_MoE=InternLM3MoEGPTQForCausalLM)) + pretrained_model_dir = model quantized_model_dir = work_dir @@ -85,6 +89,7 @@ def auto_gptq(model: str, # the model will always be loaded into CPU memory model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, + torch_dtype=torch.bfloat16, revision=revision, trust_remote_code=True) diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py index eca7d46d7..fb701329a 100644 --- a/lmdeploy/lite/quantization/awq.py +++ b/lmdeploy/lite/quantization/awq.py @@ -43,7 +43,29 @@ 'GLMBlock': { 'input_layernorm': ['self_attention.query_key_value'], 'post_attention_layernorm': ['mlp.dense_h_to_4h'] - } + }, + 'MixtralDecoderLayer': { + 'input_layernorm': + ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], + 'post_attention_layernorm': + ['block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3'] + }, + 'InternLM3MoEDecoderLayer': { + 'attention_norm': + ['attention.wqkv'], + 'ffn_norm': + ['feed_forward.experts.fused_w1w3'] + }, + 'Qwen2VLDecoderLayer': { + 'input_layernorm': + ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], + 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] + }, + 'MistralDecoderLayer': { + 'input_layernorm': + ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], + 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] + }, } FC_FCS_MAP = { @@ -80,6 +102,22 @@ 'GLMBlock': { # 'self_attention.query_key_value': ['self_attention.dense'] # 'mlp.dense_h_to_4h': ['mlp.dense_4h_to_h'] + }, + 'MixtralDecoderLayer': { + 'self_attn.v_proj': ['self_attn.o_proj'], + 'block_sparse_moe.experts.{i}.w3': ['block_sparse_moe.experts.{i}.w2'] + }, + 'InternLM3MoEDecoderLayer': { + 'feed_forward.experts.fused_w1w3': + ['feed_forward.experts.w2'] + }, + 'Qwen2VLDecoderLayer': { + 'self_attn.v_proj': ['self_attn.o_proj'], + 'mlp.up_proj': ['mlp.down_proj'] + }, + 'MistralDecoderLayer': { + 'self_attn.v_proj': ['self_attn.o_proj'], + 'mlp.up_proj': ['mlp.down_proj'] } }