-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathtrain.py
191 lines (173 loc) · 7.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import html
import time
from argparse import ArgumentParser
from datetime import datetime
from pathlib import Path
import numpy as np
import torch
import yaml
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from countdown_task import CountdownTasksDataset, reward_function
from grpo import rollout, update_policy
from optimizer import MemoryEfficientAdamW
from qwen2_model import Transformer
from tokenizer import Tokenizer
def evaluate(model, tokenizer, device, dtype, config):
test_dataset = CountdownTasksDataset(
data_path=config["data"]["path"],
tokenizer=tokenizer,
split="test",
test_size=config["data"]["test_size"],
)
generator = torch.Generator(device=device)
# We reduce the batch size by half as we want to
# generate twice as long trajectories.
dataloader = DataLoader(
test_dataset,
shuffle=False,
collate_fn=CountdownTasksDataset.collate_fn,
generator=generator,
batch_size=config["training"]["batch_size"] // 2,
drop_last=False,
)
success = []
for batch in dataloader:
episodes = rollout(
model=model,
tokenizer=tokenizer,
batch=batch,
max_gen_len=config["training"]["max_gen_len"] * 2,
num_answer_per_question=1,
reward_function=reward_function,
device=device,
dtype=dtype,
)
success.extend([episode.reward_info["answer_reward"] for episode in episodes])
return np.mean(success)
def main(config_path: str):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
pretrained_model_path = Path(config["model"]["pretrained_model_path"])
device = torch.device(config["model"]["device"])
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
dtype = dtype_map.get(config["model"]["dtype"], torch.bfloat16)
torch.set_default_device(device)
torch.random.manual_seed(config["training"]["random_seed"])
BATCH_SIZE = config["training"]["batch_size"]
NUM_QUESTIONS_PER_BATCH = config["training"]["num_questions_per_batch"]
NUM_ANSWERS_PER_QUESTION = BATCH_SIZE // NUM_QUESTIONS_PER_BATCH
current_time = datetime.now().strftime(r"%Y%m%d-%H%M%S")
tb_writer = SummaryWriter(log_dir=f"{config['training']['log_dir']}/{current_time}")
tokenizer = Tokenizer(str(pretrained_model_path / "tokenizer.json"))
train_dataset = CountdownTasksDataset(
data_path=config["data"]["path"],
tokenizer=tokenizer,
split="train",
test_size=config["data"]["test_size"],
)
generator = torch.Generator(device=device)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=CountdownTasksDataset.collate_fn,
generator=generator,
batch_size=NUM_QUESTIONS_PER_BATCH,
)
model = Transformer.from_pretrained(pretrained_model_path, device=device).train()
optimizer = MemoryEfficientAdamW(
model.parameters(),
lr=config["training"]["learning_rate"],
weight_decay=config["training"]["weight_decay"],
betas=config["training"]["betas"],
enabled=config["training"]["memory_efficient_adamw"],
)
start_time = time.time()
ckpt_dir = Path(config["training"]["ckpt_dir"])
ckpt_dir.mkdir(parents=True, exist_ok=True)
for step, batch in enumerate(train_dataloader, start=1):
episodes = rollout(
model=model,
tokenizer=tokenizer,
batch=batch,
max_gen_len=config["training"]["max_gen_len"],
num_answer_per_question=NUM_ANSWERS_PER_QUESTION,
reward_function=reward_function,
device=device,
dtype=dtype,
)
if config["training"]["skip_unfinished_episodes"]:
episodes = [episode for episode in episodes if episode.is_finished]
results = update_policy(
model=model,
optimizer=optimizer,
episodes=episodes,
micro_batch_size=config["training"]["micro_batch_size"],
pad_token_id=tokenizer.pad_token_id,
max_grad_norm=config["training"]["max_grad_norm"],
device=device,
dtype=dtype,
)
torch.cuda.synchronize()
end_time = time.time()
duration = end_time - start_time
start_time = end_time
# compute and log important metrics
reward = [episode.reward for episode in episodes]
formatted_reward = [
episode.reward_info["format_reward"] for episode in episodes
]
answer_reward = [episode.reward_info["answer_reward"] for episode in episodes]
num_finished_episodes = sum(episode.is_finished for episode in episodes)
mean_reward = np.mean(reward)
std_reward = np.std(reward)
success_rate = np.mean(answer_reward)
format_reward = np.mean(formatted_reward)
grad_norm = results["grad_norm"]
entropy = results["entropy"]
lr = optimizer.param_groups[0]["lr"]
loss = results["loss"]
mean_response_len = np.mean(
[len(episode.generated_token_ids) for episode in episodes]
)
print(
f"\rStep {step}, mean_reward: {mean_reward:.2f}, "
f"train success_rate: {success_rate:.2f}, "
f"grad_norm: {grad_norm:.2f}, duration: {duration:.2f}, "
f"num_finished_episodes: {num_finished_episodes}, "
f"mean_response_len: {mean_response_len:.2f}, "
f"entropy: {entropy:.2f}"
)
if step % config["training"]["eval_interval"] == 0:
eval_success_rate = evaluate(model, tokenizer, device, dtype, config)
print(f"\rEval success rate: {eval_success_rate:.2f}" + " " * 100)
tb_writer.add_scalar("success_rate/eval", eval_success_rate, step)
tb_writer.add_scalar("loss", loss, step)
tb_writer.add_scalar("mean_reward", mean_reward, step)
tb_writer.add_scalar("std_reward", std_reward, step)
tb_writer.add_scalar("success_rate/train", success_rate, step)
tb_writer.add_scalar("format_reward", format_reward, step)
tb_writer.add_scalar("grad_norm", grad_norm, step)
tb_writer.add_scalar("duration", duration, step)
tb_writer.add_scalar("num_finished_episodes", num_finished_episodes, step)
tb_writer.add_scalar("learning_rate", lr, step)
tb_writer.add_scalar("mean_response_len", mean_response_len, step)
tb_writer.add_scalar("entropy", entropy, step)
for i, episode in enumerate(episodes):
# TensorBoard treats text as markdown.
text = html.escape(episode.text)
tb_writer.add_text(f"text_{i}", f"<pre>{text}</pre>", step)
# save checkpoint
if step % config["training"]["ckpt_save_interval"] == 0:
output_file = ckpt_dir / f"ckpt_{step:06d}.pt"
torch.save(model.state_dict(), output_file)
print(f"Saved checkpoint to {output_file}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
main(args.config)