Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc][Long Context] feat: support ulysses for long context training #109

Merged
merged 54 commits into from
Jan 18, 2025

Conversation

PeterSH6
Copy link
Collaborator

  • To support ulysses, we implemented a FSDPUlyssesShardingManager to manage the SP Parallel states of different models. And we utilize device mesh to manage the SP parallel groups.
  • The long context training is supported through monkey_patch. Currently, we support llama and qwen2 architectures. Will support other models later
  • Before model fwd, we pad the input_ids to be divisible by SP size and then slice the input_ids. The position_ids are only padded not sliced to make sure the position_embedding can match the qkv_states. This can be optimized.
  • Make shuffle to False in mini_batch_iterator

@PeterSH6 PeterSH6 changed the title [misc] feat: support ulysses for long context training [misc][Long Context] feat: support ulysses for long context training Jan 16, 2025
@PeterSH6 PeterSH6 marked this pull request as ready for review January 17, 2025 13:47
@PeterSH6
Copy link
Collaborator Author

Almost finished.

I wonder what kind of examples shall we add? We can add some scripts in the next PR.

@vermouth1992 vermouth1992 merged commit e8eb9e4 into volcengine:main Jan 18, 2025
8 checks passed
@PeterSH6 PeterSH6 mentioned this pull request Jan 16, 2025
33 tasks
@xingyaoww
Copy link
Contributor

Quick question @PeterSH6 - would this Ulysses PR supports gradient checkpointing?

I'm trying to use context parallel implemented here for SFT, but I seems keep running into shape mismatch issue during .backward() but not forward. Not sure if it is because this implementation didn't support grad accumulation yet.

image

@xingyaoww
Copy link
Contributor

xingyaoww commented Jan 23, 2025

Yes, i'm able to produce; when enabling Ulysses context parallelism, set gradient_checkpointing_enable to False, and everything works. And turn it on will result in the above indexing error.

Nevermind! I figured it out: it happens when you do loss.backward() outside the context of `FSDPUlyssesShardingManager, seq-parallel info is not available, so the patched forward FN won't gather sentences correctly, hence causing this error.

@PeterSH6
Copy link
Collaborator Author

@xingyaoww Cool! So you implemented Ulysses in the SFT trainer?

@xingyaoww
Copy link
Contributor

xingyaoww commented Jan 24, 2025

@PeterSH6 yep! most changes are here (but a lot of unrelated changes as well, e.g. lora)
https://github.com/xingyaoww/verl/commits/dev

I'm still testing it :) but so far it seems to work pretty well.

Can send some PR later

@PeterSH6
Copy link
Collaborator Author

@xingyaoww It would be really nice.

I've seen your LoRA PR. It looks great.
It would be even better if you could create a PR for Ulysses + rmpad in SFTTrainer.
We really appreciate your effort!

@xingyaoww
Copy link
Contributor

@PeterSH6 definitely -- draft PR up in #132

Will clean up the code there when LoRA is merged :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants