forked from open-mmlab/mmengine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfsdp_finetune.py
168 lines (143 loc) · 6.08 KB
/
fsdp_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import argparse
import copy
from functools import partial
import torch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.data import default_data_collator
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from mmengine import load
from mmengine._strategy import FSDPStrategy
from mmengine.dataset import DefaultSampler
from mmengine.dist.utils import is_main_process
from mmengine.optim import StepLR
from mmengine.utils import apply_to
from mmengine.visualization import Visualizer, WandbVisBackend
ORI_BATCH_SIZE = 4
PROMPT_DICT = {
'prompt_input':
('Below is an instruction that describes a task, paired with an input '
'that provides further context. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n'
'### Input:\n{input}\n\n### Response:'),
'prompt_no_input':
('Below is an instruction that describes a task. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n### Response:'),
}
# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/ft_datasets/alpaca_dataset.py # noqa: E501
class AlpacaDataset(Dataset):
def __init__(self, data_path, tokenizer, max_words=224):
self.ann = load(data_path)
self.max_words = max_words
self.tokenizer = tokenizer
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
ann = self.ann[index]
if ann.get('input', '') == '':
prompt = PROMPT_DICT['prompt_no_input'].format_map(ann)
else:
prompt = PROMPT_DICT['prompt_input'].format_map(ann)
example = prompt + ann['output']
prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64)
example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(example, dtype=torch.int64)
padding = self.max_words - example.shape[0]
if padding > 0:
example = torch.cat(
(example, torch.zeros(padding, dtype=torch.int64) - 1))
elif padding < 0:
example = example[:self.max_words]
labels = copy.deepcopy(example)
labels[:len(prompt)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = 0
example_mask = example_mask.float()
label_mask = label_mask.float()
return {
'input_ids': example,
'labels': labels,
'attention_mask': example_mask,
}
def parse_args():
parser = argparse.ArgumentParser(description='Train alpaca with llama2')
parser.add_argument('data_root', type=str)
parser.add_argument('checkpoint', type=str)
parser.add_argument('--output-dir', type=str, default='work_dirs')
parser.add_argument('--max-epoch', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--save-interval', type=int, default=500)
args = parser.parse_args()
return args
def train():
args = parse_args()
# Setup distributed related component in Strategy.
strategy = FSDPStrategy(
model_wrapper=dict(
auto_wrap_policy=partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer})),
state_dict_cfg='full',
env_kwargs=dict(randomness=dict(seed=42)))
visualizer = Visualizer(
name='mmengine',
save_dir=args.output_dir,
vis_backends=[dict(type=WandbVisBackend)])
# Prepare model
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
tokenizer.add_special_tokens({'pad_token': '<PAD>'})
model = LlamaForCausalLM.from_pretrained(args.checkpoint)
model.to(torch.bfloat16)
model.train()
# Prepare dataset
train_dataset = AlpacaDataset(
tokenizer=tokenizer, data_path=args.data_root)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=DefaultSampler(train_dataset, seed=0),
collate_fn=default_data_collator,
drop_last=True)
# Get the prepared model, scheduler and optimizer from strategy
epoch_length = len(train_dataloader)
max_iters = epoch_length * args.max_epoch
optim_cfg = dict(
optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0),
accumulative_counts=ORI_BATCH_SIZE / args.batch_size)
scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)]
model, optimizer, schedulers = strategy.prepare(
model,
optim_wrapper=optim_cfg,
param_scheduler=scheduler_cfgs,
dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch))
for epoch in range(args.max_epoch):
for idx, inputs in enumerate(train_dataloader):
# Convert inputs to target device.
inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor),
lambda m: m.cuda())
loss = model(**inputs).loss
optimizer.update_params(loss)
max_memory = torch.cuda.max_memory_allocated()
strategy.logger.info(f'Epoch: {epoch+1}/{args.max_epoch}, '
f'Iter: {idx+1}/{epoch_length}, '
f'Loss: {loss.item():.3f}, '
f'Lr: {optimizer.get_lr()["lr"][0]:.6f} '
f'Memory: {max_memory/1e9:.3f}G')
visualizer.add_scalars({'loss': loss.item()})
torch.cuda.reset_peak_memory_stats()
for scheduler in schedulers:
scheduler.step()
save_dir = f'{args.output_dir}/epoch_{epoch+1}'
state_dict = model.state_dict()
if is_main_process():
model.save_pretrained(save_dir, state_dict=state_dict)
tokenizer.save_pretrained(save_dir)
if __name__ == '__main__':
train()