diff --git a/models/falcon/7b_config.json b/models/falcon/7b_config.json new file mode 100644 index 0000000..5a3c5f4 --- /dev/null +++ b/models/falcon/7b_config.json @@ -0,0 +1,21 @@ +{ + "emb_size": 4544, + "feedforward_size": 18176, + "hidden_size": 4544, + "hidden_act": "gelu_fast", + "heads_num": 71, + "layers_num": 32, + "dropout": 0.0, + "data_processor": "lm", + "max_seq_length": 2048, + "embedding": ["word"], + "remove_transformer_bias": true, + "remove_embedding_layernorm": true, + "attention": "flash_attention", + "encoder": "transformer", + "mask": "causal", + "layernorm_positioning": "parallel_attn", + "layernorm": "normal_torch", + "layernorm_eps": 1e-5, + "target": ["lm"] +} \ No newline at end of file diff --git a/models/falcon_special_tokens_map.json b/models/falcon_special_tokens_map.json new file mode 100644 index 0000000..5e1d0b2 --- /dev/null +++ b/models/falcon_special_tokens_map.json @@ -0,0 +1,7 @@ +{ + "pad_token": "<|endoftext|>", + "unk_token": "<|endoftext|>", + "cls_token": "<|endoftext|>", + "sep_token": "<|endoftext|>", + "mask_token": "" +} diff --git a/preprocess.py b/preprocess.py index 5e3a245..541c413 100644 --- a/preprocess.py +++ b/preprocess.py @@ -39,6 +39,7 @@ def main(): help="Probability of truncating sequence." "The larger value, the higher probability of using short (truncated) sequence.") parser.add_argument("--full_sentences", action="store_true", help="Full sentences.") + parser.add_argument("--remove_cls_token", action="store_true", help="Preprocess without CLS tokens.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") # Masking options. diff --git a/scripts/convert_falcon_from_huggingface_to_tencentpretrain.py b/scripts/convert_falcon_from_huggingface_to_tencentpretrain.py new file mode 100644 index 0000000..ffc4963 --- /dev/null +++ b/scripts/convert_falcon_from_huggingface_to_tencentpretrain.py @@ -0,0 +1,56 @@ +import argparse +import collections +import torch +import os +import json + + +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("--input_model_path", type=str, default="models/falcon-7b/", + help=".") +parser.add_argument("--output_model_path", type=str, default="models/falcon-7b.bin", + help=".") +parser.add_argument("--layers_num", type=int, default=32, + help=".") + +args = parser.parse_args() + +files = os.listdir(args.input_model_path) +model_files = [f for f in files if f[-4:] == ".bin"] +input_models = {f: torch.load(os.path.join(args.input_model_path, f), map_location="cpu") for f in model_files} + +with open(os.path.join(args.input_model_path, "pytorch_model.bin.index.json")) as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + + +output_model = collections.OrderedDict() + +def get_weight_from_name(layer_name): + return input_models[weight_map[layer_name]][layer_name] + + +output_model["embedding.word.embedding.weight"] = get_weight_from_name("transformer.word_embeddings.weight") + +for i in range(args.layers_num): + + output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \ + get_weight_from_name("transformer.h." + str(i) + ".input_layernorm.weight") + output_model["encoder.transformer." + str(i) + ".layer_norm_1.bias"] = \ + get_weight_from_name("transformer.h." + str(i) + ".input_layernorm.bias") + + output_model["encoder.transformer." + str(i) + ".self_attn.query_key_value.weight"] = \ + get_weight_from_name("transformer.h." + str(i) + ".self_attention.query_key_value.weight") + output_model["encoder.transformer." + str(i) + ".self_attn.dense.weight"] = \ + get_weight_from_name("transformer.h." + str(i) + ".self_attention.dense.weight") + + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \ + get_weight_from_name("transformer.h." + str(i) + ".mlp.dense_h_to_4h.weight") + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \ + get_weight_from_name("transformer.h." + str(i) + ".mlp.dense_4h_to_h.weight") + +output_model["encoder.layer_norm.weight"] = get_weight_from_name("transformer.ln_f.weight") +output_model["encoder.layer_norm.bias"] = get_weight_from_name("transformer.ln_f.bias") +output_model["target.lm.output_layer.weight"] = get_weight_from_name("lm_head.weight") + +torch.save(output_model, args.output_model_path) diff --git a/scripts/generate_lm.py b/scripts/generate_lm.py index 69d89ff..e3bf193 100644 --- a/scripts/generate_lm.py +++ b/scripts/generate_lm.py @@ -79,11 +79,13 @@ def top_k_top_p_filtering(logits, top_k, top_p): args = load_hyperparam(args) + args.tokenizer_type = args.tokenizer args.tokenizer = str2tokenizer[args.tokenizer](args) model = GenerateLm(args) model = load_model(model, args.load_model_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) model.eval() @@ -110,6 +112,8 @@ def top_k_top_p_filtering(logits, top_k, top_p): f.write(line + "\n") tokens = [token_id.item() for token_id in src_tensor[0]] + if args.tokenizer_type in ["hfpretrained"]: + generated_sentence = args.tokenizer.decode(tokens) if args.tokenizer.sp_model is not None: generated_sentence = args.tokenizer.sp_model.decode(tokens) else: diff --git a/tencentpretrain/encoders/transformer_encoder.py b/tencentpretrain/encoders/transformer_encoder.py index 68673be..5db7056 100644 --- a/tencentpretrain/encoders/transformer_encoder.py +++ b/tencentpretrain/encoders/transformer_encoder.py @@ -2,7 +2,7 @@ import torch.nn as nn from tencentpretrain.utils.rope import precompute_freqs_cis from tencentpretrain.layers.transformer import TransformerLayer -from tencentpretrain.layers.layer_norm import * +from tencentpretrain.layers import * from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding class TransformerEncoder(nn.Module): @@ -36,13 +36,8 @@ def __init__(self, args): self.transformer = nn.ModuleList( [TransformerLayer(args) for _ in range(self.layers_num)] ) - if self.layernorm_positioning == "pre": - if args.layernorm == "t5": - self.layer_norm = T5LayerNorm(args.hidden_size) - elif args.layernorm == "rms": - self.layer_norm = RMSNorm(args.hidden_size) - else: - self.layer_norm = LayerNorm(args.hidden_size) + + self.layer_norm = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) if self.relative_position_embedding: self.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num, @@ -143,7 +138,7 @@ def custom_forward(*inputs): has_residual_attention=self.has_residual_attention, prev_attn=prev_attn, freqs_cis=freqs_cis) - if self.layernorm_positioning == "pre": + if self.layernorm_positioning in ["pre", "parallel_attn"]: return self.layer_norm(hidden) else: return hidden diff --git a/tencentpretrain/layers/__init__.py b/tencentpretrain/layers/__init__.py index e69de29..7883239 100644 --- a/tencentpretrain/layers/__init__.py +++ b/tencentpretrain/layers/__init__.py @@ -0,0 +1,15 @@ +from tencentpretrain.layers.layer_norm import * +from tencentpretrain.layers.multi_headed_attn import * +from tencentpretrain.layers.position_ffn import * +import torch.nn as nn + +str2layernorm = {"t5": T5LayerNorm, "rms": RMSNorm, "normal_torch": nn.LayerNorm, "normal": LayerNorm} + +str2attention = {"multi_head": MultiHeadedAttention, "flash_attention": FlashAttention} + +str2feedforward = {"gated": GatedFeedForward, "dense": PositionwiseFeedForward} + +__all__ = ["T5LayerNorm", "RMSNorm", "LayerNorm", "MultiHeadedAttention", + "FlashAttention", "GatedFeedForward", "PositionwiseFeedForward", "str2layernorm", + "str2attention", "str2feedforward"] + diff --git a/tencentpretrain/layers/multi_headed_attn.py b/tencentpretrain/layers/multi_headed_attn.py index e5fcd12..5236c08 100755 --- a/tencentpretrain/layers/multi_headed_attn.py +++ b/tencentpretrain/layers/multi_headed_attn.py @@ -1,8 +1,10 @@ import math import torch import torch.nn as nn +from torch.nn import functional as F from tencentpretrain.utils.rope import apply_rotary_emb from tencentpretrain.utils.lora import LoraLinear +from tencentpretrain.utils.rope import RotaryEmbedding class MultiHeadedAttention(nn.Module): """ @@ -88,3 +90,141 @@ def unshape(x): output = unshape(torch.matmul(probs, value)) output = self.final_linear(output) return output, prev_attn_out + + +class FlashAttention(nn.Module): + """ + Flash Attention used in Falcon. + https://huggingface.co/tiiuae/falcon-7b/blob/main/modelling_RW.py#L154 + """ + + def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bias=True, with_scale=True, + lora_params=None): + super(FlashAttention, self).__init__() + self.heads_num = heads_num + self.hidden_size = hidden_size + self.per_head_size = attention_head_size + self.with_scale = with_scale + self.inner_hidden_size = heads_num * attention_head_size + self.multi_query = True + + + self.rotary = RotaryEmbedding(self.per_head_size) + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.per_head_size) + self.beta = self.inv_norm_factor + + self.query_key_value = nn.Linear( + self.hidden_size, + 3 * self.hidden_size if not self.multi_query else (self.hidden_size + 2 * self.per_head_size), + bias=has_bias + ) + + self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=has_bias) + self.attention_dropout = nn.Dropout(dropout) + self.num_kv = self.heads_num if not self.multi_query else 1 + + def _split_heads(self, fused_qkv: torch.Tensor): + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + if not self.multi_query: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.heads_num, 3, self.per_head_size) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + else: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.heads_num + 2, self.per_head_size) + return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] + + def _merge_heads(self, x: torch.Tensor): + """ + Merge heads together over the last dimenstion + Args: + x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ + # What we want to achieve is: + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // self.heads_num + + # First view to decompose the batch size + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, self.heads_num, seq_length, self.per_head_size) + + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, self.heads_num * self.per_head_size) + + def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None, + freqs_cis=None): + """ + Args: + key: [batch_size x seq_length x hidden_size] + value: [batch_size x seq_length x hidden_size] + query: [batch_size x seq_length x hidden_size] + mask: [batch_size x 1 x seq_length x seq_length] + position_bias: [1 x heads_num x seq_length x seq_length] + Returns: + output: [batch_size x seq_length x hidden_size] + """ + + fused_qkv = self.query_key_value(query) + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, q_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.heads_num, q_length, self.per_head_size) + + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * self.num_kv, + q_length, + self.per_head_size, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.per_head_size) + + query_layer, key_layer = self.rotary(query_layer, key_layer) + + _, kv_length, _ = key_layer.shape + + query_layer_ = query_layer.reshape(batch_size, self.heads_num, -1, self.per_head_size) + key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.per_head_size) + value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.per_head_size) + + if torch.__version__ < "2.0.0": + scores = torch.matmul(query_layer_, key_layer_.transpose(-2, -1)) + if self.with_scale: + scores = scores / math.sqrt(float(self.per_head_size)) + scores = scores + mask.type_as(scores) + prev_attn_out = None + if has_residual_attention: + if prev_attn is not None: + scores += prev_attn + prev_attn_out = scores + probs = nn.Softmax(dim=-1)(scores) + attn_output = probs @ value_layer_ + + else: + prev_attn_out = None + attn_output = F.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True + ) + + x = attn_output.view(batch_size, self.heads_num, q_length, self.per_head_size) + x = x.permute(0, 2, 1, 3) + attn_output = x.reshape(batch_size, q_length, self.heads_num * self.per_head_size) + + output_tensor = self.dense(attn_output) + + return output_tensor, prev_attn_out diff --git a/tencentpretrain/layers/transformer.py b/tencentpretrain/layers/transformer.py index f859ec0..1c434f1 100755 --- a/tencentpretrain/layers/transformer.py +++ b/tencentpretrain/layers/transformer.py @@ -1,8 +1,5 @@ import torch.nn as nn -from tencentpretrain.layers.layer_norm import * -from tencentpretrain.layers.position_ffn import PositionwiseFeedForward, GatedFeedForward -from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention -from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding +from tencentpretrain.layers import * class TransformerLayer(nn.Module): @@ -28,32 +25,24 @@ def __init__(self, args): if hasattr(args, "lora_params"): lora_params = args.lora_params - self.self_attn = MultiHeadedAttention( + self.self_attn = str2attention[args.attention]( args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, with_scale = with_scale, lora_params=lora_params ) + self.dropout_1 = nn.Dropout(args.dropout) # Feed forward layer. - if args.feed_forward == "gated": - self.feed_forward = GatedFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) - else: - self.feed_forward = PositionwiseFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) + self.feed_forward = str2feedforward[args.feed_forward]( + args.hidden_size, args.feedforward_size, args.hidden_act, has_bias + ) + self.dropout_2 = nn.Dropout(args.dropout) - if args.layernorm == "t5": - self.layer_norm_1 = T5LayerNorm(args.hidden_size) - self.layer_norm_2 = T5LayerNorm(args.hidden_size) - elif args.layernorm == "rms": - self.layer_norm_1 = RMSNorm(args.hidden_size) - self.layer_norm_2 = RMSNorm(args.hidden_size) - else: - self.layer_norm_1 = LayerNorm(args.hidden_size) - self.layer_norm_2 = LayerNorm(args.hidden_size) + self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) + if self.layernorm_positioning != "parallel_attn": + self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) + def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, prev_attn=None, freqs_cis=None): """ @@ -71,13 +60,20 @@ def forward(self, hidden, mask, position_bias=None, has_residual_attention=False inter = self.layer_norm_1(inter + hidden) output = self.dropout_2(self.feed_forward(inter)) output = self.layer_norm_2(output + inter) - else: + elif self.layernorm_positioning == "pre": inter = self.layer_norm_1(hidden) inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention, prev_attn, freqs_cis) inter = self.dropout_1(inter) hidden = hidden + inter output = self.layer_norm_2(hidden) output = self.dropout_2(self.feed_forward(output)) + hidden + else: # parallel_attn: Flash Attention + inter = self.layer_norm_1(hidden) + attn_output, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention, prev_attn, freqs_cis) + mlp_output = self.feed_forward(inter) + inter = self.dropout_1(mlp_output + attn_output) + output = inter + hidden + return output, prev_attn_out @@ -107,32 +103,23 @@ def __init__(self, args): self.dropout_1 = nn.Dropout(args.dropout) # Multi-headed context-attention. - self.context_attn = MultiHeadedAttention( + self.context_attn = str2attention[args.attention]( args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, with_scale=with_scale, lora_params=lora_params ) self.dropout_2 = nn.Dropout(args.dropout) # Feed forward layer. - if args.feed_forward == "gated": - self.feed_forward = GatedFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) - else: - self.feed_forward = PositionwiseFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) + self.feed_forward = str2feedforward[args.feed_forward]( + args.hidden_size, args.feedforward_size, args.hidden_act, has_bias + ) + self.dropout_3 = nn.Dropout(args.dropout) # Layer Normalization - if args.layernorm == "t5": - self.layer_norm_1 = T5LayerNorm(args.hidden_size) - self.layer_norm_2 = T5LayerNorm(args.hidden_size) - self.layer_norm_3 = T5LayerNorm(args.hidden_size) - else: - self.layer_norm_1 = LayerNorm(args.hidden_size) - self.layer_norm_2 = LayerNorm(args.hidden_size) - self.layer_norm_3 = LayerNorm(args.hidden_size) + self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) + self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) + self.layer_norm_3 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) def forward(self, hidden, encoder_hidden, mask_decoder, mask_encoder, self_position_bias=None, context_position_bias=None): """ @@ -156,7 +143,7 @@ def forward(self, hidden, encoder_hidden, mask_decoder, mask_encoder, self_posit mid_norm = self.layer_norm_2(mid + query_norm) output = self.dropout_3(self.feed_forward(mid_norm)) output = self.layer_norm_3(output + mid_norm) - else: + elif self.layernorm_positioning == "pre": hidden_norm = self.layer_norm_1(hidden) query, _ = self.self_attn(hidden_norm, hidden_norm, hidden_norm, mask_decoder, self_position_bias) query = self.dropout_1(query) @@ -167,4 +154,6 @@ def forward(self, hidden, encoder_hidden, mask_decoder, mask_encoder, self_posit mid = mid + query mid_norm = self.layer_norm_3(mid) output = self.dropout_3(self.feed_forward(mid_norm)) + mid + else: + raise NotImplementedError return output diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index 0fcdfaf..c537452 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -17,6 +17,8 @@ def model_opts(parser): parser.add_argument("--encoder", choices=["transformer", "rnn", "lstm", "gru", "birnn", "bilstm", "bigru", "gatedcnn", "dual"], default="transformer", help="Encoder type.") + parser.add_argument("--attention", choices=["multi_head", "flash_attention"], + default="multi_head", help="Self-attention type.") parser.add_argument("--decoder", choices=[None, "transformer"], default=None, help="Decoder type.") parser.add_argument("--mask", choices=["fully_visible", "causal", "causal_with_prefix"], default="fully_visible", help="Mask type.") @@ -30,8 +32,10 @@ def model_opts(parser): help="Remove attention scale.") parser.add_argument("--remove_transformer_bias", action="store_true", help="Remove bias on transformer layers.") - parser.add_argument("--layernorm", choices=["normal", "t5"], default="normal", + parser.add_argument("--layernorm", choices=["normal", "t5", "rms", "normal_torch"], default="normal", help="Layernorm type.") + parser.add_argument("--layernorm_eps", type=float, default=1e-6, + help="Layernorm eps.") parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.") parser.add_argument("--parameter_sharing", action="store_true", help="Parameter sharing.") parser.add_argument("--has_residual_attention", action="store_true", help="Add residual attention.") @@ -173,7 +177,8 @@ def infer_opts(parser): def tokenizer_opts(parser): - parser.add_argument("--tokenizer", choices=["bert", "bpe", "char", "space", "xlmroberta", "image", "text_image", "virtual"], default="bert", + parser.add_argument("--tokenizer", choices=["bert", "bpe", "char", "space", "xlmroberta", "image", "text_image", + "virtual", "hfpretrained"], default="bert", help="Specify the tokenizer." "Original Google BERT uses bert tokenizer." "Char tokenizer segments sentences into characters." diff --git a/tencentpretrain/utils/__init__.py b/tencentpretrain/utils/__init__.py index a97e606..ffdc1da 100644 --- a/tencentpretrain/utils/__init__.py +++ b/tencentpretrain/utils/__init__.py @@ -7,7 +7,7 @@ str2tokenizer = {"char": CharTokenizer, "space": SpaceTokenizer, "bert": BertTokenizer, "bpe": BPETokenizer, "xlmroberta": XLMRobertaTokenizer, "image": ImageTokenizer, - "text_image": TextImageTokenizer, "virtual": VirtualTokenizer} + "text_image": TextImageTokenizer, "virtual": VirtualTokenizer, "hfpretrained": HFPreTrainedTokenizer} str2dataset = {"bert": BertDataset, "lm": LmDataset, "mlm": MlmDataset, "bilm": BilmDataset, "albert": AlbertDataset, "mt": MtDataset, "t5": T5Dataset, "gsg": GsgDataset, "bart": BartDataset, diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 462fe08..e2277f5 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -436,6 +436,7 @@ def __init__(self, args, vocab, tokenizer): super(LmDataset, self).__init__(args, vocab, tokenizer) self.full_sentences = args.full_sentences self.json_format_corpus = args.json_format_corpus + self.remove_cls_token = args.remove_cls_token def worker(self, proc_id, start, end): print("Worker %d is building dataset ... " % proc_id) @@ -455,7 +456,10 @@ def worker(self, proc_id, start, end): pos += 1 document = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) - document = [self.vocab.get(CLS_TOKEN)] + document + [self.vocab.get(SEP_TOKEN)] + if self.remove_cls_token: + document = document + [self.vocab.get(SEP_TOKEN)] + else: + document = [self.vocab.get(CLS_TOKEN)] + document + [self.vocab.get(SEP_TOKEN)] if self.full_sentences: buffer.extend(document) instances_num = len(buffer) // (self.seq_length + 1) diff --git a/tencentpretrain/utils/rope.py b/tencentpretrain/utils/rope.py index 129858a..59649e0 100644 --- a/tencentpretrain/utils/rope.py +++ b/tencentpretrain/utils/rope.py @@ -28,3 +28,58 @@ def apply_rotary_emb( xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).transpose(1,2), xk_out.type_as(xk).transpose(1,2) + + +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 + + +class RotaryEmbedding(torch.nn.Module): + """Implementation of RotaryEmbedding from GPT-NeoX. + This implementation is design to operate on queries and keys that are compatible with + [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). + """ + + def __init__( + self, + head_dim: int, + base=10000, + ): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = None + self.batch_size_cached = None + self.cos_cached: torch.Tensor | None = None + self.sin_cached: torch.Tensor | None = None + + def cos_sin( + self, + seq_len: int, + device="cuda", + dtype=torch.bfloat16, + ) -> torch.Tensor: + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + return self.cos_cached, self.sin_cached + + def forward(self, q, k): + batch, seq_len, head_dim = q.shape + cos, sin = self.cos_sin(seq_len, q.device, q.dtype) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) diff --git a/tencentpretrain/utils/tokenizers.py b/tencentpretrain/utils/tokenizers.py index 2f13a3d..6ee3512 100644 --- a/tencentpretrain/utils/tokenizers.py +++ b/tencentpretrain/utils/tokenizers.py @@ -602,3 +602,23 @@ def __init__(self, args, is_src=True): self.vocab_bias = len(self.vocab) for i in range(args.image_tokenizer["image_vocab_size"]): self.vocab[i + self.vocab_bias] = str(i) + + +class HFPreTrainedTokenizer(Tokenizer): + def __init__(self, args, is_src=True): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(args.vocab_path) + self.sp_model = None + self.vocab = self.tokenizer.vocab + + def tokenize(self, text): + return self.tokenizer.tokenize(text) + + def convert_tokens_to_ids(self, tokens): + return self.tokenizer.convert_tokens_to_ids(tokens) + + def convert_ids_to_tokens(self, ids): + return self.tokenizer.convert_ids_to_tokens(ids) + + def decode(self, ids): + return self.tokenizer.decode(ids)