This repository has been archived by the owner on Jun 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 38
/
train_pythia_flash_toolformer.py
182 lines (156 loc) · 6.24 KB
/
train_pythia_flash_toolformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# From: https://github.com/kyleliang919/Long-context-transformers
import torch
import numpy as np
import evaluate
from datasets import load_dataset
from transformers import GPTNeoXForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding
from transformers.trainer_utils import get_last_checkpoint
from itertools import chain
from typing import Optional
from dataclasses import dataclass, field
from transformers import (
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
)
from flash_attention.flash_attention_gptj_wrapper import FlashAttentionWrapper
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default="pythia-1.3b",
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
max_positions: Optional[int] = field(
default=8192,
metadata={"help": ("The maximun sequence length of the model.")},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default="pile",
metadata={"help": "The name of the dataset to use (via the datasets library)."},
)
def main():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
last_checkpoint = get_last_checkpoint(training_args.output_dir)
set_seed(training_args.seed)
model = GPTNeoXForCausalLM.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token = tokenizer.mask_token
max_positions = model_args.max_positions
tokenizer.model_max_length = max_positions
for each in model.gpt_neox.layers:
original_emb = each.attention.rotary_emb
each.attention.rotary_emb = RotaryEmbedding(
each.attention.rotary_ndims, max_positions, 10000
)
each.attention.bias = torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.uint8)
).view(1, 1, max_positions, max_positions)
each.attention = FlashAttentionWrapper(each.attention, max_seqlen=max_positions)
# patching for the random contiguous tensors bug
for p in model.parameters():
p = p.contiguous()
def merge_questions_and_answers(examples):
out = tokenizer(
[
question + " " + answer
for question, answer in zip(examples["input"], examples["output"])
]
)
return out
block_size = tokenizer.model_max_length
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
if data_args.dataset_name == "pile":
base_url = "https://the-eye.eu/public/AI/pile/"
data_files = {
"train": [
base_url + "train/" + f"{idx:02d}.jsonl.zst" for idx in range(30)
],
"validation": base_url + "val.jsonl.zst",
"test": base_url + "test.jsonl.zst",
}
datasets = load_dataset("json", data_files=data_files, streaming=True)
datasets = datasets.filter(lambda x: len(x["text"]) >= max_positions)
tokenized_datasets = datasets.map(
lambda examples: tokenizer(examples["text"]),
batched=True,
)
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
)
lm_datasets = lm_datasets.filter(lambda x: len(x["input_ids"]) >= max_positions)
else:
raise Exception("Sorry, please the dataset specified can not be recognized")
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
preds, labels = eval_pred
labels = labels[:, 1:].reshape(-1)
preds = preds[:, :-1].reshape(-1)
return metric.compute(predictions=preds, references=labels)
train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=default_data_collator,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
else:
checkpoint = None
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
max_train_samples = len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if __name__ == "__main__":
main()