Skip to content

Commit

Permalink
Support internlm3 awq quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Oct 11, 2024
1 parent fd33b59 commit 14bb81d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
3 changes: 3 additions & 0 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'InternLM2ForCausalLM': 'InternLM2DecoderLayer',
'InternLM3ForCausalLM': 'InternLM3DecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'Qwen2ForCausalLM': 'Qwen2DecoderLayer',
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
Expand All @@ -30,6 +31,7 @@
NORM_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMRMSNorm',
'InternLM2ForCausalLM': 'InternLM2RMSNorm',
'InternLM3ForCausalLM': 'InternLM3RMSNorm',
'QWenLMHeadModel': 'RMSNorm',
'Qwen2ForCausalLM': 'Qwen2RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
Expand All @@ -45,6 +47,7 @@
HEAD_NAME_MAP = {
'InternLMForCausalLM': 'lm_head',
'InternLM2ForCausalLM': 'output',
'InternLM3ForCausalLM': 'output',
'QWenLMHeadModel': 'lm_head',
'Qwen2ForCausalLM': 'lm_head',
'BaiChuanForCausalLM': 'lm_head', # Baichuan 7B
Expand Down
26 changes: 22 additions & 4 deletions lmdeploy/lite/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
'attention_norm': ['attention.wqkv'],
'ffn_norm': ['feed_forward.w1', 'feed_forward.w3']
},
'InternLM3DecoderLayer': {
'attention_norm': ['attention.wqkv', 'attention.wq'],
'ffn_norm': ['feed_forward.w1', 'feed_forward.w3']
},
'QWenBlock': {
'ln_1': ['attn.c_attn'],
'ln_2': ['mlp.w1', 'mlp.w2']
Expand Down Expand Up @@ -54,6 +58,9 @@
'InternLM2DecoderLayer': {
'feed_forward.w3': ['feed_forward.w2']
},
'InternLM3DecoderLayer': {
'feed_forward.w3': ['feed_forward.w2']
},
'QWenBlock': {
'attn.c_attn': ['attn.c_proj'],
'mlp.w1': ['mlp.c_proj']
Expand Down Expand Up @@ -269,18 +276,29 @@ def smooth_layers(layers,

for l_name, layer in layers.items():
layer.to(device)
submodule_names = [name for name, _ in layer.named_modules()]
for ln_name, fc_names in norm2fcs.items():
a_name = [f'{l_name}.{n}' for n in fc_names][0]
a_name = [
f'{l_name}.{n}' for n in fc_names if n in submodule_names
][0]

ln = layer.get_submodule(ln_name)
fcs = [layer.get_submodule(n) for n in fc_names]
fcs = [
layer.get_submodule(n) for n in fc_names
if n in submodule_names
]
smooth_ln_fcs(ln, fcs, a_scales[a_name], group_size)

for f_name, fc_names in fc2fcs.items():
a_name = [f'{l_name}.{n}' for n in fc_names][0]
a_name = [
f'{l_name}.{n}' for n in fc_names if n in submodule_names
][0]

fc = layer.get_submodule(f_name)
fcs = [layer.get_submodule(n) for n in fc_names]
fcs = [
layer.get_submodule(n) for n in fc_names
if n in submodule_names
]

smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size)

Expand Down

0 comments on commit 14bb81d

Please sign in to comment.