Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoE #639

Open
wants to merge 144 commits into
base: main
Choose a base branch
from
Open

MoE #639

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
e725eb9
Clean MoE implementation
Muennighoff Jun 20, 2024
db24750
Add conf
Muennighoff Jun 20, 2024
18450de
Fix return args
Muennighoff Jun 20, 2024
4ab7f77
Rmv outdated kwarg
Muennighoff Jun 20, 2024
dba42fd
Rmv legacy kwarg
Muennighoff Jun 20, 2024
6c5f8a3
Merge branch 'Muennighoff/MoE' of github.com:allenai/LLM into Muennig…
Muennighoff Jun 20, 2024
6a8e089
Add distributed_strategy
Muennighoff Jun 20, 2024
1a9a317
Allow w/o weight attr
Muennighoff Jun 20, 2024
ddf6fd4
Merge branch 'Muennighoff/MoE' of github.com:allenai/LLM into Muennig…
Muennighoff Jun 20, 2024
ab55e07
Allow w/o weight attr
Muennighoff Jun 20, 2024
7aeefd4
Add MoE params
Muennighoff Jun 20, 2024
3eab45c
Rmv kwarg
Muennighoff Jun 20, 2024
6d736da
Reduce lb & moe losses
Muennighoff Jun 21, 2024
d07c638
LN & Emb Dec
Muennighoff Jun 21, 2024
cdb592f
Merge branch 'Muennighoff/MoE' of github.com:allenai/LLM into Muennig…
Muennighoff Jun 21, 2024
1399841
Do not decay emb
Muennighoff Jun 21, 2024
a13b5b8
Tmp - debug throughput
Muennighoff Jun 21, 2024
935167e
Fix
Muennighoff Jun 21, 2024
b96972d
Fix
Muennighoff Jun 21, 2024
0079490
maintain init order
Muennighoff Jun 21, 2024
8b1c441
Merge branch 'Muennighoff/MoE' of github.com:allenai/LLM into Muennig…
Muennighoff Jun 21, 2024
e2c7286
Decay emb
Muennighoff Jun 21, 2024
d39a37c
Keep EA on CPU
Muennighoff Jun 21, 2024
3acfc04
Do not decay emb
Muennighoff Jun 21, 2024
4432261
Change norm
Muennighoff Jun 21, 2024
cef7707
Confs
Muennighoff Jun 22, 2024
2a6df33
Adapt wrap
Muennighoff Jun 22, 2024
7421890
Add conf
Muennighoff Jun 23, 2024
021974e
decemb conf
Muennighoff Jun 24, 2024
d5a0626
Updates
Muennighoff Jun 24, 2024
fce086f
up
Muennighoff Jun 25, 2024
daa7c91
Update weka
Muennighoff Jun 25, 2024
3a40b7f
Rev weka
Muennighoff Jun 25, 2024
2676d03
Adapt
Muennighoff Jun 25, 2024
2fb4c96
Add fine
Muennighoff Jun 25, 2024
448a9a8
Adapt
Muennighoff Jun 26, 2024
e361918
Adapt
Muennighoff Jun 26, 2024
9377aa5
Add eps
Muennighoff Jun 26, 2024
bad3a34
Update confs
Muennighoff Jun 27, 2024
3dbd156
Change eps
Muennighoff Jun 27, 2024
a6496d6
Add conf
Muennighoff Jun 28, 2024
5dd6135
Confs
Muennighoff Jun 28, 2024
195c77d
Add conf
Muennighoff Jun 29, 2024
4ea9f0a
Fix path
Muennighoff Jun 29, 2024
44fa5ae
Fix path
Muennighoff Jun 29, 2024
b7658eb
Adapt confs
Muennighoff Jun 30, 2024
16edece
add s3 conf
Muennighoff Jun 30, 2024
f8e061f
fix paths
Muennighoff Jun 30, 2024
c8a51cf
Update confs
Muennighoff Jul 1, 2024
b95a05f
Increase bs
Muennighoff Jul 1, 2024
20103ec
Adjust paths
Muennighoff Jul 1, 2024
459ef27
conf
Muennighoff Jul 1, 2024
84f21db
confs
Muennighoff Jul 1, 2024
08d3253
confs
Muennighoff Jul 1, 2024
4b07140
Update confs
Muennighoff Jul 1, 2024
cf6fa33
Update conf
Muennighoff Jul 1, 2024
d95f978
Conf
Muennighoff Jul 1, 2024
3c15abf
Finegrained
Muennighoff Jul 2, 2024
4163e70
Cx5
Muennighoff Jul 2, 2024
aae0e0b
Add conf
Muennighoff Jul 2, 2024
76e6e6d
Add cx5
Muennighoff Jul 2, 2024
31e387b
Up confs
Muennighoff Jul 2, 2024
f631d8e
Up confs
Muennighoff Jul 2, 2024
142720b
Dfix
Muennighoff Jul 4, 2024
412a55e
reddit
Muennighoff Jul 4, 2024
669bba5
flan
Muennighoff Jul 4, 2024
7d90908
fine
Muennighoff Jul 4, 2024
3c97519
indent
Muennighoff Jul 4, 2024
6e08b09
Add conf
Muennighoff Jul 4, 2024
1a6a2e2
Conf
Muennighoff Jul 5, 2024
103e450
Add shared
Muennighoff Jul 5, 2024
48b6c14
Add
Muennighoff Jul 5, 2024
8a387fa
Add conf
Muennighoff Jul 5, 2024
3203724
Add conf
Muennighoff Jul 6, 2024
5309137
Add conf
Muennighoff Jul 6, 2024
bf66e68
Addqk
Muennighoff Jul 6, 2024
63c12e1
Make QK Norm parametric
Muennighoff Jul 6, 2024
1c4aa8d
dense comp
Muennighoff Jul 7, 2024
6a94263
add conf
Muennighoff Jul 7, 2024
b58c316
conf
Muennighoff Jul 7, 2024
3bbfaed
conf
Muennighoff Jul 7, 2024
364659f
Change torch version
Muennighoff Jul 18, 2024
15f5503
Add anneal
Muennighoff Jul 20, 2024
6c516d8
Resets
Muennighoff Jul 20, 2024
8a0758e
add conf
Muennighoff Jul 21, 2024
c23b048
conf
Muennighoff Jul 21, 2024
285ff10
Adjust mlp hidden
Muennighoff Jul 21, 2024
99aec31
Change ratio
Muennighoff Jul 21, 2024
a43eae8
Add conf
Muennighoff Jul 21, 2024
aaefc58
Conf
Muennighoff Jul 21, 2024
270271f
Adjust
Muennighoff Jul 21, 2024
c084d34
add conf
Muennighoff Jul 21, 2024
6789ee2
Fix conf
Muennighoff Jul 21, 2024
f084fa0
Adjust
Muennighoff Jul 21, 2024
14ee7e4
Add alt
Muennighoff Jul 21, 2024
f831adf
Fix typo; update script
Muennighoff Jul 23, 2024
26eb3f3
merge main
Muennighoff Jul 23, 2024
0a3b076
add moe reorder
Muennighoff Jul 23, 2024
d6ccbf0
Add datafix
Muennighoff Jul 23, 2024
6a2c17e
Max doc len MoE
Muennighoff Jul 23, 2024
9ccb2f1
add fa varlen
Muennighoff Jul 23, 2024
f5291ec
Add conf
Muennighoff Jul 23, 2024
ed571a8
add conf
Muennighoff Jul 24, 2024
71d2d2a
fix conf
Muennighoff Jul 24, 2024
8b72521
fixconf
Muennighoff Jul 24, 2024
fc822a0
fixonf
Muennighoff Jul 24, 2024
1ccaf9a
fix
Muennighoff Jul 24, 2024
c157441
Add 8k
Muennighoff Jul 25, 2024
06b8010
add conf
Muennighoff Jul 26, 2024
cf31e53
conf
Muennighoff Jul 26, 2024
a750fde
Adapt
Muennighoff Jul 31, 2024
c853a43
Merge branch 'main' into Muennighoff/MoE
Muennighoff Jul 31, 2024
bea80ec
Sort
Muennighoff Jul 31, 2024
dfdcfc5
Simplify configs for merge
Muennighoff Jul 31, 2024
4606598
Update megablocks installation
Muennighoff Aug 1, 2024
f6c707d
Simplify
Muennighoff Aug 1, 2024
76f0376
Simplify
Muennighoff Aug 1, 2024
9631c80
typing & changes
Muennighoff Aug 1, 2024
2fa2acb
Format
Muennighoff Aug 1, 2024
c517703
Change torch.tensor to Any so type check passes?
Muennighoff Aug 1, 2024
b4eb33f
Simplify
Muennighoff Aug 1, 2024
ae6f16a
Simplify
Muennighoff Aug 1, 2024
02781be
Sort
Muennighoff Aug 1, 2024
a62a1ee
Fix type checks
Muennighoff Aug 1, 2024
0ecd4b8
Format
Muennighoff Aug 1, 2024
d8452a0
Fix typo; MoEArgs func
Muennighoff Aug 3, 2024
8a28ced
Format
Muennighoff Aug 3, 2024
91f5553
Check for act ckpt strategy & moe; fix typo
Muennighoff Aug 3, 2024
61ac104
fix import
Muennighoff Aug 3, 2024
f4faf8a
Sort impot
Muennighoff Aug 20, 2024
fdc1021
Merge branch 'main' into Muennighoff/MoE
Muennighoff Aug 20, 2024
ed82181
Fix typo
Muennighoff Aug 21, 2024
b0cc754
Simplify isinstance
Muennighoff Aug 21, 2024
ca9b41f
Clean conf & move constructor
Muennighoff Aug 21, 2024
215c0f5
Add ref
Muennighoff Sep 4, 2024
43baf74
Merge main
Muennighoff Sep 4, 2024
775e514
Sort imports
Muennighoff Sep 4, 2024
cd0004b
Format
Muennighoff Sep 4, 2024
acb23dd
No exp ass
Muennighoff Sep 12, 2024
a143469
Revert
Muennighoff Sep 12, 2024
1a4bdae
Simplify
Muennighoff Sep 12, 2024
410064a
Rmv interleave
Muennighoff Oct 3, 2024
671bc8e
Merge branch 'main' into Muennighoff/MoE
Muennighoff Oct 3, 2024
04a2da5
expert_assignments on gpu
Muennighoff Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `OLMoE`: Configurations & modeling for training Mixture-of-Experts models.
- Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`.
- Added support for flash attention and gradient checkpointing to `hf_olmo`.

Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,15 @@ See [Debugging](https://github.com/allenai/OLMo/blob/main/docs/NOTES.md#debuggin
journal={arXiv preprint},
}
```

