-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for Qwen #129
Open
cjw-d
wants to merge
17
commits into
Tencent:main
Choose a base branch
from
cjw-d:qwen
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add support for Qwen #129
Changes from 15 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
a08a34e
add logN-scaling
cjw-d 7f22312
add dynamic ntk
cjw-d 7240c2d
add qwen config
cjw-d da9d7b4
add qwen conversion script
cjw-d 24e00f0
modify conversion script
cjw-d af5f1ce
remove old conversion script
cjw-d e261047
add QwenTokenizer
cjw-d 97cdb6a
modify logN-scaling
cjw-d ede088f
fix bug
cjw-d 98508ac
modify logN-scaling
cjw-d a80884e
modify QwenTokenizer
cjw-d 121eb3a
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
cjw-d 6b4eaff
add qwen_special_tokens_map
cjw-d 68c4ad5
add qwen config
cjw-d c2527db
refactor
cjw-d d999642
refact rope
cjw-d 7258fa2
fix file format
cjw-d File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"emb_size": 2048, | ||
"feedforward_size": 5504, | ||
"hidden_size": 2048, | ||
"hidden_act": "silu", | ||
"heads_num": 16, | ||
"layers_num": 24, | ||
"dropout": 0.0, | ||
"data_processor": "lm", | ||
"max_seq_length": 8192, | ||
"embedding": ["word"], | ||
"remove_transformer_bias": true, | ||
"remove_attention_bias": false, | ||
"remove_embedding_layernorm": true, | ||
"rotary_position_embedding": true, | ||
"encoder": "transformer", | ||
"feed_forward": "gated", | ||
"mask": "causal", | ||
"layernorm_positioning": "pre", | ||
"layernorm": "rms", | ||
"target": ["lm"], | ||
"use_logn_attn": true, | ||
"use_dynamic_ntk": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"emb_size": 4096, | ||
"feedforward_size": 11008, | ||
"hidden_size": 4096, | ||
"hidden_act": "silu", | ||
"heads_num": 32, | ||
"layers_num": 32, | ||
"dropout": 0.0, | ||
"data_processor": "lm", | ||
"max_seq_length": 8192, | ||
"embedding": ["word"], | ||
"remove_transformer_bias": true, | ||
"remove_attention_bias": false, | ||
"remove_embedding_layernorm": true, | ||
"rotary_position_embedding": true, | ||
"encoder": "transformer", | ||
"feed_forward": "gated", | ||
"mask": "causal", | ||
"layernorm_positioning": "pre", | ||
"layernorm": "rms", | ||
"target": ["lm"], | ||
"use_logn_attn": true, | ||
"use_dynamic_ntk": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"pad_token": "<|endoftext|>", | ||
"unk_token": "<|UNK|>", | ||
"cls_token": "<|im_start|>", | ||
"sep_token": "<|im_end|>", | ||
"mask_token": "<|MASK|>", | ||
"sentinel_token": "<|extra_0|>" | ||
} |
54 changes: 54 additions & 0 deletions
54
scripts/convert_qwen_from_huggingface_to_tencentpretrain.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import argparse | ||
import os | ||
import collections | ||
from safetensors.torch import load_file | ||
import torch | ||
|
||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument("--input_model_path", type=str, default="models/input_model.bin", | ||
help=".") | ||
parser.add_argument("--output_model_path", type=str, default="models/output_model.bin", | ||
help=".") | ||
parser.add_argument("--layers_num", type=int, default=12) | ||
|
||
args = parser.parse_args() | ||
|
||
input_model = {} | ||
for file_name in os.listdir(args.input_model_path): | ||
if os.path.splitext(file_name)[-1][1:] == "safetensors": | ||
dict = load_file(filename=os.path.join(args.input_model_path, file_name)) | ||
input_model.update(dict) | ||
|
||
output_model = collections.OrderedDict() | ||
emb_size = input_model["transformer.h." + str(0) + ".attn.c_attn.weight"].shape[1] | ||
|
||
output_model["embedding.word.embedding.weight"] = input_model["transformer.wte.weight"] | ||
|
||
|
||
for i in range(args.layers_num): | ||
for j in range(3): | ||
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers." + str(j) + ".weight"] = \ | ||
input_model["transformer.h." + str(i) + ".attn.c_attn.weight"][j*emb_size:(j+1)*emb_size, :] | ||
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers." + str(j) + ".bias"] = \ | ||
input_model["transformer.h." + str(i) + ".attn.c_attn.bias"][j*emb_size:(j+1)*emb_size] | ||
|
||
output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \ | ||
input_model["transformer.h." + str(i) + ".attn.c_proj.weight"] | ||
|
||
output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \ | ||
input_model["transformer.h." + str(i) + ".ln_1.weight"] | ||
|
||
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"] = \ | ||
input_model["transformer.h." + str(i) + ".mlp.w2.weight"] | ||
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \ | ||
input_model["transformer.h." + str(i) + ".mlp.w1.weight"] | ||
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \ | ||
input_model["transformer.h." + str(i) + ".mlp.c_proj.weight"] | ||
|
||
output_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \ | ||
input_model["transformer.h." + str(i) + ".ln_2.weight"] | ||
|
||
output_model["encoder.layer_norm.weight"] = input_model["transformer.ln_f.weight"] | ||
output_model["target.lm.output_layer.weight"] = input_model["lm_head.weight"] | ||
|
||
torch.save(output_model, args.output_model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import torch | ||
import torch.nn as nn | ||
from tencentpretrain import mpu | ||
from tencentpretrain.utils.rope import apply_rotary_emb | ||
from tencentpretrain.utils.rope import apply_rotary_emb, apply_rotary_emb_with_ntk | ||
from tencentpretrain.utils.lora import LoraLinear | ||
|
||
|
||
|
@@ -26,7 +26,7 @@ class MultiHeadedAttention(nn.Module): | |
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf | ||
""" | ||
|
||
def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_num, dropout, has_bias=True, with_scale=True, | ||
def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_num, dropout, max_seq_length, has_bias=True, has_attention_bias=None, with_scale=True, | ||
lora_params=None, layer_number=None): | ||
super(MultiHeadedAttention, self).__init__() | ||
self.heads_num = heads_num | ||
|
@@ -41,6 +41,15 @@ def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_n | |
assert heads_num % self.local_kv_heads_num == 0, "heads_num should be divisible by n_local_kv_heads" | ||
self.repeat_num = self.heads_num // self.local_kv_heads_num | ||
|
||
self.max_seq_length = max_seq_length | ||
|
||
logn_list = [ | ||
math.log(i, self.max_seq_length) if i > self.max_seq_length else 1 | ||
for i in range(1, 32768) | ||
] | ||
logn_tensor = torch.tensor(logn_list)[None, None, :, None] | ||
self.register_buffer("logn_tensor", logn_tensor, persistent=False) | ||
|
||
if lora_params is not None: | ||
|
||
self.linear_layers = nn.ModuleList( | ||
|
@@ -53,8 +62,9 @@ def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_n | |
lora_dropout=lora_params['lora_dropout'], bias=has_bias)] | ||
) | ||
else: | ||
has_attention_bias = has_attention_bias if has_attention_bias is not None else has_bias | ||
self.linear_layers = nn.ModuleList( | ||
[nn.Linear(hidden_size, self.inner_hidden_size, bias=has_bias) if i==0 else nn.Linear(hidden_size, self.kv_embed_dim, bias=has_bias) for i in range(3)] | ||
[nn.Linear(hidden_size, self.inner_hidden_size, bias=has_attention_bias) if i==0 else nn.Linear(hidden_size, self.kv_embed_dim, bias=has_attention_bias) for i in range(3)] | ||
) | ||
self.dropout = nn.Dropout(dropout) | ||
self.final_linear = nn.Linear(self.inner_hidden_size, hidden_size, bias=has_bias) | ||
|
@@ -66,7 +76,7 @@ def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_n | |
self.layer_number = None | ||
|
||
def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None, | ||
freqs_cis=None, alibi=None): | ||
freqs_cis=None, alibi=None, use_logn_attn=False, use_dynamic_ntk=False): | ||
""" | ||
Args: | ||
key: [batch_size x seq_length x hidden_size] | ||
|
@@ -103,9 +113,18 @@ def unshape(x): | |
key = repeat_kv(key, self.repeat_num).transpose(1, 2) | ||
value = repeat_kv(value, self.repeat_num).transpose(1, 2) | ||
|
||
|
||
if freqs_cis is not None: | ||
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis) | ||
if use_dynamic_ntk: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里建议可以封装一下 |
||
query, key = apply_rotary_emb_with_ntk(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis) | ||
else: | ||
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis) | ||
|
||
key_size = key.size(2) | ||
if key_size > self.max_seq_length and use_logn_attn and not self.training: | ||
seq_start = key_size - query.size(2) | ||
seq_end = key_size | ||
logn_tensor = self.logn_tensor[:, :, seq_start:seq_end, :].type_as(query) | ||
query = query * logn_tensor.expand_as(query) | ||
|
||
scores = torch.matmul(query, key.transpose(-2, -1)) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import torch | ||
import torch.nn as nn | ||
from tencentpretrain.utils.rope import get_ntk_alpha, update_freqs_cis | ||
from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention, ParallelMultiHeadedAttention | ||
from tencentpretrain.layers import * | ||
|
||
|
@@ -16,10 +17,18 @@ def __init__(self, args, layer_number=None): | |
self.relative_position_embedding = args.relative_position_embedding | ||
self.rotary_position_embedding = args.rotary_position_embedding | ||
self.has_residual_attention = args.has_residual_attention | ||
self.use_logn_attn = args.use_logn_attn | ||
self.use_dynamic_ntk = args.use_dynamic_ntk | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 训练不需要ntk,只有推理需要,如果只考虑训练的话这里是否有可能简化? |
||
|
||
if self.relative_position_embedding: | ||
self.relative_pos_emb = args.relative_pos_emb | ||
if self.rotary_position_embedding: | ||
self.freqs_cis = args.freqs_cis | ||
if self.use_dynamic_ntk: | ||
self.max_seq_length = args.max_seq_length | ||
self.attention_head_size = args.hidden_size // args.heads_num | ||
self.seq_len_cached = 0 | ||
self.ntk_alpha_cached = 1.0 | ||
|
||
if hasattr(args, "attention_head_size"): | ||
attention_head_size = args.attention_head_size | ||
|
@@ -32,6 +41,7 @@ def __init__(self, args, layer_number=None): | |
local_kv_heads_num = args.heads_num | ||
|
||
has_bias = bool(1 - args.remove_transformer_bias) | ||
has_attention_bias = bool(1 - args.remove_attention_bias) if hasattr(args, "remove_attention_bias") else has_bias | ||
with_scale = bool(1 - args.remove_attention_scale) | ||
|
||
# Multi-headed self-attention. | ||
|
@@ -40,7 +50,7 @@ def __init__(self, args, layer_number=None): | |
lora_params = args.lora_params | ||
|
||
self.self_attn = MultiHeadedAttention( | ||
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias, | ||
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, args.max_seq_length, has_bias=has_bias, has_attention_bias=has_attention_bias, | ||
with_scale=with_scale, lora_params=lora_params, layer_number=layer_number | ||
) | ||
self.dropout_1 = nn.Dropout(args.dropout) | ||
|
@@ -77,22 +87,32 @@ def forward(self, *inputs): | |
else: | ||
position_bias = None | ||
|
||
if self.rotary_position_embedding: | ||
if self.rotary_position_embedding and self.use_dynamic_ntk: | ||
ntk_alpha = get_ntk_alpha(seq_length, self.max_seq_length) if seq_length > self.max_seq_length else 1.0 | ||
if seq_length > self.seq_len_cached or ntk_alpha != self.ntk_alpha_cached: | ||
self.freqs_cis = update_freqs_cis(self.attention_head_size, seq_length * 2, ntk_alpha=ntk_alpha) | ||
self.seq_len_cached = seq_length * 2 | ||
self.ntk_alpha_cached = ntk_alpha | ||
|
||
if self.rotary_position_embedding and not self.use_dynamic_ntk: | ||
freqs_cis = self.freqs_cis[:seq_length].to(hidden.device) | ||
elif self.use_dynamic_ntk: | ||
cos, sin = self.freqs_cis | ||
freqs_cis = [cos[:, :seq_length], sin[:, :seq_length]] | ||
else: | ||
freqs_cis = None | ||
|
||
if self.layernorm_positioning == "post": | ||
inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, self.has_residual_attention, | ||
prev_attn, freqs_cis) | ||
prev_attn, freqs_cis, use_logn_attn=self.use_logn_attn, use_dynamic_ntk=self.use_dynamic_ntk) | ||
inter = self.dropout_1(inter) | ||
inter = self.layer_norm_1(inter + hidden) | ||
output = self.dropout_2(self.feed_forward(inter)) | ||
output = self.layer_norm_2(output + inter) | ||
else: | ||
inter = self.layer_norm_1(hidden) | ||
inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, self.has_residual_attention, | ||
prev_attn, freqs_cis) | ||
prev_attn, freqs_cis, use_logn_attn=self.use_logn_attn, use_dynamic_ntk=self.use_dynamic_ntk) | ||
inter = self.dropout_1(inter) | ||
hidden = hidden + inter | ||
output = self.layer_norm_2(hidden) | ||
|
@@ -281,14 +301,14 @@ def __init__(self, args): | |
lora_params = args.lora_params | ||
|
||
self.self_attn = MultiHeadedAttention( | ||
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias, | ||
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, args.max_seq_length, has_bias=has_bias, | ||
with_scale=with_scale, lora_params=lora_params | ||
) | ||
self.dropout_1 = nn.Dropout(args.dropout) | ||
|
||
# Multi-headed context-attention. | ||
self.context_attn = MultiHeadedAttention( | ||
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias, | ||
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, args.max_seq_length, has_bias=has_bias, | ||
with_scale=with_scale, lora_params=lora_params | ||
) | ||
self.dropout_2 = nn.Dropout(args.dropout) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议同时提供 互相转换脚本
convert_qwen_from_tencentpretrain_to_huggingface.py