Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Varying experts per layer #43

Open
wants to merge 85 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
9261410
Update switch_mlp.py
pglorio Dec 5, 2023
f2ecc71
Update switch_mlp.py
pglorio Dec 5, 2023
c7ea711
Update switch_mlp.py
pglorio Dec 5, 2023
d5f7af4
Update switch_mlp.py
pglorio Dec 6, 2023
35d3ac4
Update transformer_block.py
pglorio Dec 6, 2023
0e7eb23
Update switch_mlp.py
pglorio Dec 6, 2023
4ad505a
Update transformer_block.py
pglorio Dec 6, 2023
3af0059
Update transformer_block.py
pglorio Dec 6, 2023
13c54d9
Update transformer_block.py
pglorio Dec 6, 2023
df2eeb2
Update arguments.py
pglorio Dec 7, 2023
1f07db5
Update initialize.py
pglorio Dec 7, 2023
b9d6475
Update initialize.py
pglorio Dec 7, 2023
1c20628
Update initialize.py
pglorio Dec 7, 2023
413bf43
Update initialize.py
pglorio Dec 7, 2023
9237b34
Update initialize.py
pglorio Dec 7, 2023
01dc1e9
Update arguments.py
pglorio Dec 7, 2023
68b99a0
Update arguments.py
pglorio Dec 7, 2023
28bc3a3
Update pretrain_gpt.py
pglorio Dec 7, 2023
b819330
Update arguments.py
pglorio Dec 7, 2023
56ad120
Update transformer_block.py
pglorio Dec 7, 2023
8ce18aa
Update switch_mlp.py
pglorio Dec 7, 2023
0875fc5
Update arguments.py
pglorio Dec 7, 2023
4989c58
Update transformer_block.py
pglorio Dec 7, 2023
ef743f4
Update transformer_block.py
pglorio Dec 7, 2023
25056d5
Update switch_mlp.py
pglorio Dec 7, 2023
7dad158
Update switch_mlp.py
pglorio Dec 7, 2023
2aee703
Update switch_mlp.py
pglorio Dec 7, 2023
babbebb
Update switch_mlp.py
pglorio Dec 7, 2023
179d2a4
Update switch_mlp.py
pglorio Dec 7, 2023
253c87c
Update transformer_block.py
pglorio Dec 7, 2023
a058ac2
Update README.md
pglorio Dec 7, 2023
c178073
Merge pull request #47 from Zyphra/main
Quentin-Anthony Dec 7, 2023
cec3437
Update mlp.py
pglorio Dec 7, 2023
381e899
Update switch_mlp.py
pglorio Dec 7, 2023
cd4b5b1
Update transformer_layer.py
pglorio Dec 7, 2023
3c0a58f
Update mlp.py
pglorio Dec 7, 2023
2644ff8
Update switch_mlp.py
pglorio Dec 7, 2023
911ceeb
Update mlp.py
pglorio Dec 7, 2023
7d90105
Update arguments.py
pglorio Dec 7, 2023
0ff85c7
Update mlp.py
pglorio Dec 7, 2023
e91b587
Update mlp.py
pglorio Dec 7, 2023
7f7d9ab
Update mlp.py
pglorio Dec 7, 2023
1f052fc
Update README.md
pglorio Dec 7, 2023
25dd9a8
Update mlp.py
pglorio Dec 7, 2023
e0f37f1
Update mlp.py
pglorio Dec 7, 2023
7e4a6ae
Update switch_mlp.py
pglorio Dec 14, 2023
42be8d4
Update switch_mlp.py
pglorio Dec 14, 2023
f787add
Update switch_mlp.py
pglorio Dec 14, 2023
03b50ec
Update gpt_model.py
pglorio Dec 15, 2023
02969e7
Update switch_mlp.py
pglorio Dec 15, 2023
fb61494
Update switch_mlp.py
pglorio Dec 19, 2023
e26ca57
Update training.py
pglorio Dec 19, 2023
83bcde1
Update switch_mlp.py
pglorio Dec 20, 2023
a8ac369
Update switch_mlp.py
pglorio Dec 20, 2023
c2edfbb
Update switch_mlp.py
pglorio Dec 20, 2023
74b0c7b
print statement of data shape
pglorio Dec 20, 2023
8f7bb1c
print statement of data shape
pglorio Dec 20, 2023
714cd8d
Update transformer_layer.py
pglorio Dec 20, 2023
5fd134f
Update gpt_model.py
pglorio Dec 20, 2023
d95239c
Update training.py
pglorio Dec 20, 2023
8ac39c9
Update gpt_model.py
pglorio Dec 20, 2023
687687e
Update gpt_model.py
pglorio Dec 20, 2023
fff6470
Update gpt_model.py
pglorio Dec 21, 2023
fac5f79
Update gpt_model.py
pglorio Dec 21, 2023
b157390
Update gpt_model.py
pglorio Dec 21, 2023
214e0d0
Update pretrain_gpt.py
pglorio Dec 21, 2023
0597d75
Update gpt_model.py
pglorio Dec 21, 2023
2b017d1
Update transformer_layer.py
pglorio Dec 21, 2023
09b1228
Update switch_mlp.py
pglorio Dec 29, 2023
3800c0e
Update switch_mlp.py
pglorio Dec 29, 2023
5a583a2
Update gpt_model.py
pglorio Dec 29, 2023
5123488
Update fused_softmax.py
pglorio Dec 29, 2023
0d6f7c9
Update fused_softmax.py
pglorio Dec 29, 2023
f8c8de9
Update fused_softmax.py
pglorio Dec 29, 2023
2bfb7c0
Update fused_softmax.py
pglorio Dec 29, 2023
36a2376
Update fused_softmax.py
pglorio Dec 29, 2023
7048fc1
Update switch_mlp.py
pglorio Jan 1, 2024
431fa63
Update switch_mlp.py
pglorio Jan 1, 2024
fc5390c
Update switch_mlp.py
pglorio Jan 2, 2024
5791a87
Update switch_mlp.py
pglorio Jan 2, 2024
03aecf6
Update switch_mlp.py
pglorio Jan 2, 2024
0e5b34f
Update switch_mlp.py
pglorio Jan 2, 2024
704726d
Update switch_mlp.py
pglorio Jan 2, 2024
277afeb
Update switch_mlp.py
pglorio Jan 2, 2024
4a31d0f
Update switch_mlp.py
pglorio Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ A sample plot for `top2` routing mode (obtained from a tiny toy model) is:

<img src="images/token_count.png" alt="Token Counts" width="70%">

## Varying expert number and MLP hidden dimension across layers

To set different number of experts across layers use the flag `--moe-layers` followed by a sequence of integers corresponding to the number of experts per layer. For example, in a model with 5 layers, one can write `--moe-layers 1 8 16 8 1`.

This flag does not currently support pipeline parallelism. Also, for MoE layers, each of these numbers should be multiple of `--expert-model-parallel-size` and greater than 2. For a dense layer, the number should be set to 1.

To change the hidden dimension of MLP's across layers, use the flag `--ffn-hidden-ratio` followed by a sequence of integers corresponding to the ratio between the hidden dimension and the model's embedding dimension. Without this flag, this value is set by default to 4 for all layers (unless `--ff-hidden-size` is used). For example, for a model with 5 layers, one can write `--ffn-hidden-ratio 4 4 2 4 4`.

# NVIDIA Megatron-LM (copied from upstream)

Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf), [2](https://arxiv.org/pdf/2104.04473.pdf), and [3](https://arxiv.org/pdf/2205.05198)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel ([tensor](https://arxiv.org/pdf/1909.08053.pdf), [sequence](https://arxiv.org/pdf/2205.05198), and [pipeline](https://arxiv.org/pdf/2104.04473.pdf)), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision.
Expand Down
13 changes: 13 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,13 @@ def validate_args(args, defaults={}):
# MoE Spec check
if args.num_experts is not None:
assert args.model_spec is None, "Model Spec must be None when using MoEs"
assert args.num_experts > 1, "--num-experts should be greater than 2."
if args.use_balancing_loss is not None:
assert (args.routing_mode == 'top1' or args.routing_mode == 'top2'), "Need --routing-mode = 'top1' or 'top2' if setting --use-balancing-loss."
if args.moe_layers is not None:
assert len(args.moe_layers) == args.num_layers, "length of --moe-layers should equal --num-layers."
assert min(x for x in args.moe_layers if x != 1) > 2, "Experts per layer should be greater than 2."
assert args.use_mcore_models == True, "--moe-layers supported only with --use-mcore-models."

# Expert parallelism check
if args.expert_model_parallel_size > 1:
Expand All @@ -401,6 +406,8 @@ def validate_args(args, defaults={}):
if args.tensor_model_parallel_size > 1:
assert args.sequence_parallel, \
"When using expert parallelism and tensor parallelism, sequence parallelism must be used."
if args.moe_layers is not None:
assert all(x % args.expert_model_parallel_size == 0 for x in args.moe_layers if x != 1), "Experts per layer should be multiple of --expert-model-parallel-size."

# Print arguments.
_print_args("arguments", args)
Expand Down Expand Up @@ -628,6 +635,12 @@ def _add_network_size_args(parser):
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--moe-layers', nargs='+', type=int, default=None,
help='Number of experts for each layer (`1` means dense layer). '
'Does not support pipeline parallelism.')
group.add_argument('--ffn-hidden-ratio', nargs='+', type=int, default=None,
help='Ratio of MLP intermediate layer over embedding dimension (4 is default). '
'It can be different in each layer.')
group.add_argument('--routing-mode', type=str, default='sinkhorn',
choices=['sinkhorn', 'top1', 'top2', 'sinkhorn_top2'],
help='Mode of the expert routing.')
Expand Down
3 changes: 2 additions & 1 deletion megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def forward(
args = get_args()
if args.use_balancing_loss is not None:
args.l_aux = 0.0

hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
Expand All @@ -171,7 +172,7 @@ def forward(
loss = self.compute_language_model_loss(labels, logits)
if args.use_balancing_loss is not None:
loss += args.use_balancing_loss * args.l_aux

return loss

def shared_embedding_or_output_weight(self) -> Tensor:
Expand Down
23 changes: 16 additions & 7 deletions megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron import get_args


@dataclass
Expand All @@ -37,21 +38,29 @@ class MLP(MegatronModule):
"""

def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules, is_expert: bool = False
self, config: TransformerConfig, submodules: MLPSubmodules, is_expert: bool = False, layer=None
):
super().__init__(config=config)


args = get_args()
self.config: TransformerConfig = config

self.layer = layer
if layer and args.ffn_hidden_ratio:
ffn_hidden_size_1 = self.config.hidden_size * args.ffn_hidden_ratio[layer-1]
ffn_hidden_size_2 = self.config.hidden_size * args.ffn_hidden_ratio[layer-1]
else:
ffn_hidden_size_1 = self.config.ffn_hidden_size
ffn_hidden_size_2 = self.config.ffn_hidden_size

# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
ffn_hidden_size_1 *= 2


self.linear_fc1 = build_module(
submodules.linear_fc1,
self.config.hidden_size,
ffn_hidden_size,
ffn_hidden_size_1,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
Expand All @@ -72,7 +81,7 @@ def glu(x):

self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
ffn_hidden_size_2,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
Expand Down
35 changes: 16 additions & 19 deletions megatron/core/transformer/switch_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def __init__(self, config: TransformerConfig, submodules: MLPSubmodules, layer=N
args = get_args()

self.config: TransformerConfig = config

self.router = torch.nn.Linear(self.config.hidden_size, self.config.num_moe_experts)
if args.moe_layers:
self.num_moe_experts = args.moe_layers[layer-1]
else:
self.num_moe_experts = self.config.num_moe_experts
self.router = torch.nn.Linear(self.config.hidden_size, self.num_moe_experts)
self.add_bias = config.add_bias_linear
self.routing = args.routing_mode # 'sinkhorn', 'top1', 'top2', 'sinkhorn_top2'
self.layer = layer
Expand All @@ -61,8 +64,8 @@ def __init__(self, config: TransformerConfig, submodules: MLPSubmodules, layer=N
self.router_activation = torch.sigmoid
self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()

assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
assert self.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.num_moe_experts // self.expert_parallel_size
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
Expand All @@ -72,7 +75,7 @@ def __init__(self, config: TransformerConfig, submodules: MLPSubmodules, layer=N

self.local_experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
expert = MLP(self.config, submodules, is_expert=True, layer=layer)
self.local_experts.append(expert)

def gather_indices(self, local_indices):
Expand All @@ -97,7 +100,7 @@ def forward(self, hidden_states):
args = get_args()
hidden_shape = hidden_states.shape
route = self.router(hidden_states)
route = route.view(-1, self.config.num_moe_experts)
route = route.view(-1, self.num_moe_experts)

if self.config.timers is not None:
self.config.timers('routing_block1', log_level=2).start()
Expand Down Expand Up @@ -163,20 +166,18 @@ def forward(self, hidden_states):
if self.config.timers is not None:
self.config.timers('routing_gather').stop()



# Evaluate balancing loss.
if (args.use_balancing_loss is not None) and self.training:
if hasattr(args, 'l_aux'):
me = torch.mean(route, dim=0)
mask1 = F.one_hot(global_indices, num_classes=self.config.num_moe_experts)
mask1 = F.one_hot(global_indices, num_classes=self.num_moe_experts)
ce = torch.mean(mask1.float(), dim=0)
args.l_aux += torch.sum(me * ce) * self.config.num_moe_experts
args.l_aux += torch.sum(me * ce) * self.num_moe_experts
if self.routing == 'top2':
me_2 = torch.mean(masked_route, dim=0)
mask1 = F.one_hot(global_indices_2, num_classes=self.config.num_moe_experts)
mask1 = F.one_hot(global_indices_2, num_classes=self.num_moe_experts)
ce_2 = torch.mean(mask1.float(), dim=0)
args.l_aux += torch.sum(me_2 * ce_2) * self.config.num_moe_experts
args.l_aux += torch.sum(me_2 * ce_2) * self.num_moe_experts

# Collect token count for each expert and save to file
if self.router_profiling_interval and (args.curr_iteration % self.router_profiling_interval == 0) and args.curr_iteration > 0:
Expand Down Expand Up @@ -223,7 +224,6 @@ def forward(self, hidden_states):
if self.config.timers is not None:
self.config.timers('routing_loop').stop()


if self.config.timers is not None:
self.config.timers('ep_scatter', log_level=2).start()
if self.sequence_parallel or (self.expert_parallel_size > 1):
Expand Down Expand Up @@ -259,23 +259,20 @@ def forward(self, hidden_states):
if self.config.timers is not None:
self.config.timers('ep_scatter').stop()


if self.config.timers is not None:
self.config.timers('final_route', log_level=2).start()
output_total = output_total * max_prob
if self.routing == 'top2' or self.routing == 'sinkhorn_top2':
output_total_2 = output_total_2 * max_prob_2
output_total = output_total + output_total_2
output_total = (output_total + output_total_2 * max_prob_2)
output_total = output_total.view(hidden_shape)
if self.add_bias:
output_bias_total = output_bias_total * max_prob
if self.routing == 'top2' or self.routing == 'sinkhorn_top2':
output_bias_total_2 = output_bias_total_2 * max_prob_2
output_bias_total = output_bias_total + output_bias_total_2
output_bias_total = (output_bias_total + output_bias_total_2 * max_prob_2)
output_bias_total = output_bias_total.view(hidden_shape)
else:
output_bias_total = None
if self.config.timers is not None:
self.config.timers('final_route').stop()

return output_total, output_bias_total
31 changes: 25 additions & 6 deletions megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
from megatron.core.models.gpt.gpt_layer_specs import gpt_layer_with_transformer_engine_spec
from megatron import get_args


class TransformerBlock(MegatronModule):
Expand Down Expand Up @@ -57,12 +59,29 @@ def _build_layers(self, transformer_layer_spec):
# coeff = self.layer_number
# self.norm_factor *= coeff
def build_layer(layer_number):
layer = TransformerLayer(
config=self.config,
submodules=transformer_layer_spec.submodules,
layer_number=layer_number,
self_attn_mask_type=self.self_attn_mask_type,
)
args = get_args()
if args.moe_layers:
if args.moe_layers[layer_number-1] == 1:
layer = TransformerLayer(
config=self.config,
submodules=gpt_layer_with_transformer_engine_spec.submodules,
layer_number=layer_number,
self_attn_mask_type=self.self_attn_mask_type,
)
else:
layer = TransformerLayer(
config=self.config,
submodules=transformer_layer_spec.submodules,
layer_number=layer_number,
self_attn_mask_type=self.self_attn_mask_type,
)
else:
layer = TransformerLayer(
config=self.config,
submodules=transformer_layer_spec.submodules,
layer_number=layer_number,
self_attn_mask_type=self.self_attn_mask_type,
)
return layer

if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
if submodules.mlp.module == SwitchMLP:
self.mlp = build_module(submodules.mlp, config=self.config, layer=layer_number)
else:
self.mlp = build_module(submodules.mlp, config=self.config)
self.mlp = build_module(submodules.mlp, config=self.config, layer=layer_number)

## [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
Expand Down
6 changes: 3 additions & 3 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def finish_mpu_init():
dir_path = os.path.join(args.router_profiling_path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)

# No continuation function
return None
# No continuation function
Quentin-Anthony marked this conversation as resolved.
Show resolved Hide resolved
return None


def _compile_dependencies():
Expand Down
12 changes: 12 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
from megatron.model.vision.knn_monitor import compute_feature_bank
from megatron.eval_harness import Evaluator

global prev_params
prev_params = [[] for i in range(1000)]

def print_datetime(string):
"""Note that this call will sync across all ranks."""
Expand Down Expand Up @@ -455,6 +457,16 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

# Update parameters.

# print("JUST BEFORE STEP:" )
if args.curr_iteration % 10 == 0 and torch.distributed.get_rank() == 0 and 1 == 0:
for i,(n, p) in enumerate(model[0].named_parameters()):
if len(prev_params[i]) == 0:
prev_params[i] = p.detach().clone()
param_diff = p - prev_params[i]
grad_sum = str(p.grad.sum().item()) if p.grad is not None else "NO GRAD!"
print(args.curr_iteration, n, p.shape, torch.norm(p).item(), torch.norm(param_diff).item())
prev_params[i] = p.detach().clone()
if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"Optimizer step")
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
Expand Down
2 changes: 1 addition & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
if args.model_spec is not None:
transformer_layer_spec = import_module(args.model_spec)
else:
if args.num_experts is None:
if (args.num_experts is None) and (args.moe_layers is None):
transformer_layer_spec = gpt_layer_with_transformer_engine_spec
else:
transformer_layer_spec = gpt_layer_with_transformer_engine_spec_moe
Expand Down