Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for Qwen #129

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
24 changes: 24 additions & 0 deletions models/qwen/1_8b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"emb_size": 2048,
"feedforward_size": 5504,
"hidden_size": 2048,
"hidden_act": "silu",
"heads_num": 16,
"layers_num": 24,
"dropout": 0.0,
"data_processor": "lm",
"max_seq_length": 8192,
"embedding": ["word"],
"remove_transformer_bias": true,
"remove_attention_bias": false,
"remove_embedding_layernorm": true,
"rotary_position_embedding": true,
"encoder": "transformer",
"feed_forward": "gated",
"mask": "causal",
"layernorm_positioning": "pre",
"layernorm": "rms",
"target": ["lm"],
"use_logn_attn": true,
"use_dynamic_ntk": true
}
24 changes: 24 additions & 0 deletions models/qwen/7b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"emb_size": 4096,
"feedforward_size": 11008,
"hidden_size": 4096,
"hidden_act": "silu",
"heads_num": 32,
"layers_num": 32,
"dropout": 0.0,
"data_processor": "lm",
"max_seq_length": 8192,
"embedding": ["word"],
"remove_transformer_bias": true,
"remove_attention_bias": false,
"remove_embedding_layernorm": true,
"rotary_position_embedding": true,
"encoder": "transformer",
"feed_forward": "gated",
"mask": "causal",
"layernorm_positioning": "pre",
"layernorm": "rms",
"target": ["lm"],
"use_logn_attn": true,
"use_dynamic_ntk": true
}
8 changes: 8 additions & 0 deletions models/qwen_special_tokens_map.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"pad_token": "<|endoftext|>",
"unk_token": "<|UNK|>",
"cls_token": "<|im_start|>",
"sep_token": "<|im_end|>",
"mask_token": "<|MASK|>",
"sentinel_token": "<|extra_0|>"
}
54 changes: 54 additions & 0 deletions scripts/convert_qwen_from_huggingface_to_tencentpretrain.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议同时提供 互相转换脚本
convert_qwen_from_tencentpretrain_to_huggingface.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import argparse
import os
import collections
from safetensors.torch import load_file
import torch

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/input_model.bin",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/output_model.bin",
help=".")
parser.add_argument("--layers_num", type=int, default=12)

args = parser.parse_args()

input_model = {}
for file_name in os.listdir(args.input_model_path):
if os.path.splitext(file_name)[-1][1:] == "safetensors":
dict = load_file(filename=os.path.join(args.input_model_path, file_name))
input_model.update(dict)

output_model = collections.OrderedDict()
emb_size = input_model["transformer.h." + str(0) + ".attn.c_attn.weight"].shape[1]

output_model["embedding.word.embedding.weight"] = input_model["transformer.wte.weight"]


for i in range(args.layers_num):
for j in range(3):
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers." + str(j) + ".weight"] = \
input_model["transformer.h." + str(i) + ".attn.c_attn.weight"][j*emb_size:(j+1)*emb_size, :]
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers." + str(j) + ".bias"] = \
input_model["transformer.h." + str(i) + ".attn.c_attn.bias"][j*emb_size:(j+1)*emb_size]

output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \
input_model["transformer.h." + str(i) + ".attn.c_proj.weight"]

output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \
input_model["transformer.h." + str(i) + ".ln_1.weight"]

output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"] = \
input_model["transformer.h." + str(i) + ".mlp.w2.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \
input_model["transformer.h." + str(i) + ".mlp.w1.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \
input_model["transformer.h." + str(i) + ".mlp.c_proj.weight"]

output_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \
input_model["transformer.h." + str(i) + ".ln_2.weight"]

output_model["encoder.layer_norm.weight"] = input_model["transformer.ln_f.weight"]
output_model["target.lm.output_layer.weight"] = input_model["lm_head.weight"]

torch.save(output_model, args.output_model_path)
5 changes: 4 additions & 1 deletion scripts/generate_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def top_k_top_p_filtering(logits, top_k, top_p):
if args.tokenizer.sp_model is not None:
generated_sentence = args.tokenizer.sp_model.decode(tokens)
else:
generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens))
tokens = args.tokenizer.convert_ids_to_tokens(tokens)
if hasattr(args.tokenizer, "convert_tokens_to_string"):
tokens = args.tokenizer.convert_tokens_to_string(tokens)
generated_sentence = "".join(tokens)