```bibtex
@misc{muennighoff2024olmoeopenmixtureofexpertslanguage,
title={OLMoE: Open Mixture-of-Experts Language Models},
author={Niklas Muennighoff and Luca Soldaini and Dirk Groeneveld and Kyle Lo and Jacob Morrison and Sewon Min and Weijia Shi and Pete Walsh and Oyvind Tafjord and Nathan Lambert and Yuling Gu and Shane Arora and Akshita Bhagia and Dustin Schwenk and David Wadden and Alexander Wettig and Binyuan Hui and Tim Dettmers and Douwe Kiela and Ali Farhadi and Noah A. Smith and Pang Wei Koh and Amanpreet Singh and Hannaneh Hajishirzi},
year={2024},
eprint={2409.02060},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2409.02060},
}
```
1,493 changes: 1,493 additions & 0 deletions configs/official/OLMoE-7B-A1B.yaml

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from omegaconf.errors import OmegaConfBaseException
Expand Down Expand Up @@ -198,6 +199,11 @@ class BlockType(StrEnum):
implementations of operations like attention to imitate the behavior of Llama.
"""

moe = "moe"
"""
A block for OLMoE-style Mixture-of-Experts models.
"""


class InitFnType(StrEnum):
mitchell = "mitchell"
Expand Down Expand Up @@ -457,6 +463,56 @@ class ModelConfig(BaseConfig):
See :data:`TrainConfig.precision` instead.
"""

