Skip to content

Commit 86758c3

Browse files
yanggithub-actionsQuentin-Anthony
authored
Add MoE (#1129)
* Add DeepSpeed MoE Thanks to dayofthepenguin for extensive testing Closes #479 * Update NeoXArgs docs automatically * pre-commit * Update NeoXArgs docs automatically --------- Co-authored-by: Yang Zhang <[email protected]> Co-authored-by: github-actions <[email protected]> Co-authored-by: Quentin Anthony <[email protected]>
1 parent df8cf24 commit 86758c3

File tree

10 files changed

+434
-31
lines changed

10 files changed

+434
-31
lines changed

configs/125M-moe.yml

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# GPT-2 pretraining setup
2+
{
3+
# Have 4 experts per layer (every 2 layers by default)
4+
# So with 12 layers total:
5+
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
6+
# Experts would be in layers:
7+
# 0, 2, 4, 6, 8, 10
8+
"num_experts": 4,
9+
10+
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
11+
# across the node boundaries )
12+
"pipe_parallel_size": 1,
13+
"model_parallel_size": 1,
14+
"moe_expert_parallel_size": 1,
15+
16+
# model settings
17+
"num_layers": 12,
18+
"hidden_size": 768,
19+
"num_attention_heads": 12,
20+
"seq_length": 2048,
21+
"max_position_embeddings": 2048,
22+
"norm": "layernorm",
23+
"pos_emb": "rotary",
24+
"no_weight_tying": true,
25+
"gpt_j_residual": false,
26+
"output_layer_parallelism": "column",
27+
28+
# these should provide some speedup but takes a while to build, set to true if desired
29+
"scaled_upper_triang_masked_softmax_fusion": false,
30+
"bias_gelu_fusion": false,
31+
"rope_fusion": false,
32+
33+
# init methods
34+
"init_method": "small_init",
35+
"output_layer_init_method": "wang_init",
36+
37+
38+
# optimizer settings
39+
"optimizer": {
40+
"type": "Adam",
41+
"params": {
42+
"lr": 0.0006,
43+
"betas": [0.9, 0.95],
44+
"eps": 1.0e-8,
45+
}
46+
},
47+
"min_lr": 0.00006,
48+
49+
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
50+
"zero_optimization": {
51+
"stage": 1,
52+
"allgather_partitions": True,
53+
"allgather_bucket_size": 500000000,
54+
"overlap_comm": True,
55+
"reduce_scatter": True,
56+
"reduce_bucket_size": 500000000,
57+
"contiguous_gradients": True,
58+
},
59+
60+
# batch / data settings
61+
"train_micro_batch_size_per_gpu": 4,
62+
"data_impl": "mmap",
63+
64+
# activation checkpointing
65+
"checkpoint_activations": true,
66+
"checkpoint_num_layers": 1,
67+
"partition_activations": true,
68+
"synchronize_each_layer": true,
69+
70+
# regularization
71+
"gradient_clipping": 1.0,
72+
"weight_decay": 0.1,
73+
"hidden_dropout": 0.0,
74+
"attention_dropout": 0.0,
75+
76+
# precision settings
77+
"fp16": {
78+
"enabled": true,
79+
"loss_scale": 0,
80+
"loss_scale_window": 1000,
81+
"hysteresis": 2,
82+
"min_loss_scale": 1
83+
},
84+
85+
# misc. training settings
86+
"train_iters": 320000,
87+
"lr_decay_iters": 320000,
88+
"distributed_backend": "nccl",
89+
"lr_decay_style": "cosine",
90+
"warmup": 0.01,
91+
"checkpoint_factor": 10000,
92+
"eval_interval": 1000,
93+
"eval_iters": 10,
94+
95+
# logging
96+
"log_interval": 10,
97+
"steps_per_print": 10,
98+
"keep_last_n_checkpoints": 4,
99+
"wall_clock_breakdown": true,
100+
101+
# networking
102+
"hostfile": "/mock_path"
103+
}

configs/neox_arguments.md

+97-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Logging Arguments
111111

112112
- **git_hash**: str
113113

114-
Default = 2a3c4e1
114+
Default = ae06be5
115115

116116
current git hash of repository
117117

@@ -1007,6 +1007,14 @@ Parallelism Arguments
10071007
10081008
10091009
1010+
- **expert_interval**: int
1011+
1012+
Default = 2
1013+
1014+
Have one MoE layer every expert_interval layers
1015+
1016+
1017+
10101018
## NeoXArgsTemplate
10111019
10121020
NeoXArgsTemplate()
@@ -1128,6 +1136,94 @@ Text Generation arguments
11281136
11291137
11301138
1139+
- **moe_top_k**: int
1140+
1141+
Default = 1
1142+
1143+
Activate top K experts in MoE
1144+
1145+
1146+
1147+
- **use_tutel**: bool
1148+
1149+
Default = False
1150+
1151+
Use Tutel optimizations in MoE
1152+
1153+
1154+
1155+
- **num_experts**: int
1156+
1157+
Default = 1
1158+
1159+
Number of MoE experts
1160+
1161+
1162+
1163+
- **moe_loss_coeff**: float
1164+
1165+
Default = 0.1
1166+
1167+
Coefficient for MoE loss
1168+
1169+
1170+
1171+
- **moe_train_capacity_factor**: float
1172+
1173+
Default = 1.0
1174+
1175+
The capacity of the expert at train time
1176+
1177+
1178+
1179+
- **moe_eval_capacity_factor**: float
1180+
1181+
Default = 1.0
1182+
1183+
The capacity of the expert at eval time
1184+
1185+
1186+
1187+
- **moe_min_capacity**: int
1188+
1189+
Default = 4
1190+
1191+
The minimum capacity per expert regardless of the capacity_factor
1192+
1193+
1194+
1195+
- **moe_token_dropping**: bool
1196+
1197+
Default = True
1198+
1199+
Whether to drop tokens when exceeding capacity
1200+
1201+
1202+
1203+
- **create_moe_param_group**: bool
1204+
1205+
Default = True
1206+
1207+
Whether to create a separate parameter group for MoE parameters
1208+
1209+
1210+
1211+
- **moe_use_residual**: bool
1212+
1213+
Default = True
1214+
1215+
Whether to use residual in MoE
1216+
1217+
1218+
1219+
- **moe_expert_parallel_size**: int
1220+
1221+
Default = 1
1222+
1223+
Number of parallel experts in MoE
1224+
1225+
1226+
11311227
## NeoXArgsTokenizer
11321228
11331229
Tokenizer Arguments

0 commit comments

Comments
 (0)