-
Notifications
You must be signed in to change notification settings - Fork 519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MoE #639
base: main
Are you sure you want to change the base?
MoE #639
Changes from 125 commits
e725eb9
db24750
18450de
4ab7f77
dba42fd
6c5f8a3
6a8e089
1a9a317
ddf6fd4
ab55e07
7aeefd4
3eab45c
6d736da
d07c638
cdb592f
1399841
a13b5b8
935167e
b96972d
0079490
8b1c441
e2c7286
d39a37c
3acfc04
4432261
cef7707
2a6df33
7421890
021974e
d5a0626
fce086f
daa7c91
3a40b7f
2676d03
2fb4c96
448a9a8
e361918
9377aa5
bad3a34
3dbd156
a6496d6
5dd6135
195c77d
4ea9f0a
44fa5ae
b7658eb
16edece
f8e061f
c8a51cf
b95a05f
20103ec
459ef27
84f21db
08d3253
4b07140
cf6fa33
d95f978
3c15abf
4163e70
aae0e0b
76e6e6d
31e387b
f631d8e
142720b
412a55e
669bba5
7d90908
3c97519
6e08b09
1a6a2e2
103e450
48b6c14
8a387fa
3203724
5309137
bf66e68
63c12e1
1c4aa8d
6a94263
b58c316
3bbfaed
364659f
15f5503
6c516d8
8a0758e
c23b048
285ff10
99aec31
a43eae8
aaefc58
270271f
c084d34
6789ee2
f084fa0
14ee7e4
f831adf
26eb3f3
0a3b076
d6ccbf0
6a2c17e
9ccb2f1
f5291ec
ed571a8
71d2d2a
8b72521
fc822a0
1ccaf9a
c157441
06b8010
cf31e53
a750fde
c853a43
bea80ec
dfdcfc5
4606598
f6c707d
76f0376
9631c80
2fa2acb
c517703
b4eb33f
ae6f16a
02781be
a62a1ee
0ecd4b8
d8452a0
8a28ced
91f5553
61ac104
f4faf8a
fdc1021
ed82181
b0cc754
ca9b41f
215c0f5
43baf74
775e514
cd0004b
acb23dd
a143469
1a4bdae
410064a
671bc8e
04a2da5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -198,6 +198,11 @@ class BlockType(StrEnum): | |
implementations of operations like attention to imitate the behavior of Llama. | ||
""" | ||
|
||
moe = "moe" | ||
""" | ||
A block for OLMoE-style Mixture-of-Experts models. | ||
""" | ||
|
||
|
||
class InitFnType(StrEnum): | ||
mitchell = "mitchell" | ||
|
@@ -457,6 +462,61 @@ class ModelConfig(BaseConfig): | |
See :data:`TrainConfig.precision` instead. | ||
""" | ||
|
||
moe_num_experts: Optional[int] = 8 | ||
""" | ||
The number of experts to use in the MoE block. | ||
""" | ||
|
||
moe_top_k: Optional[int] = 2 | ||
""" | ||
The number of experts to select for each token. | ||
""" | ||
|
||
moe_mlp_impl: Optional[str] = "sparse" | ||
""" | ||
Choose "grouped" for grouped GEMM installable via `pip install git+https://[email protected]/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb`. | ||
""" | ||
|
||
moe_log_expert_assignment: Optional[bool] = True | ||
""" | ||
Whether to log the expert assignment. | ||
""" | ||
|
||
moe_shared_expert: Optional[bool] = False | ||
""" | ||
Whether to have an always-used expert like in [DeepSeekMoE](https://arxiv.org/abs/2401.06066). | ||
""" | ||
|
||
moe_lbl_in_fp32: Optional[bool] = False | ||
""" | ||
Whether to perform load balancing in FP32. | ||
""" | ||
|
||
moe_interleave: Optional[bool] = False | ||
""" | ||
Interleave sequential with MoE blocks starting with sequential. | ||
""" | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
moe_loss_weight: Optional[float] = 0.1 | ||
""" | ||
The weight to use for the MoE load balancing loss. | ||
""" | ||
|
||
moe_zloss_weight: Optional[float] = None | ||
""" | ||
Weight for MoE router z-loss where None means no router z-loss. 0.001 is a common value. | ||
""" | ||
|
||
moe_dropless: Optional[bool] = True | ||
""" | ||
Whether to use [dMoE](https://arxiv.org/abs/2211.15841). | ||
""" | ||
|
||
moe_capacity_factor: Optional[float] = 1.25 | ||
""" | ||
The capacity factor to use in the MoE block. Only applies if not using dMoE. | ||
""" | ||
|
||
scale_emb_init: bool = False | ||
""" | ||
If ``True``, embeddings are scaled up by ``sqrt(d_model)`` during initialization. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -660,10 +660,245 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl | |
return OLMoSequentialBlock(layer_id, config, cache) | ||
elif config.block_type == BlockType.llama: | ||
return OLMoLlamaBlock(layer_id, config, cache) | ||
elif config.block_type == BlockType.moe: | ||
return OLMoEBlock(layer_id, config, cache) | ||
else: | ||
raise NotImplementedError(f"Unknown block type: '{config.block_type}'") | ||
|
||
|
||
class OLMoEBlock(OLMoBlock): | ||
""" | ||
This is a a transformer MoE block where the output is computed as ``MoE(LN(x + Attention(LN(x))))`` | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(plus another skip connection). | ||
""" | ||
|
||
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): | ||
try: | ||
from megablocks.layers.arguments import Arguments as MoEArgs | ||
from megablocks.layers.dmoe import dMoE | ||
from megablocks.layers.moe import MoE | ||
except ImportError: | ||
raise ImportError( | ||
"To train MoEs, run `pip install git+https://github.com/Muennighoff/megablocks.git@olmoe`" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's different about your branch for the original source? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It includes zloss which we use during training for better stability There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can view the exact difference here: databricks/megablocks@main...Muennighoff:megablocks:olmoe ; besides zloss it also has expert choice which is currently not used but i think we may want to try in the future when we go multimodal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you upstream this, so we don't have to depend on a private fork? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, opened a PR here databricks/megablocks#133 - If / when it gets merged, I will update the install instructions. If people don't want to use zloss, it also works with the regular megablocks - it's not a big difference. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Muennighoff , so they decided to merge their version instead. Is our version compatible? Will the model you trained work with their implementation of zloss? |
||
) | ||
|
||
nn.Module.__init__(self) | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.layer_id = layer_id | ||
self.config = config | ||
self.hidden_size = ( | ||
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model | ||
) | ||
self.__cache = cache | ||
assert config.d_model % config.n_heads == 0 | ||
|
||
self._activation_checkpoint_fn = None | ||
|
||
# Dropout. | ||
self.dropout = Dropout(config.residual_dropout) | ||
|
||
# Layer norms. | ||
self.k_norm: Optional[LayerNormBase] = None | ||
self.q_norm: Optional[LayerNormBase] = None | ||
if config.attention_layer_norm: | ||
assert config.effective_n_kv_heads is not None | ||
self.k_norm = LayerNormBase.build( | ||
config, | ||
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, | ||
elementwise_affine=config.attention_layer_norm_with_affine, | ||
) | ||
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) | ||
|
||
# Make sure QKV clip coefficient is positive, otherwise it's not well-defined. | ||
if config.clip_qkv is not None: | ||
assert config.clip_qkv > 0 | ||
|
||
# Activation function. | ||
self.act = Activation.build(config) | ||
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 | ||
|
||
# Attention output projection. | ||
self.attn_out = nn.Linear( | ||
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device | ||
) | ||
|
||
# MoE Block | ||
kwargs = { | ||
"activation_fn": F.silu if "swiglu" in config.activation_type.lower() else self.act, | ||
"mlp_type": "glu" if "glu" in config.activation_type.lower() else "mlp", | ||
"mlp_impl": config.moe_mlp_impl, | ||
"hidden_size": config.d_model, | ||
"ffn_hidden_size": int(self.act.output_multiplier * self.hidden_size), | ||
"moe_num_experts": config.moe_num_experts, | ||
# Handled by FSDP (https://github.com/databricks/megablocks/issues/57#issuecomment-1854594483) | ||
"moe_weight_parallelism": False, | ||
"moe_expert_model_parallelism": False, | ||
"moe_top_k": config.moe_top_k, | ||
"moe_capacity_factor": config.moe_capacity_factor, | ||
"moe_loss_weight": config.moe_loss_weight, | ||
"device": config.init_device, | ||
# Handled by FSDP | ||
"bf16": False, | ||
"fp16": False, | ||
"bias": self.config.include_bias, | ||
"return_bias": False, | ||
"shared_expert": self.config.moe_shared_expert, | ||
"moe_lbl_in_fp32": config.moe_lbl_in_fp32, | ||
} | ||
if config.moe_zloss_weight: | ||
kwargs["moe_zloss_weight"] = config.moe_zloss_weight | ||
|
||
self.moe_args = MoEArgs(**kwargs) | ||
self.ffn = dMoE(self.moe_args) if self.config.moe_dropless else MoE(self.moe_args) | ||
|
||
# Rotary embeddings. | ||
if self.config.rope: | ||
self.rotary_emb = RotaryEmbedding(config, self.__cache) | ||
|
||
self.flash_attn_func = None | ||
self.flash_attn_varlen_func = None | ||
if config.flash_attention: | ||
try: | ||
from flash_attn import ( # type: ignore | ||
flash_attn_func, | ||
flash_attn_varlen_func, | ||
) | ||
|
||
self.flash_attn_func = flash_attn_func | ||
self.flash_attn_varlen_func = flash_attn_varlen_func | ||
except ModuleNotFoundError: | ||
pass | ||
|
||
self.attn_norm = LayerNorm.build(config) | ||
self.ff_norm = LayerNorm.build(config) | ||
|
||
# Attention input projection. Projects x -> (q, k, v) | ||
head_dim = config.d_model // config.n_heads | ||
self.fused_dims = ( | ||
config.d_model, | ||
config.effective_n_kv_heads * head_dim, | ||
config.effective_n_kv_heads * head_dim, | ||
) | ||
self.att_proj = nn.Linear( | ||
config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device | ||
) | ||
|
||
def reset_parameters(self): | ||
if self.k_norm is not None: | ||
self.k_norm.reset_parameters() | ||
if self.q_norm is not None: | ||
self.q_norm.reset_parameters() | ||
|
||
if self.config.init_fn == InitFnType.normal: | ||
attn_out_std = ff_out_std = in_std = self.config.init_std | ||
cutoff_factor = self.config.init_cutoff_factor | ||
elif self.config.init_fn == InitFnType.mitchell: | ||
in_std = 1 / math.sqrt(self.config.d_model) | ||
attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1))) | ||
ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1))) | ||
cutoff_factor = self.config.init_cutoff_factor or 3.0 | ||
elif self.config.init_fn == InitFnType.full_megatron: | ||
in_std = self.config.init_std | ||
attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers) | ||
cutoff_factor = self.config.init_cutoff_factor or 3.0 | ||
else: | ||
raise NotImplementedError(self.config.init_fn) | ||
|
||
init_normal(self.att_proj, std=in_std, init_cutoff_factor=cutoff_factor) | ||
init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor) | ||
self.attn_norm.reset_parameters() | ||
self.ff_norm.reset_parameters() | ||
init_normal(self.ffn.experts.mlp.w1, std=in_std, init_cutoff_factor=cutoff_factor) | ||
init_normal(self.ffn.experts.mlp.w2, std=ff_out_std, init_cutoff_factor=cutoff_factor) | ||
if hasattr(self.ffn.experts.mlp, "v1"): | ||
init_normal(self.ffn.experts.mlp.v1, std=in_std, init_cutoff_factor=cutoff_factor) | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.ffn.experts.bias is not None: | ||
torch.nn.init.zeros_(self.ffn.experts.bias) | ||
init_normal(self.ffn.router.layer, std=in_std, init_cutoff_factor=cutoff_factor) | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
attention_bias: Optional[torch.Tensor] = None, | ||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
use_cache: bool = False, | ||
max_doc_len: Optional[int] = None, | ||
cu_doc_lens: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | ||
# Get query, key, value projections. | ||
# shape: | ||
# - for regular attn q, k, v: (batch_size, seq_len, d_model) | ||
# - for multi-query attn q: (batch_size, seq_len, d_model) | ||
# k, v: (batch_size, seq_len, d_model // n_heads) | ||
# - for group query attn q: (batch_size, seq_len, d_model) | ||
# k, v: (batch_size, seq_len, d_model // n_kv_heads) | ||
if not self.config.norm_after: | ||
if self._activation_checkpoint_fn is not None: | ||
qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)) | ||
else: | ||
qkv = self.att_proj(self.attn_norm(x)) | ||
else: | ||
qkv = self.att_proj(x) | ||
|
||
if self.config.clip_qkv is not None: | ||
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) | ||
|
||
q, k, v = qkv.split(self.fused_dims, dim=-1) | ||
|
||
# Get attention scores. | ||
if self._activation_checkpoint_fn is not None: | ||
att, cache = self._activation_checkpoint_fn( # type: ignore | ||
self.attention, | ||
q, | ||
k, | ||
v, | ||
attention_bias, | ||
layer_past=layer_past, | ||
use_cache=use_cache, | ||
max_doc_len=max_doc_len, | ||
cu_doc_lens=cu_doc_lens, | ||
) | ||
else: | ||
att, cache = self.attention( | ||
q, | ||
k, | ||
v, | ||
attention_bias, | ||
layer_past=layer_past, | ||
use_cache=use_cache, | ||
max_doc_len=max_doc_len, | ||
cu_doc_lens=cu_doc_lens, | ||
) | ||
|
||
if self.config.norm_after: | ||
if self._activation_checkpoint_fn is not None: | ||
att = self._activation_checkpoint_fn(self.attn_norm, att) | ||
else: | ||
att = self.attn_norm(att) | ||
|
||
# Add attention scores. | ||
# shape: (B, T, C) | ||
x = x + self.dropout(att) | ||
|
||
# Add feed-forward projection. | ||
# shape: (batch_size, seq_len, d_model) | ||
og_x = x | ||
|
||
if self.config.norm_after: | ||
x = self.ffn(x) | ||
if self._activation_checkpoint_fn is not None: | ||
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore | ||
else: | ||
x = self.ff_norm(x) | ||
return og_x + self.dropout(x), cache | ||
else: | ||
if self._activation_checkpoint_fn is not None: | ||
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore | ||
else: | ||
x = self.ff_norm(x) | ||
# Activation checkpointing for the MoE FFN is not supported | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not? If there is a technical problem with it, will it affect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It fails with
This paper has some explanations why it is difficult to do act ckpt for MoEs: https://dspace.mit.edu/bitstream/handle/1721.1/153897/wisdom-dwisdom-meng-eecs-2024-thesis.pdf
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I see. Interesting. It would be fixable I think (by saving the active experts per token in the forward pass), but out of scope for this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is probably a fairly big blocker to going bigger though. For dense models, our fastest settings still use a lot of checkpointing. |
||
return og_x + self.dropout(self.ffn(x)), cache | ||
|
||
|
||
class OLMoSequentialBlock(OLMoBlock): | ||
""" | ||
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` | ||
|
@@ -1096,8 +1331,15 @@ def __init__(self, config: ModelConfig, init_params: bool = True): | |
ln_f=LayerNorm.build(config), | ||
) | ||
) | ||
|
||
blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] | ||
if self.config.moe_interleave: | ||
blocks = [] | ||
for i in range(config.n_layers): | ||
if i % 2 == 0: | ||
blocks.append(OLMoSequentialBlock(i, config, self.__cache)) | ||
else: | ||
blocks.append(OLMoEBlock(i, config, self.__cache)) | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] | ||
if self.config.block_group_size > 1: | ||
block_groups = [ | ||
OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size]) | ||
|
@@ -1538,7 +1780,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): | |
else: | ||
raise NotImplementedError(wrap_strategy) | ||
|
||
def num_params(self, include_embedding: bool = True) -> int: | ||
def num_params(self, include_embedding: bool = True, include_inactivate_params: bool = True) -> int: | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Get the total number of parameters. | ||
""" | ||
|
@@ -1548,6 +1790,15 @@ def num_params(self, include_embedding: bool = True) -> int: | |
lambda np: ".wte." not in np[0] and ".wpe." not in np[0], | ||
params, | ||
) | ||
if not include_inactivate_params: | ||
# Need to reduce blocks to the number of experts that are selected | ||
# If not dropless 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (total_experts, in_dim, out_dim) | ||
# change to 'transformer.blocks.0.ffn.experts.mlp.w1' with shape (selected_experts, in_dim, out_dim) | ||
# If dropless, the total_experts & out_dim are combined into one dimension | ||
idx = self.config.moe_top_k | ||
if self.config.moe_dropless: | ||
idx *= self.transformer.blocks[1].moe_args.ffn_hidden_size | ||
params = [(np[0], np[1][:idx]) if "experts.mlp" in np[0] else np for np in params] # type: ignore | ||
return sum(p.numel() for _, p in params) | ||
|
||
@property | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these are
Optional
, what does it mean when it'sNone
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're optional when no MoE is used, otherwise required. Is this not an acceptable usage of
Optional[int]
? Can change itThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, when we have a config setting that is not always required we should either 1) always make it optional type, set it to None by default, and set it in every config when it is needed; or 2) don't make it optional type unless
None
is needed. I prefer 1 since it makes our config more readable (less irrelevant settings) and slightly more backwards compatible.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can change it to option 1) if others agree? Note that there's other params not following this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you actually rely on the defaults you put in here anywhere? If not, let's go with Shane's version, and default these to
None
. I assume something somewhere will fail if they are not set and you need them.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes quite a lot, e.g. the loss weights; the use of dropless MoEs (moe_dropless); leaving moe_interleave,moe_lbl_in_fp32,moe_shared_expert as False
Actually, I don't think setting them all to None is a good idea, as it means that everytime we add a new MoE-specific configuration parameter all MoE configs become outdated since every MoE-specific configuration parameter is Optional in that dense.
I can also remove the
Optional
from it as they have defaults anyways but then as seen in the examples I pasted above, we do haveOptional
config params with default values in the codebase anyways.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it doesn't break everything, I'd prefer to have a special config object for MoE, which is
Optional
, but none of the items inside of that object areOptional
. This may break backwards compatibility with the model we already released though?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it would break compat with the configs we released but can pin a commit to our released repo if people want to reuse our configs to reproduce things exactly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, that's unfortunate, but I think I prefer the
MoEConfigObject
. It reduces the impact on old-school dense model training.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it would make the name ModelConfig a bit inaccurate though; maybe it should inherit from ModelConfig or sth