Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP fails with [ lm_head ] layer #332

Open
gotzmann opened this issue Oct 31, 2024 · 5 comments
Open

FSDP fails with [ lm_head ] layer #332

gotzmann opened this issue Oct 31, 2024 · 5 comments

Comments

@gotzmann
Copy link

🐛 Describe the bug

I'm trying to train LLaMA model with all linear layers + embeddings and head.

Whilst embeddings have no problems with FSDP over Liger, there always exceptions when [ lm_head ] is added.

I've tried different versions and latest patches not yet merged, but still getting the error:

RuntimeError: size mismatch, got input (2), mat (2x4096), vec (65667072)

Reproduce

accelerate launch --config_file fsdp.yaml src/train.py sft.yaml

Versions

v3.1 and others, too

@ByronHsu
Copy link
Collaborator

Can you provide reproducible code? Liger is working fine with fsdp in https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface

@gotzmann
Copy link
Author

gotzmann commented Nov 1, 2024

Yep, it working fine when training on linear and embedding layers, but lm_head. I'll try with newest commits today.

@gotzmann
Copy link
Author

gotzmann commented Nov 1, 2024

Nope, still fails:

[rank3]:   File "/home/git/Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py", line 59, in fused_linear_cross_entropy_forward
[rank3]:     logits_chunk = _input_chunk @ weight.t()  # chunk_size x V
[rank3]: RuntimeError: size mismatch, got input (2), mat (2x4096), vec (65667072)

config.yaml for LLaMA-Factory:

stage: sft
do_train: true
finetuning_type: lora
lora_target: all
additional_target: embed_tokens,lm_head
lora_rank: 128
lora_alpha: 16
lora_dropout: 0.1
use_rslora: true

@ByronHsu
Copy link
Collaborator

ByronHsu commented Nov 1, 2024

@gotzmann i think this is due to the constraint of FSDP-1. FSDP-2 should resolve the issues but it is still under experimentation. Any specific reason for you to split lm_head too?

@yundai424
Copy link
Collaborator

yundai424 commented Nov 1, 2024

I believe that's because lm_head is not independently wrapped as a FSDP module (like LlamaDecoderLayer), so it relies on the FSDP root module (i.e. the entire LlamaForCausalLM)'s pre-forward/pre-backward hook in order for its forward/backward to properly work. Essentially FSDP1 doesn't allow something like this out of the box:

model = AutoModelForCausalLM.from_config(...)
model = FullyShardedDataParallel(model, ...)
dummy_hidden_states = torch.randn((1,2,8))
logit = model.lm_head(dummy_hidden_states)

because at this point lm_head is flattened and not properly summoned.

Embedding layer is fine because no liger kernel or any other code in llama factory is trying to perform a standalone call on it. However FusedLinearCrossEntropy tries to extract the lm_head and do custom operation on it, thus failing.

I know lightning does an interesting trick https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648 to make sure these kind of operations can work smoothly under train_step, HF trainer might have done something to work around too, but need to look at Llama factory's code to see what will be a good fix. Otherwise, the workarounds that I can think of:

  1. disable FusedLinearCrossEntorpy and switch to liger fused CrossEntropy (at the cost of higher memory though)
  2. if this can be configured through llama factory, when setting the auto_wrap_policy for FSDP, let the policy function return True when the module is named lm_head

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants