generated from SparkJiao/pytorch-transformers-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert2ckpt_double_head.py
118 lines (91 loc) · 4.29 KB
/
convert2ckpt_double_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from pathlib import Path
from typing import Optional, Literal
from dataclasses import dataclass, field
import torch
import transformers
from transformers.models.llama.modeling_llama import LlamaConfig
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from general_util.tokenization_utils import expand_special_tokenizer, PreTrainedTokenizer
@dataclass
class Arguments:
model_name_or_path: Optional[str] = field(default="/path/to/llama-7b-hf")
output_dir: str = field(default="./llama-7B-init-ckpt")
mp_world_size: int = field(default=1)
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
# TODO: padding embedding size for being divisible by 64.
original_vocab_size = model.get_input_embeddings().weight.shape[0]
num_new_tokens = len(tokenizer) - original_vocab_size
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def write_ckpt(outpath: Path, model: torch.nn.Module, model_config: LlamaConfig, mp: int):
loaded = model.state_dict()
n_layers = model_config.num_hidden_layers
# embedding
sd = {"weight": loaded['model.embed_tokens.weight']}
torch.save(sd, outpath / "layer_00-model_00-model_states.pt")
# norm
sd = {f"weight": loaded['model.norm.weight']}
torch.save(sd, outpath / f"layer_{n_layers + 1}-model_00-model_states.pt")
# lm head
sd = {f"lm_head.weight": loaded['lm_head.weight'],
f"rw_head.weight": loaded['reward_head.weight']}
torch.save(sd, outpath / f"layer_{n_layers + 2}-model_00-model_states.pt")
# decoder layers
for layer_i in range(n_layers):
sd = {nm.replace(f"model.layers.{layer_i}.", f""): weight for nm, weight in loaded.items() if
nm.startswith(f"model.layers.{layer_i}.")}
torch.save(sd, outpath / f"layer_{layer_i + 1:02d}-model_00-model_states.pt")
model_state = {
"dp_world_size": 1,
"mp_world_size": mp,
"module": None,
"optimizer": None,
"global_steps": 1,
"skipped_steps": 1,
"iteration": 1,
}
for rank in range(mp):
torch.save(model_state, outpath / f"mp_rank_{rank:02d}_model_states.pt")
def main():
parser = transformers.HfArgumentParser((Arguments,))
args, = parser.parse_args_into_dataclasses()
tokenizer: PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)
# model = transformers.AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
from models.llama import LlamaTokenRewardModel
model = LlamaTokenRewardModel.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True,
device_map={"": "cpu"})
model_config = model.config
original_vocab_size = model_config.vocab_size
tokenizer.add_tokens(["<eot>", "<ext_0>", "<ext_1>", "<ext_2>", "<ext_3>"])
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
if len(tokenizer) > original_vocab_size:
print(f"expand vocab size from {original_vocab_size} to {len(tokenizer)}")
smart_tokenizer_and_embedding_resize(tokenizer, model)
outpath = Path(args.output_dir)
if outpath.exists():
print(f"{outpath} exists. Do nothing.")
exit(0)
print(f"create {outpath}")
outpath.mkdir()
steppath = outpath / "global_step001"
steppath.mkdir()
write_ckpt(steppath, model, model_config, args.mp_world_size)
with open(outpath / "latest", "w") as fout:
fout.write("global_step001")
tokenizer.save_pretrained(outpath)
model_config.save_pretrained(outpath)
if __name__ == "__main__":
main()