f.write(generated_sentence)
31 changes: 25 additions & 6 deletions tencentpretrain/layers/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
from tencentpretrain import mpu
from tencentpretrain.utils.rope import apply_rotary_emb
from tencentpretrain.utils.rope import apply_rotary_emb, apply_rotary_emb_with_ntk
from tencentpretrain.utils.lora import LoraLinear


Expand All @@ -26,7 +26,7 @@ class MultiHeadedAttention(nn.Module):
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf
"""

def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_num, dropout, has_bias=True, with_scale=True,
def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_num, dropout, max_seq_length, has_bias=True, has_attention_bias=None, with_scale=True,
lora_params=None, layer_number=None):
super(MultiHeadedAttention, self).__init__()
self.heads_num = heads_num
Expand All @@ -41,6 +41,15 @@ def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_n
assert heads_num % self.local_kv_heads_num == 0, "heads_num should be divisible by n_local_kv_heads"
self.repeat_num = self.heads_num // self.local_kv_heads_num

self.max_seq_length = max_seq_length

logn_list = [
math.log(i, self.max_seq_length) if i > self.max_seq_length else 1
for i in range(1, 32768)
]
logn_tensor = torch.tensor(logn_list)[None, None, :, None]
self.register_buffer("logn_tensor", logn_tensor, persistent=False)

if lora_params is not None:

self.linear_layers = nn.ModuleList(
Expand All @@ -53,8 +62,9 @@ def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_n
lora_dropout=lora_params['lora_dropout'], bias=has_bias)]
)
else:
has_attention_bias = has_attention_bias if has_attention_bias is not None else has_bias
self.linear_layers = nn.ModuleList(
[nn.Linear(hidden_size, self.inner_hidden_size, bias=has_bias) if i==0 else nn.Linear(hidden_size, self.kv_embed_dim, bias=has_bias) for i in range(3)]
[nn.Linear(hidden_size, self.inner_hidden_size, bias=has_attention_bias) if i==0 else nn.Linear(hidden_size, self.kv_embed_dim, bias=has_attention_bias) for i in range(3)]
)
self.dropout = nn.Dropout(dropout)
self.final_linear = nn.Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
Expand All @@ -66,7 +76,7 @@ def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_n
self.layer_number = None

def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None,
freqs_cis=None, alibi=None):
freqs_cis=None, alibi=None, use_logn_attn=False, use_dynamic_ntk=False):
"""
Args:
key: [batch_size x seq_length x hidden_size]
Expand Down Expand Up @@ -103,9 +113,18 @@ def unshape(x):
key = repeat_kv(key, self.repeat_num).transpose(1, 2)
value = repeat_kv(value, self.repeat_num).transpose(1, 2)


if freqs_cis is not None:
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis)
if use_dynamic_ntk:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议可以封装一下

query, key = apply_rotary_emb_with_ntk(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis)
else:
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis)

key_size = key.size(2)
if key_size > self.max_seq_length and use_logn_attn and not self.training:
seq_start = key_size - query.size(2)
seq_end = key_size
logn_tensor = self.logn_tensor[:, :, seq_start:seq_end, :].type_as(query)
query = query * logn_tensor.expand_as(query)

scores = torch.matmul(query, key.transpose(-2, -1))

Expand Down
32 changes: 26 additions & 6 deletions tencentpretrain/layers/transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from tencentpretrain.utils.rope import get_ntk_alpha, update_freqs_cis
from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention, ParallelMultiHeadedAttention
from tencentpretrain.layers import *