moe_num_experts: Optional[int] = 8
"""
The number of experts to use in the MoE block.
"""

moe_top_k: Optional[int] = 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are Optional, what does it mean when it's None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're optional when no MoE is used, otherwise required. Is this not an acceptable usage of Optional[int]? Can change it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, when we have a config setting that is not always required we should either 1) always make it optional type, set it to None by default, and set it in every config when it is needed; or 2) don't make it optional type unless None is needed. I prefer 1 since it makes our config more readable (less irrelevant settings) and slightly more backwards compatible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it to option 1) if others agree? Note that there's other params not following this:

    embedding_size: Optional[int] = 50304
    gen1_gc_interval: Optional[int] = 1
    distributed_strategy: Optional[DistributedStrategy] = DistributedStrategy.fsdp
    fsdp: Optional[FSDPConfig] = field(default_factory=FSDPConfig)
    auxiliary_loss_multiplier: Optional[float] = 1e-4

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually rely on the defaults you put in here anywhere? If not, let's go with Shane's version, and default these to None. I assume something somewhere will fail if they are not set and you need them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually rely on the defaults you put in here anywhere?

Yes quite a lot, e.g. the loss weights; the use of dropless MoEs (moe_dropless); leaving moe_interleave,moe_lbl_in_fp32,moe_shared_expert as False

