-
Notifications
You must be signed in to change notification settings - Fork 90
Feat: Support TP for long-context draft model training #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yd-oom, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've introduced Tensor Parallelism (TP) to the draft model, which is essential for training large Llama-3.1 models with long context lengths without encountering Out-of-Memory errors on multi-GPU systems. This involved modifying key components like the attention and MLP layers to distribute computations across devices. Additionally, I've refined the AllReduce
operation to improve its integration with PyTorch's autograd system and added new tests to validate the numerical accuracy of the parallelized model.
Highlights
- Tensor Parallelism Implementation: I've implemented Tensor Parallelism (TP) support for the draft model, specifically targeting Llama-3.1 Eagle models. This crucial change allows for the training of these large models with extended context lengths (e.g., 8K, 16K, 32K) on multi-GPU setups, resolving previous Out-of-Memory (OOM) errors.
- AllReduce Operation Refinement: I've refactored the
AllReduce
operation inspecforge/layers/linear.py
by introducing a customtorch.autograd.Function
. This addresses a persistent PyTorchUserWarning
related to autograd kernel registration, ensuring more robust and silent backpropagation for distributed operations. - Correctness Test Suite: To ensure the integrity of the TP implementation, I've added comprehensive correctness tests in
tests/test_draft_modeling_tp.py
. These tests rigorously verify that the output of the TP-enabled attention and MLP layers is numerically identical to their original, non-parallelized counterparts, confirming the accuracy of the distributed computations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Tensor Parallelism (TP) support for the draft model to address out-of-memory errors during training with long context lengths. The changes are well-structured, replacing standard nn.Linear
layers with ColumnParallelLinear
and RowParallelLinear
in the LlamaAttention
and LlamaMLP
modules. A custom _AllReduce
autograd function is correctly implemented to handle backpropagation with distributed operations. The addition of a comprehensive correctness test suite is excellent for verifying the TP implementation. My review focuses on improving the robustness of these new tests.
@yd-oom This is feature is really exciting! could you please solve the conflicts? and did you test it using llama3.1B? Is the accept length good? |
8a8ae9d
to
55c1a23
Compare
5abe203
to
a0d3267
Compare
@zyksir Hi,Conflicts resolved. This was tested on Llama 3.1 8B. The results with TP=2 are identical to the baseline (non-TP) after two epochs on ShareGPT. Our team has been using this function internally for a month |
Motivation
Training Llama-3.1 models (8B and 70B) in offline mode with long context lengths (e.g., 8K, 16K, or 32K) currently fails with Out-of-Memory (OOM) errors, even on multi-GPU setups.
Modifications
add TP support in specforge/modeling/draft/llama3_eagle.py
rewrite AllReduce in linear.py to aviod UserWarning(UserWarning: c10d::allreduce_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.))
Added Correctness Tests: Included new tests to verify that the output of the TP-enabled implementation is numerically identical to the original single-GPU implementation.
Implemented a robust save_pretrained method in the Eagle3DraftModel base class (specforge/modeling/draft/base.py).
Related Issues
#112
Accuracy Test
Benchmark & Profiling
Before (Original):
Training Llama-3.1-8B with an 8192 context length on 2*H20 fails with an OOM error.
torchrun \ --standalone \ --nproc_per_node $NUM_GPUS \ $ROOT_DIR/scripts/train_eagle3_offline.py \ --target-model-path /mnt/model/Meta-Llama-3.1-8B-Instruct \ --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/longwriter.jsonl \ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/longwriter \ --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3 \ --num-epochs 1 \ --batch-size 1 \ --learning-rate 1e-4 \ --max-length 8192 \ --chat-template llama3 \ --cache-dir $ROOT_DIR/cache \ --report-to swanlab \ --swanlab-project eagle3 \ --swanlab-key xxx \
will OOM
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 95.22 GiB of which 302.56 MiB is free. Including non-PyTorch memory, this process has 94.91 GiB memory in use. Of the allocated memory 86.89 GiB is allocated by PyTorch, and 6.61 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) [rank0]:[W807 04:12:51.599551331 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Add tp-size:
torchrun \ --standalone \ --nproc_per_node $NUM_GPUS \ $ROOT_DIR/scripts/train_eagle3_offline.py \ --target-model-path /mnt/model/Meta-Llama-3.1-8B-Instruct/main \ --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/longwriter.jsonl \ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/longwriter \ --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3 \ --num-epochs 1 \ --batch-size 1 \ --learning-rate 1e-4 \ --max-length 8192 \ --chat-template llama3 \ --cache-dir $ROOT_DIR/cache \ --tp-size $NUM_GPUS
It can run successfully
Todo
Add comprehensive benchmark results for several tp training scenarios
Checklist