-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_gpt.py
114 lines (94 loc) · 2.94 KB
/
train_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import random
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
def fine_tune_gpt2(
model_name,
train_file,
eval_file,
output_dir,
epochs=1,
batch_size=4,
max_seq_length=128,
bos="<|startoftext|>",
eos="<|endoftext|>",
pad="<|pad|>",
):
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(
model_name,
bos_token=bos,
eos_token=eos,
pad_token=pad,
)
model.resize_token_embeddings(len(tokenizer))
train_dataset = TextDataset(
tokenizer=tokenizer, file_path=train_file, block_size=max_seq_length
)
eval_dataset = TextDataset(
tokenizer=tokenizer, file_path=eval_file, block_size=max_seq_length
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
evaluation_strategy="steps",
eval_steps=100,
logging_steps=1000,
save_total_limit=2,
logging_dir="./diffusion/gpt-logs/",
gradient_accumulation_steps=2,
num_train_epochs=epochs,
weight_decay=0.1,
warmup_steps=1000,
lr_scheduler_type="cosine",
learning_rate=5e-4,
save_steps=1000,
fp16=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(output_dir)
if __name__ == "__main__":
random.seed(42)
with open("./resources/v7.txt") as f:
texts = []
for line in f:
texts.append(line)
random.shuffle(texts)
n_train = int(len(texts) * 0.95)
train_data = texts[:n_train]
valid_data = texts[n_train:]
bos = "<|startoftext|>"
eos = "<|endoftext|>"
with open("./diffusion/gpt_train.txt", "w", encoding="utf-8") as f:
for line in train_data:
f.write(bos + line.replace("\n", f"{eos}\n"))
with open("./diffusion/gpt_valid.txt", "w", encoding="utf-8") as f:
for line in valid_data:
f.write(bos + line.replace("\n", f"{eos}\n"))
# model_name = "distilgpt2" # GPT-2 모델 사용
train_file = "./diffusion/gpt_train.txt"
valid_file = "./diffusion/gpt_valid.txt"
output_dir = "./diffusion/gpt-outputs-gpt2-3/"
model_name = "gpt2" # "distilgpt2"
epochs = 10
batch_size = 128
max_seq_length = 77
fine_tune_gpt2(
model_name,
train_file,
valid_file,
output_dir,
epochs=epochs,
batch_size=batch_size,
max_seq_length=max_seq_length,
)