Actually, I don't think setting them all to None is a good idea, as it means that everytime we add a new MoE-specific configuration parameter all MoE configs become outdated since every MoE-specific configuration parameter is Optional in that dense.

I can also remove the Optional from it as they have defaults anyways but then as seen in the examples I pasted above, we do have Optional config params with default values in the codebase anyways.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't break everything, I'd prefer to have a special config object for MoE, which is Optional, but none of the items inside of that object are Optional. This may break backwards compatibility with the model we already released though?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would break compat with the configs we released but can pin a commit to our released repo if people want to reuse our configs to reproduce things exactly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, that's unfortunate, but I think I prefer the MoEConfigObject. It reduces the impact on old-school dense model training.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it would make the name ModelConfig a bit inaccurate though; maybe it should inherit from ModelConfig or sth

"""
The number of experts to select for each token.
"""

moe_mlp_impl: Optional[str] = "sparse"
"""
Choose "grouped" for grouped GEMM installable via `pip install git+https://[email protected]/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb`.
"""

moe_log_expert_assignment: Optional[bool] = True
"""
Whether to log the expert assignment.
"""

moe_shared_expert: Optional[bool] = False
"""
Whether to have an always-used expert like in [DeepSeekMoE](https://arxiv.org/abs/2401.06066).
"""

moe_lbl_in_fp32: Optional[bool] = False
"""
Whether to perform load balancing in FP32.
"""

moe_loss_weight: Optional[float] = 0.1
"""
The weight to use for the MoE load balancing loss.
"""

moe_zloss_weight: Optional[float] = None
"""
Weight for MoE router z-loss where None means no router z-loss. 0.001 is a common value.
"""

moe_dropless: Optional[bool] = True
"""
Whether to use [dMoE](https://arxiv.org/abs/2211.15841).
"""

moe_capacity_factor: Optional[float] = 1.25
"""
The capacity factor to use in the MoE block. Only applies if not using dMoE.
"""