Expand All @@ -16,10 +17,18 @@ def __init__(self, args, layer_number=None):
self.relative_position_embedding = args.relative_position_embedding
self.rotary_position_embedding = args.rotary_position_embedding
self.has_residual_attention = args.has_residual_attention
self.use_logn_attn = args.use_logn_attn
self.use_dynamic_ntk = args.use_dynamic_ntk
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练不需要ntk,只有推理需要,如果只考虑训练的话这里是否有可能简化?


if self.relative_position_embedding:
self.relative_pos_emb = args.relative_pos_emb
if self.rotary_position_embedding:
self.freqs_cis = args.freqs_cis
if self.use_dynamic_ntk:
self.max_seq_length = args.max_seq_length
self.attention_head_size = args.hidden_size // args.heads_num
self.seq_len_cached = 0
self.ntk_alpha_cached = 1.0

if hasattr(args, "attention_head_size"):
attention_head_size = args.attention_head_size
Expand All @@ -32,6 +41,7 @@ def __init__(self, args, layer_number=None):
local_kv_heads_num = args.heads_num

has_bias = bool(1 - args.remove_transformer_bias)
has_attention_bias = bool(1 - args.remove_attention_bias) if hasattr(args, "remove_attention_bias") else has_bias
with_scale = bool(1 - args.remove_attention_scale)

# Multi-headed self-attention.
Expand All @@ -40,7 +50,7 @@ def __init__(self, args, layer_number=None):
lora_params = args.lora_params

self.self_attn = MultiHeadedAttention(
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, args.max_seq_length, has_bias=has_bias, has_attention_bias=has_attention_bias,
with_scale=with_scale, lora_params=lora_params, layer_number=layer_number
)
self.dropout_1 = nn.Dropout(args.dropout)
Expand Down Expand Up @@ -77,22 +87,32 @@ def forward(self, *inputs):
else:
position_bias = None

if self.rotary_position_embedding:
if self.rotary_position_embedding and self.use_dynamic_ntk:
ntk_alpha = get_ntk_alpha(seq_length, self.max_seq_length) if seq_length > self.max_seq_length else 1.0
if seq_length > self.seq_len_cached or ntk_alpha != self.ntk_alpha_cached:
self.freqs_cis = update_freqs_cis(self.attention_head_size, seq_length * 2, ntk_alpha=ntk_alpha)
self.seq_len_cached = seq_length * 2
self.ntk_alpha_cached = ntk_alpha

if self.rotary_position_embedding and not self.use_dynamic_ntk:
freqs_cis = self.freqs_cis[:seq_length].to(hidden.device)
elif self.use_dynamic_ntk:
cos, sin = self.freqs_cis
freqs_cis = [cos[:, :seq_length], sin[:, :seq_length]]
else:
freqs_cis = None

if self.layernorm_positioning == "post":
inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, self.has_residual_attention,
prev_attn, freqs_cis)
prev_attn, freqs_cis, use_logn_attn=self.use_logn_attn, use_dynamic_ntk=self.use_dynamic_ntk)
inter = self.dropout_1(inter)
inter = self.layer_norm_1(inter + hidden)
output = self.dropout_2(self.feed_forward(inter))
output = self.layer_norm_2(output + inter)
else:
inter = self.layer_norm_1(hidden)
inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, self.has_residual_attention,
prev_attn, freqs_cis)
prev_attn, freqs_cis, use_logn_attn=self.use_logn_attn, use_dynamic_ntk=self.use_dynamic_ntk)
inter = self.dropout_1(inter)
hidden = hidden + inter
output = self.layer_norm_2(hidden)
Expand Down Expand Up @@ -281,14 +301,14 @@ def __init__(self, args):
lora_params = args.lora_params

self.self_attn = MultiHeadedAttention(
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, args.max_seq_length, has_bias=has_bias,
with_scale=with_scale, lora_params=lora_params
)
self.dropout_1 = nn.Dropout(args.dropout)

