forked from argonne-lcf/Megatron-DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dpo_training.py
1542 lines (1354 loc) · 70.2 KB
/
dpo_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Direct Preference Optimization"""
import os
from rich import print
import torch
import math
import numpy as np
# The earliest we can measure the start time.
import time
from datetime import datetime
import threading
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group, update_rotary_pos_emb
from megatron.arguments import core_transformer_config_from_args
from megatron.utils import (
report_memory,
throughput_calculator,
checkpoint_throughput_calculator
)
from pathlib import Path
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.accelerator.real_accelerator import get_accelerator
import subprocess
import wandb
import time
from torch import nn
import torch.nn.functional as F
# from ezpz import get_logger
from ezpz.dist import get_world_size, setup_wandb, get_rank
# More imports
from megatron.initialize import initialize_megatron
from megatron.initialize import set_jit_fusion_options
from megatron.training import print_datetime, _create_ds_config_dict
from megatron.training import setup_model_and_optimizer
from megatron.training import load_model_weights_only, get_model
from megatron.training import load_model_weights_only_modified
from megatron.training import get_optimizer_param_scheduler, cyclic_iter
from megatron.training import train, train_step
from megatron.training import train_step_dpo, training_log_dpo
from megatron.optimizer import get_megatron_optimizer
from megatron.checkpointing import load_checkpoint
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.arguments import core_transformer_config_from_args
from megatron import update_num_microbatches
from megatron import get_num_microbatches
from megatron.utils import throughput_calculator, get_parameters_in_billions
from megatron.text_generation import generate_and_post_process, beam_search_and_post_process
from megatron.text_generation.forward_step import ForwardStep, InferenceParams
from megatron.text_generation.sampling import sample
from megatron.text_generation.tokenization import detokenize_generations
from megatron.text_generation.communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from megatron.checkpointing import save_checkpoint
from megatron.utils import get_ltor_masks_and_position_ids
from generate_utils import generate_post_training
# RANK = setup_torch(
# backend='deepspeed',
# port='5432',
# )
RANK = get_rank()
WORLD_SIZE = get_world_size()
LEVEL = "DEBUG" if RANK == 0 else "CRITICAL"
WANDB_MODE = os.environ.get('WANDB_MODE', None)
DISABLE_WANDB = (
WANDB_MODE is not None and str(WANDB_MODE).lower() == 'disabled'
)
if RANK == 0 and not DISABLE_WANDB:
project_name = (
os.environ.get(
'WB_PROJECT',
os.environ.get(
'WANDB_PROJECT',
'AuroraGPT'
),
)
)
print('--------------------------------------------------')
print(f"Setting up W&B from: {RANK} with {project_name}")
print('--------------------------------------------------')
#setup_wandb(project_name=project_name)
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
see_memory_usage("Before Building Model", force=True)
args = get_args()
config = core_transformer_config_from_args(args)
if wandb.run is not None:
print(f"Updating WandB run: [{wandb.run.name}]({wandb.run.url})")
wandb.run.config.update({"args": vars(args)}, allow_val_change=True)
if RANK == 0:
git_ds_info()
if hasattr(mpu, 'get_sequence_parallel_group'):
dpg = mpu.get_sequence_parallel_group()
elif hasattr(mpu, 'get_data_parallel_group'):
dpg = mpu.get_data_parallel_group()
else:
dpg = None
if wandb is not None and wandb.run is not None:
assert wandb is not None and wandb.run is not None
print(f'Updating {wandb.run.name=} at {wandb.run.url=}')
wandb.run.config.update({'args': vars(args)}, allow_val_change=True)
with deepspeed.zero.Init(
data_parallel_group=dpg,
remote_device=(
None if args.remote_device == 'none' else args.remote_device
),
config_dict_or_path=args.deepspeed_config_dict,
enabled=args.zero_stage == 3,
mpu=mpu
):
if args.deepspeed and not args.no_pipeline_parallel:
model = GPTModelPipe(
config=config,
num_tokentypes=0,
parallel_output=True
)
# This is a hack to give us a reference to
# get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
# Predompute the attention mask and store it in args.
# This avoids having to pipeline it
# as an activation during training.
# The mask is constant, and thus we can reuse it.
attention_mask = torch.tril(
torch.ones(
(1, args.seq_length, args.seq_length),
device=get_accelerator().current_device_name()
)
).view(1, 1, args.seq_length, args.seq_length)
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
if args.fp16:
attention_mask = attention_mask.half()
elif args.bf16:
attention_mask = attention_mask.bfloat16()
# Attention mask must be bool.
args.attn_mask = attention_mask.to(torch.bool)
# For prertaining, since sequence length is fixed,
# cache rotary embedding in args, to avoid communicating around
if args.use_rotary_position_embeddings:
update_rotary_pos_emb(args.seq_length)
else:
print(f'Building model check..')
model = GPTModel(
config=config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print_rank_0('\n ------------------------ ')
# print_rank_0(f'num of parameters {num_params}')
# print_rank_0('------------------------\n ')
print_rank_0(80 * '-')
print_rank_0(f"Number of parameters in model: {num_params}")
print_rank_0(80 * '-')
see_memory_usage("After Building Model", force=True)
if wandb.run is not None:
wandb.run.config.update({'num_params': num_params}, allow_val_change=True)
# wandb.run.watch(
# model,
# log='all',
# log_graph=True,
# )
# wandb.run.config.update({'num_params': num_params})
return model
def throughput_flops(model, args, iteration_time, total_iterations):
batch_size = args.micro_batch_size * get_num_microbatches() * args.data_parallel_size
approx_parameters_in_billions = None if (model is None) else get_parameters_in_billions(model)
elapsed_time_per_iter = iteration_time/total_iterations
samples_per_second = batch_size / elapsed_time_per_iter
#flops calculator
hidden_size = args.hidden_size
num_layers = args.num_layers
vocab_size = args.padded_vocab_size
# General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of
# https://arxiv.org/pdf/2104.04473.pdf).
# The factor of 4 is when used with activation check-pointing,
# otherwise it will be 3.
checkpoint_activations_factor = 3
if hasattr(args, 'checkpoint_activations') and args.checkpoint_activations:
checkpoint_activations_factor = 4
if hasattr(args, 'recompute_granularity') and (args.recompute_granularity == 'selective' or args.recompute_granularity == 'full'):
checkpoint_activations_factor = 4
seq_len = args.seq_length
if hasattr(args, 'actual_seq_length'):
seq_len = args.actual_seq_length
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
tflops = flops_per_iteration / (elapsed_time_per_iter * args.world_size * (10**12))
return tflops
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# print(f'len(tokenizer.vocab): {len(tokenizer.vocab)}')
# Items and their type.
keys = ['text']
datatype = torch.int64
data = next(data_iterator) if data_iterator is not None else None
# # Broadcast data.
# if data_iterator is not None:
# data = next(data_iterator)
# else:
# data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
skip_mask = args.use_flash_attn or args.use_flash_attn_triton
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
skip_mask)
# For DS's sequence parallel
seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
seq_parallel_world_rank = mpu.get_sequence_parallel_rank()
# For Megatron's sequence parallel
if args.sequence_parallel:
seq_parallel_world_size = mpu.get_tensor_model_parallel_world_size()
seq_parallel_world_rank = mpu.get_tensor_model_parallel_rank()
seq_length = tokens.size(1)
assert seq_length % seq_parallel_world_size == 0
sub_seq_length = seq_length // seq_parallel_world_size
sub_seq_start = seq_parallel_world_rank * sub_seq_length
sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length
tokens = tokens[:, sub_seq_start:sub_seq_end]
position_ids = position_ids[:, sub_seq_start:sub_seq_end]
# For DS's sequence parallel
if mpu.get_sequence_parallel_world_size() > 1:
labels = labels[:, sub_seq_start:sub_seq_end]
return tokens, labels, loss_mask, attention_mask, position_ids
def data_post_process(data, data_sampler_state_dict):
args = get_args()
if args.data_efficiency_curriculum_learning:
if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']:
args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate'
current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate']
if current_seqlen < args.seq_length:
data['text'] = data['text'][:, :(current_seqlen+1)].contiguous()
elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']:
args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape'
current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape']
if current_seqlen < args.seq_length:
orig_num_token = torch.numel(data['text'])
reshape_len = (data['text'].size()[1] // (current_seqlen+1)) * (current_seqlen+1)
data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen+1),
data['text'][:, -(current_seqlen+1):]), 0).contiguous()
num_row = math.ceil(orig_num_token / (current_seqlen+1))
num_row = min(num_row, data['text'].size()[0])
if num_row > 1 and num_row % 2 != 0:
num_row -= 1
data['text'] = data['text'][:num_row, :].contiguous()
else:
args.data_efficiency_curriculum_learning_seqlen_type = None
return data
def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)`
instead of `data_iterator`
"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
if (
args.curriculum_learning_legacy
and args.curriculum_seqlen < tokens.size()[1]
):
# seqlen-based curriculum learning
# tokens, position_ids, labels, loss_mask
# have size [batch size, seqlen]
tokens = tokens[:, :args.curriculum_seqlen].contiguous()
position_ids = position_ids[:, :args.curriculum_seqlen].contiguous()
if labels is not None:
labels = labels[:, :args.curriculum_seqlen].contiguous()
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def loss_func(loss_mask, moe_loss, mos_loss, output_tensor):
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
if args.mos or args.kd:
# assert max(args.num_experts) >= 1
loss = loss + moe_loss + mos_loss
if args.mos:
return loss, {
'total loss': loss,
'lm loss': averaged_loss[0],
'moe loss': moe_loss,
'mos loss': mos_loss
}
elif args.kd:
return loss, {
'total loss': loss,
'lm loss': averaged_loss[0],
'moe loss': moe_loss,
'kd loss': mos_loss
}
print_rank_0(
f'>>> total loss: {loss}, '
f'lm loss {averaged_loss[0]}, '
f'kd loss {mos_loss}'
)
else:
if max(args.num_experts) <= 1:
return loss, {'lm loss': averaged_loss[0]}
loss = loss + moe_loss
return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss}
def dpo_loss_func(loss_mask, dpo_loss, output_tensor):
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
if args.mos or args.kd:
# assert max(args.num_experts) >= 1
loss = loss + moe_loss + mos_loss
if args.mos:
return loss, {
'total loss': loss,
'lm loss': averaged_loss[0],
'moe loss': moe_loss,
'mos loss': mos_loss
}
elif args.kd:
return loss, {
'total loss': loss,
'lm loss': averaged_loss[0],
'moe loss': moe_loss,
'kd loss': mos_loss
}
print_rank_0(
f'>>> total loss: {loss}, '
f'lm loss {averaged_loss[0]}, '
f'kd loss {mos_loss}'
)
# else:
# if max(args.num_experts) <= 1:
# return loss, {'lm loss': averaged_loss[0]}
# loss = loss + moe_loss
# return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss}
else:
# if max(args.num_experts) <= 1:
# return loss, {'lm loss': averaged_loss[0]}
loss = dpo_loss
return loss, {'lm loss': averaged_loss[0], 'dpo loss': dpo_loss}
def batch_seq_logprobs(logits, labels):
""" Function to compute a batch of sequence log probabilities """
logits = logits[:-1, :, :] # skip last logit
logits_logsoftmax = logits.log_softmax(-1) # compute log softmax of logits
labels = labels[1:, :].clone() # clone labels
# # Loss mask to avoid padded tokens while computing loss
# loss_mask = labels != tokenizer.pad_token_id
# print(f'Labels shape: {labels.shape}')
# print(f'loss_mask shape: {loss_mask.shape}')
# print(f'loss_mask dtype: {loss_mask.dtype}')
# Gather logps and squeeze last dimension
logprobs = torch.gather(logits_logsoftmax, dim=2, index=labels.unsqueeze(2)).squeeze(2)
# print(f'seq_logprobs shape: {logprobs.shape}')
# Weighted sum over logprobs using loss mask
# seq_logprobs = (logprobs * loss_mask).sum(-1)
seq_logprobs = logprobs.sum(-1)
return seq_logprobs
def calculate_mos_loss(
args,
stu_output,
teacher_model,
tokens,
position_ids,
attention_mask
):
mos_loss = 0
alpha = args.kd_alpha_ce
beta = args.kd_beta_ce
kd_temp = args.kd_temp
if teacher_model:
with torch.no_grad():
if (
args.curriculum_learning_legacy and
args.curriculum_seqlen < args.seq_length
):
assert args.curriculum_seqlen is not None
curriculum_seqlen = args.curriculum_seqlen
tokens = tokens[:, :curriculum_seqlen].contiguous()
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
csl = curriculum_seqlen
attention_mask = (
attention_mask[:, :, :csl, :csl].contiguous()
)
# No need to truncate labels
# as we do not need it for the teacher logits
tea_output, tea_other_losses = teacher_model(
tokens,
position_ids,
attention_mask
)
assert stu_output.size() == tea_output.size(), (
'teacher and student output should match in size. '
f'Student: {stu_output.size()}, '
f'Teacher: {tea_output.size()}, '
f'CL seq length {args.curriculum_seqlen}'
)
student_logits = F.log_softmax(stu_output / kd_temp, dim=2)
# The target logits is expected to be probabilities.
# If we use log_softmax,
# then we need to set target_log to true
# when initializing the KLDivLoss.
tea_logits = F.softmax(tea_output / kd_temp, dim=2)
mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')(
student_logits,
tea_logits
)
mos_loss = mos_loss.div(args.seq_length) * beta
return mos_loss
def calculate_dpo_loss(
args,
stu_output,
teacher_model,
logprobs_p,
logprobs_u,
ref_logprobs_p,
ref_logprobs_u,
tokens,
position_ids,
attention_mask
):
mos_loss = 0
alpha = args.kd_alpha_ce
beta = args.kd_beta_ce
kd_temp = args.kd_temp
kd_temp = 1.0
beta = 0.1 # add to cmdline args
if teacher_model:
with torch.no_grad():
if (
args.curriculum_learning_legacy and
args.curriculum_seqlen < args.seq_length
):
assert args.curriculum_seqlen is not None
curriculum_seqlen = args.curriculum_seqlen
tokens = tokens[:, :curriculum_seqlen].contiguous()
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
csl = curriculum_seqlen
attention_mask = (
attention_mask[:, :, :csl, :csl].contiguous()
)
# No need to truncate labels
# as we do not need it for the teacher logits
ref_output, ref_other_losses = teacher_model(
tokens,
position_ids,
attention_mask
)
assert stu_output.size() == ref_output.size(), (
'ref and student output should match in size. '
f'Student: {stu_output.size()}, '
f'Reference: {ref_output.size()}, '
f'CL seq length {args.curriculum_seqlen}'
)
student_logits = F.log_softmax(stu_output / kd_temp, dim=2)
# Labels ?
logprobs = torch.gather(student_logits, dim=2, index=labels.unsqueeze(2)).squeeze(2)
# The target logits is expected to be probabilities.
# If we use log_softmax,
# then we need to set target_log to true
# when initializing the KLDivLoss.
# Get ratios of preferred log probabilities from model and ref model
logprob_ratio_p = logprobs_p - ref_logprobs_p
# Get ratios of unpreferred log probabilities from model and ref model
logprob_ratio_u = logprobs_u - ref_logprobs_u
# Difference of logprobs ratios scaled by beta
scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u)
# Losses computed as negative logsigmoid of scaled difference
losses = -F.logsigmoid(scaled_diff_logprob_ratios)
# preferred dpo rewards
pref_dpo_rewards = (beta * logprob_ratio_p).detach()
# unpreferred dpo rewards
unpref_dpo_rewards = (beta * logprob_ratio_u).detach()
# Implicit DPO rewards
implicit_dpo_rewards = (pref_dpo_rewards > unpref_dpo_rewards).float()
rewards = implicit_dpo_rewards.cpu().mean()
# Compute mean loss
dpo_loss = losses.mean()
# print(f'Loss dtype: {loss.dtype}')
return dpo_loss, rewards
def compute_dp_loss(logprobs_p, ref_logprobs_p,
logprobs_u, ref_logprobs_u,
beta=0.1):
# Get ratios of preferred log probabilities from model and ref model
logprob_ratio_p = logprobs_p - ref_logprobs_p
# Get ratios of unpreferred log probabilities from model and ref model
logprob_ratio_u = logprobs_u - ref_logprobs_u
# Difference of logprobs ratios scaled by beta
scaled_diff_logprob_ratios = beta * (logprob_ratio_p - logprob_ratio_u)
# Losses computed as negative logsigmoid of scaled difference
losses = -F.logsigmoid(scaled_diff_logprob_ratios)
# Compute mean loss
dp_loss = losses.mean()
return dp_loss
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
if args.data_efficiency_curriculum_learning:
args.curriculum_seqlen = tokens.size()[1]
if (
hasattr(
args,
'data_efficiency_curriculum_learning_seqlen_type')
and (
args.data_efficiency_curriculum_learning_seqlen_type
== 'seqlen_reshape'
)
):
args.data_efficiency_curriculum_learning_numel = (
torch.numel(tokens)
)
if args.mos or args.kd:
# The forward func can return either the loss or the logits,
# depending on whether passing in the labels or not.
stu_output, other_losses = model(tokens, position_ids, attention_mask)
if (
args.curriculum_learning_legacy
and args.curriculum_seqlen < args.seq_length
):
assert args.curriculum_seqlen is not None
labels = labels[:, :args.curriculum_seqlen].contiguous()
output_tensor = tensor_parallel.vocab_parallel_cross_entropy(
stu_output.contiguous().float(),
labels
)
else:
output_tensor, other_losses = model(
tokens,
position_ids,
attention_mask,
labels=labels
)
if (
args.curriculum_learning_legacy and
args.curriculum_seqlen < args.seq_length
):
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
moe_losses = []
for moe_loss in other_losses:
if moe_loss is not None:
moe_losses.append(moe_loss)
moe_loss = sum(moe_losses) * args.moe_loss_coeff
mos_loss = 0
if args.mos or args.kd:
assert model.training
if args.teacher_forward and args.teacher_model is not None:
mos_loss = calculate_mos_loss(
args,
stu_output,
args.teacher_model[0],
tokens,
position_ids,
attention_mask
)
# Output_tensor stores the standard loss,
# loss_func calculates the total loss.
return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for GPT ...')
files = []
if args.data_file_list is not None:
with open(args.data_file_list, 'r') as flist:
for f in flist.readlines():
w, fname = f.split()
files.append(float(w))
files.append(fname)
elif len(args.data_path) == 1 and os.path.isdir(args.data_path[0]):
path = args.data_path[0] + "/"
for f in os.listdir(path):
if (os.path.isfile(path + f) and f.find(".bin") != -1):
files.append(1)
files.append(path + f.split(".bin")[0])
else:
files = args.data_path
print_rank_0(f"file list {files}")
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=files,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=True,
# skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_cache_path=args.data_cache_path)
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
def command_exists(cmd):
result = subprocess.Popen(
f'type {cmd}',
stdout=subprocess.PIPE,
shell=True
)
return result.wait() == 0
def git_ds_info():
if RANK != 0:
return
from deepspeed.env_report import main as ds_report
ds_report()
# Write out version/git info
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
if command_exists('git'):
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode('utf-8').strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
git_branch = "unknown"
else:
git_hash = "unknown"
git_branch = "unknown"
print(
f'**** Git info for Megatron: '
f'git_hash={git_hash} git_branch={git_branch} ****'
)
def main():
# if RANK == 0:
# setup_wandb()
if os.getenv('TORCH_PROFILER_ENABLED') == '1':
from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(
# extra_args_provider=extra_args_provider,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
# external_args=external_args
)
# Set pytorch JIT layer fusion options and warmup JIT functions.
if get_accelerator().device_name() == 'cuda':
set_jit_fusion_options()
args = get_args()
timers = get_timers()
# model = model_provider()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder)
prof.export_chrome_trace(f"{args.tensorboard_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}.json")
else:
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(
# extra_args_provider=extra_args_provider,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
# external_args=external_args
)
# Set pytorch JIT layer fusion options and warmup JIT functions.
if get_accelerator().device_name() == 'cuda':
set_jit_fusion_options()
args = get_args()
timers = get_timers()
if args.deepspeed:
args.deepspeed_config_dict = _create_ds_config_dict()
if "curriculum_learning" in args.deepspeed_config_dict and \
"enabled" in args.deepspeed_config_dict["curriculum_learning"]:
args.curriculum_learning_legacy = args.deepspeed_config_dict[ \
"curriculum_learning"]["enabled"]
if args.curriculum_learning_legacy and not args.no_pipeline_parallel:
from deepspeed.runtime.data_pipeline.curriculum_scheduler \
import CurriculumScheduler
args.curriculum_scheduler = CurriculumScheduler( \
args.deepspeed_config_dict["curriculum_learning"])
if "compression_training" in args.deepspeed_config_dict:
args.compression_training = True
from copy import deepcopy
ds_config_copy = deepcopy(args.deepspeed_config_dict)
ds_config_copy["flops_profiler"]["output_file"] = f"dsflops_nlayer{args.num_layers}_worldsize{WORLD_SIZE}_seq{args.seq_length}_mb{args.micro_batch_size}.log"
print_rank_0(f'Deepspeed config updated with out: {ds_config_copy["flops_profiler"]}')
# model = model_provider()
# model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder)
model = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes?
# TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider)
optimizer = get_megatron_optimizer(model, None, None, 1.0)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
model, optimizer, _, opt_param_scheduler = deepspeed.initialize(
model=model[0],
optimizer=optimizer,
args=args,
lr_scheduler=opt_param_scheduler,
mpu=mpu if args.no_pipeline_parallel else None,
config=args.deepspeed_config_dict,
)
model = [model]
print_rank_0(get_parameters_in_billions(model))
#exit()
# ---------- Reference model -------------
# model_ref, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder) # throwing assertion error
model_ref = get_model(model_provider, ModelType.encoder_or_decoder) # works but does it load from a checkpoint or randomly initializes?
# # TRY deepspeed init and load_checkpoint directly here from model_ref = get_model(model_provider)
# optimizer_2 = get_megatron_optimizer(model_ref, None, None, 1.0)
# opt_param_scheduler_2 = get_optimizer_param_scheduler(optimizer_2)
# model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize(
# model=model_ref[0],
# optimizer=optimizer_2,
# args=args,
# lr_scheduler=opt_param_scheduler_2,
# mpu=mpu if args.no_pipeline_parallel else None,
# config=args.deepspeed_config_dict,
# )
# model_ref, _, _, _ = deepspeed.initialize(
# model=model_ref[0],
# optimizer=None,
# args=args,
# lr_scheduler=None,
# mpu=mpu if args.no_pipeline_parallel else None,
# config=args.deepspeed_config_dict,
# )
# engine = deepspeed.init_inference(model=model_ref[0],
# mp_size=args.tensor_model_parallel_size,
# tensor_parallel={"mpu": mpu},
# dtype=torch.half,
# replace_with_kernel_inject=True,
# # moe_experts=args.num_experts,
# # moe_type=args.mlp_type
# )
# model_ref = engine.module
# deepspeed initialization of reference model without optimizer
ds_config_ref_dict = args.deepspeed_config_dict.copy()
if 'zero_optimization' in ds_config_ref_dict:
print_rank_0(f'args.deepspeed_config_dict before: {args.deepspeed_config_dict}')
print_rank_0(f'ds_config_ref_dict before: {ds_config_ref_dict}')
if 'zero_optimization' in ds_config_ref_dict.keys():
del ds_config_ref_dict['zero_optimization']
if 'optimizer' in ds_config_ref_dict.keys():
del ds_config_ref_dict['optimizer']
if 'train_batch_size' in ds_config_ref_dict.keys():
del ds_config_ref_dict['train_batch_size']
print_rank_0(f'args.deepspeed_config_dict after: {args.deepspeed_config_dict}')
print_rank_0(f'ds_config_ref_dict after: {ds_config_ref_dict}')
model_ref, optimizer_2, _, opt_param_scheduler_2 = deepspeed.initialize(
model=model_ref[0],
config=ds_config_ref_dict
)
print_rank_0(f'ref optimizer: {optimizer_2}')
print_rank_0(f'ref param scheduler: {opt_param_scheduler_2}')
assert optimizer_2 == None, "Reference model optimizer is not None"
assert opt_param_scheduler_2 == None, "Reference param scheduler is not None"
if isinstance(model_ref, deepspeed.PipelineEngine):
print(f'Doing assertion checks on model_ref..')
# hack to get batch_fn from pretrain_gpt.py
model_ref.set_batch_fn(model_ref.module._megatron_batch_fn)
assert model_ref.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank()
assert model_ref.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank()
assert model_ref.grid.get_data_parallel_rank() == mpu.get_data_parallel_rank()
model_ref = [model_ref]
iteration2 = load_checkpoint(model_ref, optimizer_2, opt_param_scheduler_2) # THIS WORKED!! After commenting out assert args.consumed_train_samples == 0 in load_checkpoint()
# THINGS THAT DID NOT WORK FOR LOADING FROM CHECKPOINT
# model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only(model_provider) # DID NOT WORK - train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size 32 != 8 * 1 * 8
# model_ref, optimizer_ref, lr_scheduler_ref = load_model_weights_only_modified(model_provider) # DID NOT WORK - optimizer = FusedAdam(TypeError: FusedAdam.__init__() got an unexpected keyword argument 'beta1'
# ----------------------------------------
if args.data_file_list_u is not None:
print(f'data files list unpreferred: {args.data_file_list_u}')
# Number of train/valid/test samples.
if args.train_samples:
print(f'args.train_samples: {args.train_samples}')
train_samples = args.train_samples
else:
print(f'args.train_iters: {args.train_iters}')
print(f'args.global_batch_size: {args.global_batch_size}')
train_samples = args.train_iters * args.global_batch_size
print(f'args.eval_interval: {args.eval_interval}')
print(f'args.eval_iters: {args.eval_iters}')
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(f'train_val_test_num_samples: {train_val_test_num_samples}')
# print(f'args.data_impl: {args.data_impl}')
# print(f'args.split: {args.split}')
# print(f'args.seq_length: {args.seq_length}')
# print(f'args.seed: {args.seed}')
# print(f'args.train_data_path: {args.train_data_path}')
# print(f'args.valid_data_path: {args.valid_data_path}')
# print(f'args.test_data_path: {args.test_data_path}')
# print(f'args.data_cache_path: {args.data_cache_path}')
files_u = []
with open(args.data_file_list_u, 'r') as flist:
for f in flist.readlines():
w, fname = f.split()
files_u.append(float(w))
files_u.append(fname)
train_ds_u, valid_ds_u, test_ds_u = build_train_valid_test_datasets(
data_prefix=files_u,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=True,
# skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_cache_path=args.data_cache_path)
print_rank_0("> finished creating unpreferred GPT datasets ...")
if args.data_file_list_p is not None:
print_rank_0(f'data files list preferred: {args.data_file_list_p}')
files_p = []
with open(args.data_file_list_p, 'r') as flist:
for f in flist.readlines():
w, fname = f.split()
files_p.append(float(w))
files_p.append(fname)
train_ds_p, valid_ds_p, test_ds_p = build_train_valid_test_datasets(
data_prefix=files_p,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=True,
# skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_cache_path=args.data_cache_path)
print_rank_0("> finished creating preferred GPT datasets ...")
# Data loaders
print_rank_0(f'args.consumed_train_samples: {args.consumed_train_samples}')
print_rank_0(f'args.dataloader_type: {args.dataloader_type}')
train_dataloader_u = build_pretraining_data_loader(
train_ds_u, args.consumed_train_samples)