scale_emb_init: bool = False
"""
If ``True``, embeddings are scaled up by ``sqrt(d_model)`` during initialization.
Expand Down Expand Up @@ -1283,3 +1339,41 @@ def update_legacy_settings(cls, config: D) -> D:
new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)

return new_config


def config_to_moe_args(config: ModelConfig) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to have this as an instance method of ModelConfig that can be invoked with something like config.build_moe_args()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the moe args may include things outside of the ModelConfig in the future. Currently, I put some things that may be considered as TrainingConfig params like moe_zloss_weight in the ModelConfig but in case we move them in the future to TrainingConfig then it would not only use the ModelConfig anymore.

from megablocks.layers.arguments import Arguments as MoEArgs

from .model import Activation

hidden_size = (
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
)
act = Activation.build(config)
kwargs = {
"activation_fn": F.silu if "swiglu" in config.activation_type.lower() else Activation.build(config),
"mlp_type": "glu" if "glu" in config.activation_type.lower() else "mlp",
"mlp_impl": config.moe_mlp_impl,
"hidden_size": config.d_model,
"ffn_hidden_size": int(act.output_multiplier * hidden_size),
"moe_num_experts": config.moe_num_experts,
"num_layers": config.n_layers,
# Handled by FSDP (https://github.com/databricks/megablocks/issues/57#issuecomment-1854594483)
"moe_weight_parallelism": False,
"moe_expert_model_parallelism": False,
"moe_top_k": config.moe_top_k,
"moe_capacity_factor": config.moe_capacity_factor,
"moe_loss_weight": config.moe_loss_weight,
"device": config.init_device,
# Handled by FSDP
"bf16": False,
"fp16": False,
"bias": config.include_bias,
"return_bias": False,
"shared_expert": config.moe_shared_expert,
"moe_lbl_in_fp32": config.moe_lbl_in_fp32,
}
if config.moe_zloss_weight:
kwargs["moe_zloss_weight"] = config.moe_zloss_weight

return MoEArgs(**kwargs)
10 changes: 8 additions & 2 deletions olmo/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@ def init_normal(
# weights
if init_cutoff_factor is not None:
cutoff_value = init_cutoff_factor * std
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
if hasattr(module, "weight"):
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
else:
nn.init.trunc_normal_(module, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
dirkgr marked this conversation as resolved.
Show resolved Hide resolved
else:
nn.init.normal_(module.weight, mean=0.0, std=std)
if hasattr(module, "weight"):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
nn.init.normal_(module, mean=0.0, std=std)

# biases
if isinstance(module, nn.Linear) and module.bias is not None:
Expand Down
182 changes: 173 additions & 9 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,15 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
)

# Feed-forward output projection.
self.ff_out = nn.Linear(
int(self.act.output_multiplier * self.hidden_size),
config.d_model,
bias=config.include_bias,
device=config.init_device,
)
self.ff_out._is_residual = True # type: ignore
if self.config.block_type != BlockType.moe:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this dependent on whether the block has a ff_out, instead of the block type?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with hasattr(), I mean

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean if hasattr(self, "ff_out"):? Not sure that will work because the next lines are about creating self.ff_out so no block has it yet afaict

# Feed-forward output projection.
self.ff_out = nn.Linear(
int(self.act.output_multiplier * self.hidden_size),
config.d_model,
bias=config.include_bias,
device=config.init_device,
)
self.ff_out._is_residual = True # type: ignore

# Rotary embeddings.
if self.config.rope:
Expand Down Expand Up @@ -664,10 +665,164 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl
return OLMoSequentialBlock(layer_id, config, cache)
elif config.block_type == BlockType.llama:
return OLMoLlamaBlock(layer_id, config, cache)
elif config.block_type == BlockType.moe:
return OLMoEBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")


class OLMoEBlock(OLMoBlock):
"""
This is a transformer MoE block where the output is computed as ``MoE(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""

def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
try:
from megablocks.layers.dmoe import dMoE
from megablocks.layers.moe import MoE
except ImportError:
raise ImportError(
"To train MoEs, run `pip install git+https://github.com/Muennighoff/megablocks.git@olmoe`"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's different about your branch for the original source?

Copy link
Collaborator Author

@Muennighoff Muennighoff Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It includes zloss which we use during training for better stability

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can view the exact difference here: databricks/megablocks@main...Muennighoff:megablocks:olmoe ; besides zloss it also has expert choice which is currently not used but i think we may want to try in the future when we go multimodal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you upstream this, so we don't have to depend on a private fork?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, opened a PR here databricks/megablocks#133 - If / when it gets merged, I will update the install instructions. If people don't want to use zloss, it also works with the regular megablocks - it's not a big difference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Muennighoff , so they decided to merge their version instead. Is our version compatible? Will the model you trained work with their implementation of zloss?

)
from .config import config_to_moe_args

super().__init__(layer_id, config, cache)

self.moe_args = config_to_moe_args(config)
self.ffn = dMoE(self.moe_args) if self.config.moe_dropless else MoE(self.moe_args)

self.attn_norm = LayerNorm.build(config)
self.ff_norm = LayerNorm.build(config)

# Attention input projection. Projects x -> (q, k, v)
head_dim = config.d_model // config.n_heads
self.fused_dims = (
config.d_model,
config.effective_n_kv_heads * head_dim,
config.effective_n_kv_heads * head_dim,
)
self.att_proj = nn.Linear(
config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
)

def reset_parameters(self):
if self.k_norm is not None:
self.k_norm.reset_parameters()
if self.q_norm is not None:
self.q_norm.reset_parameters()

if self.config.init_fn == InitFnType.normal:
attn_out_std = ff_out_std = in_std = self.config.init_std
cutoff_factor = self.config.init_cutoff_factor
elif self.config.init_fn == InitFnType.mitchell:
in_std = 1 / math.sqrt(self.config.d_model)
attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1)))
cutoff_factor = self.config.init_cutoff_factor or 3.0
elif self.config.init_fn == InitFnType.full_megatron:
in_std = self.config.init_std
attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
cutoff_factor = self.config.init_cutoff_factor or 3.0
else:
raise NotImplementedError(self.config.init_fn)

init_normal(self.att_proj, std=in_std, init_cutoff_factor=cutoff_factor)
init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
self.attn_norm.reset_parameters()
self.ff_norm.reset_parameters()
init_normal(self.ffn.experts.mlp.w1, std=in_std, init_cutoff_factor=cutoff_factor)
init_normal(self.ffn.experts.mlp.w2, std=ff_out_std, init_cutoff_factor=cutoff_factor)
if hasattr(self.ffn.experts.mlp, "v1"):
init_normal(self.ffn.experts.mlp.v1, std=in_std, init_cutoff_factor=cutoff_factor)
Muennighoff marked this conversation as resolved.
Show resolved Hide resolved
if self.ffn.experts.bias is not None:
torch.nn.init.zeros_(self.ffn.experts.bias)
init_normal(self.ffn.router.layer, std=in_std, init_cutoff_factor=cutoff_factor)