# Multi-headed context-attention.
self.context_attn = MultiHeadedAttention(
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, args.max_seq_length, has_bias=has_bias,
with_scale=with_scale, lora_params=lora_params
)
self.dropout_2 = nn.Dropout(args.dropout)
Expand Down
6 changes: 5 additions & 1 deletion tencentpretrain/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def model_opts(parser):
help="whether use alibi position embedding.")
parser.add_argument("--layer_number_scale", action="store_true",
help="whether use layer number scaling.")
parser.add_argument("--use_logn_attn", action="store_true",
help="whether use logn scaling.")
parser.add_argument("--use_dynamic_ntk", action="store_true",
help="whether use dynamic ntk.")

vision_opts(parser)
audio_opts(parser)
Expand Down Expand Up @@ -176,7 +180,7 @@ def infer_opts(parser):

def tokenizer_opts(parser):
parser.add_argument("--tokenizer", choices=["bert", "bpe", "char", "space", "xlmroberta", "image", "text_image",
"virtual", "hfpretrained"], default="bert",
"virtual", "hfpretrained", "qwen"], default="bert",
help="Specify the tokenizer."
"Original Google BERT uses bert tokenizer."
"Char tokenizer segments sentences into characters."
Expand Down
3 changes: 2 additions & 1 deletion tencentpretrain/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

str2tokenizer = {"char": CharTokenizer, "space": SpaceTokenizer, "bert": BertTokenizer,
"bpe": BPETokenizer, "xlmroberta": XLMRobertaTokenizer, "image": ImageTokenizer,
"text_image": TextImageTokenizer, "virtual": VirtualTokenizer, "hfpretrained": HFPreTrainedTokenizer}
"text_image": TextImageTokenizer, "virtual": VirtualTokenizer, "hfpretrained": HFPreTrainedTokenizer,
"qwen": QwenTokenizer}
str2dataset = {"bert": BertDataset, "lm": LmDataset, "mlm": MlmDataset,
"bilm": BilmDataset, "albert": AlbertDataset, "mt": MtDataset,
"t5": T5Dataset, "gsg": GsgDataset, "bart": BartDataset,
Expand Down
51 changes: 51 additions & 0 deletions tencentpretrain/utils/rope.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import math
from typing import Tuple

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
Expand All @@ -8,6 +9,24 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

def get_ntk_alpha(true_seq_len, max_seq_length):
context_value = math.log(true_seq_len / max_seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
return ntk_alpha

def update_freqs_cis(dim: int, end: int, theta: float = 10000.0, ntk_alpha: float = 1.0):
theta = theta * ntk_alpha ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
seq = torch.arange(end, device=freqs.device)
freqs = torch.outer(seq.type_as(freqs), freqs)
emb = torch.cat((freqs, freqs), dim=-1)

from einops import rearrange
emb = rearrange(emb, "n d -> 1 n 1 d")

cos, sin = emb.cos(), emb.sin()
return [cos[:, :end], sin[:, :end]]

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
Expand All @@ -28,3 +47,35 @@ def apply_rotary_emb(
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).transpose(1,2), xk_out.type_as(xk).transpose(1,2)

def _rotate_half(x: torch.Tensor):
from einops import rearrange

x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(t: torch.Tensor, freqs: list[torch.Tensor, torch.Tensor]):
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
cos = cos.to(t_rot.device)
sin = sin.to(t_rot.device)
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
return torch.cat((t_rot, t_pass), dim=-1).type_as(t).transpose(1, 2)

def apply_rotary_emb_with_ntk(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: list[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
rotary_pos_emb = freqs_cis
seq_length = xq.size(1)
rotary_pos_emb = [i[:, -seq_length:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
xq_out = apply_rotary_pos_emb(xq, q_pos_emb)
xk_out = apply_rotary_pos_emb(xk, k_pos_emb)
return xq_out, xk_out
Loading