-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathtrain_lora.py
194 lines (162 loc) · 7.6 KB
/
train_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import platform
import argparse
import random
import time
import math
import warnings
import torch.distributed as dist
from contextlib import nullcontext
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset
from model.model_lora import *
warnings.filterwarnings('ignore')
# Logger function
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
# 代码和full_sft「几乎」一致
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
# 【区别1】只保存lora权重即可
save_lora(model, f'{args.save_dir}/lora/{args.lora_name}_{lm_config.dim}.pth')
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
return model.to(args.device), tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/lora_identity.jsonl")
parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)")
args = parser.parse_args()
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
torch.manual_seed(1337)
device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
apply_lora(model)
total_params = sum(p.numel() for p in model.parameters()) # 总参数数量
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量
if not ddp or dist.get_rank() == 0:
print(f"LLM 总参数量: {total_params}")
print(f"LoRA 参数量: {lora_params_count}")
print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
for name, param in model.named_parameters():
if 'lora' not in name:
param.requires_grad = False
lora_params = []
for name, param in model.named_parameters():
if 'lora' in name:
lora_params.append(param)
# 只对 LoRA 参数进行优化
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)