def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
max_doc_len: Optional[int] = None,
cu_doc_lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value projections.
# shape:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
# - for group query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
if not self.config.norm_after:
if self._activation_checkpoint_fn is not None:
qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x))
else:
qkv = self.att_proj(self.attn_norm(x))
else:
qkv = self.att_proj(x)

if self.config.clip_qkv is not None:
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)

q, k, v = qkv.split(self.fused_dims, dim=-1)

# Get attention scores.
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention,
q,
k,
v,
attention_bias,
layer_past=layer_past,
use_cache=use_cache,
max_doc_len=max_doc_len,
cu_doc_lens=cu_doc_lens,
)
else:
att, cache = self.attention(
q,
k,
v,
attention_bias,
layer_past=layer_past,
use_cache=use_cache,
max_doc_len=max_doc_len,
cu_doc_lens=cu_doc_lens,
)

if self.config.norm_after:
if self._activation_checkpoint_fn is not None:
att = self._activation_checkpoint_fn(self.attn_norm, att)
else:
att = self.attn_norm(att)

# Add attention scores.
# shape: (B, T, C)
x = x + self.dropout(att)

# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x

if self.config.norm_after:
x = self.ffn(x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
return og_x + self.dropout(x), cache
else:
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
# Activation checkpointing for the MoE FFN is not supported
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? If there is a technical problem with it, will it affect whole_layer activation checkpointing as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails with

torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Unpack is being triggered for a tensor that was already unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do so only once. Otherwise please open an issue with details on your use case. 2024-05-15T20:15:01.172963498Z 2024-05-15 13:15:01.171 jupiter-cs-aus-133.reviz.ai2.in:3 olmo.util:158 CRITICAL Uncaught CheckpointError: torch.utils.checkpoint: Unpack is being triggered for a tensor that was already unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do so only once. Otherwise please open an issue with details on your use case.

This paper has some explanations why it is difficult to do act ckpt for MoEs: https://dspace.mit.edu/bitstream/handle/1721.1/153897/wisdom-dwisdom-meng-eecs-2024-thesis.pdf

whole_layer is not supported with MoE, only fine_grained - I added code to raise an error if it's not fine_grained & MoE is configured.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see. Interesting. It would be fixable I think (by saving the active experts per token in the forward pass), but out of scope for this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a fairly big blocker to going bigger though. For dense models, our fastest settings still use a lot of checkpointing.

return og_x + self.dropout(self.ffn(x)), cache


class OLMoSequentialBlock(OLMoBlock):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
Expand Down Expand Up @@ -1552,7 +1707,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
else:
raise NotImplementedError(wrap_strategy)

def num_params(self, include_embedding: bool = True) -> int:
def num_params(self, include_embedding: bool = True, include_inactive_params: bool = True) -> int:
"""
Get the total number of parameters.
"""
Expand All @@ -1562,6 +1717,15 @@ def num_params(self, include_embedding: bool = True) -> int:
lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
params,
)
if not include_inactive_params:
# Need to reduce blocks to the number of experts that are selected
# If not dropless 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (total_experts, in_dim, out_dim)
# change to 'transformer.blocks.0.ffn.experts.mlp.w1' with shape (selected_experts, in_dim, out_dim)
# If dropless, the total_experts & out_dim are combined into one dimension
idx = self.config.moe_top_k
if self.config.moe_dropless:
idx *= self.transformer.blocks[1].moe_args.ffn_hidden_size
params = [(np[0], np[1][:idx]) if "experts.mlp" in np[0] else np for np in params] # type: ignore
return sum(p.numel() for _, p in params)

@property
Expand Down
9 changes: 9 additions & 0 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
from .torch_util import get_default_device, is_distributed

try:
from megablocks.layers.mlp import MLP, SparseMLP

megablocks_available = True
except ImportError:
megablocks_available = False

__all__ = [
"Optimizer",
"LionW",
Expand Down Expand Up @@ -858,6 +865,8 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
decay.add(fpn)
else:
no_decay.add(fpn)
elif megablocks_available and pn.endswith(("w1", "w2", "v1")) and isinstance(m, (MLP, SparseMLP)):
decay.add(fpn)

# Validate that we've considered every parameter
inter_params = decay & no_decay
Expand Down
Loading
Loading