forked from bigscience-workshop/Megatron-DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune_t0_non_causal_decoder.py
175 lines (149 loc) · 6.58 KB
/
finetune_t0_non_causal_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Multitask Finetuning T0"""
import torch
from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import PositionEmbeddingType, AttnMaskType
from megatron.model import GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_packed_attention_mask
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
try:
from torch.distributed.elastic.multiprocessing.errors import record
except ImportError:
# noop
def record(fn):
return fn
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building GPT model ...")
see_memory_usage(f"Before Building Model", force=True)
args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True,
attn_mask_type=AttnMaskType.custom
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
raise NotImplementedError("DeepSpeed is required for T0")
see_memory_usage(f"After Building Model", force=True)
return model
def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
data:
decoder_tokens = [[6, 7, 8, 3, 4, 5, 0]]
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]
"""
args = get_args()
tokenizer = get_tokenizer()
# Broadcast data.
data_b = mpu.broadcast_data(["decoder_token_ids", "decoder_segment_ids"], data, torch.int64)
data_c = mpu.broadcast_data(["decoder_is_inputs"], data, torch.bool)
# Unpack.
tokens_ = data_b["decoder_token_ids"].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
segment_ids = data_b["decoder_segment_ids"].long()[:, :-1]
decoder_is_inputs = data_c["decoder_is_inputs"][:, :-1]
# Get the masks and position ids.
causal_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=False # This is done below
)
# Only compute loss over causal target tokens, i.e. ignore input_tokens & padding
loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:]
loss_on_non_pad_only = (tokens != tokenizer.pad)
loss_mask *= loss_on_targets_only * loss_on_non_pad_only
attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
decoder_is_inputs=decoder_is_inputs.bool(),
segment_ids=segment_ids.long(),
)
if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
train_ds, valid_ds, test_ds = None, None, None
tokenizer = get_tokenizer()
print_rank_0("> building train, validation, and test datasets for T0 ...")
# Option 1 of data loading using --data-path
if args.data_path:
# TODO: Not yet compatible with dataset weights (Will break at prefixes, weights = analyze_data_prefix(args.data_path))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
train_valid_test_num_samples=train_val_test_num_samples,
seed=args.seed,
skip_warmup=(not args.mmap_warmup)
)
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
elif args.train_weighted_split_paths:
assigned_train_valid_test = []
if args.train_weighted_split_paths is not None:
train_ds = []
assigned_train_valid_test.append("train")
if args.valid_weighted_split_paths is not None:
valid_ds = []
assigned_train_valid_test.append("valid")
if args.test_weighted_split_paths is not None:
test_ds = []
assigned_train_valid_test.append("test")
for s in assigned_train_valid_test:
data_groups = zip(eval(f"args.{s}_weighted_split_paths"),
eval(f"args.{s}_weighted_split_weights"),
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")
print_rank_0("> finished creating T0 datasets ...")
return train_ds, valid_ds, test_ds
@record
def main():
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step_func=None,
args_defaults={}
)
if __name__ == "__main__":
main()