Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SFT] Support context parallelism for SFT #132

Merged
merged 68 commits into from
Jan 27, 2025

Conversation

xingyaoww
Copy link
Contributor

@xingyaoww xingyaoww commented Jan 25, 2025

Add Sequence Parallelism and Padding Removal to SFT Trainer

This PR adds sequence parallelism (SP) and padding removal optimizations to the SFT trainer, which can help improve training efficiency for large language models.

Key Changes

Core Features

  1. Sequence Parallelism: Added support for sequence parallelism through the Ulysses framework

    • Configurable via ulysses_sequence_parallel_size parameter
    • Properly handles data distribution across SP ranks
    • Maintains consistent loss computation across distributed setup
  2. Padding Removal: Added support for efficient handling of variable-length sequences

    • Enabled via use_remove_padding flag (requires SP to be enabled)
    • Uses flash-attention's padding removal utilities
    • Handles proper re-padding and loss computation
  3. Training Improvements:

    • Added label smoothing support to loss computation
    • Added progress bar with epoch information
    • Added RoPE scaling configuration support
    • Improved error messages for batch size validation

Testing

  • Added comprehensive test suite (test_trainer.py) to verify:
    • Forward pass consistency between original and SP+rmpad implementations
    • Loss computation correctness across distributed setup
    • Proper handling of micro-batches

Example Usage

Added example script examples/sft/gsm8k/run_qwen_05_sp2.sh demonstrating how to use the new features with Qwen-2.5B model.

Implementation Details

  • Uses device mesh for proper distributed training setup
  • Handles data distribution ensuring same sequences within SP groups but different across DP groups
  • Carefully manages backward pass timing with gradient checkpointing
  • Maintains compatibility with existing FSDP features

Testing Instructions

  1. Run the example script with sequence parallelism:
bash examples/sft/gsm8k/run_qwen_05_sp2.sh <nproc_per_node> <save_path>
  1. Run the test suite:
    bash tests/sft/run_sft_sp_loss_match.sh

^^ These are PR description generated by OpenHands

xingyaoww and others added 30 commits January 15, 2025 19:22
@xingyaoww
Copy link
Contributor Author

ok the training script is working now!

image

@xingyaoww
Copy link
Contributor Author

This PR should be ready for review! I just added a CI that checks for loss match

image

@xingyaoww xingyaoww marked this pull request as ready for review January 25, 2025 20:23
@xingyaoww xingyaoww changed the title [WIP, SFT] Support context parallelism for SFT [SFT] Support context parallelism for SFT Jan 25, 2025
@@ -165,6 +193,14 @@ def _build_model_optimizer(self):
trust_remote_code = self.config.model.trust_remote_code
# load config first
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
if self.use_remove_padding:
assert self.config.ulysses_sequence_parallel_size > 1, "Remove padding is only supported with sequence parallel"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it should be the opposite? Sequence parallel is only support when remove_padding is enabled?

loss.backward()
if self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1:
# micro_batch = micro_batch.to('cuda')
loss = self._compute_loss_and_backward_sp(batch=micro_batch) / n_micro_batches
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we combine the two functions as there are plenty replicated code?

@vermouth1992
Copy link
Collaborator

Nice work!

@vermouth1992
Copy link
Collaborator

Could you run the formatting script?

@xingyaoww
Copy link
Contributor Author

done! @vermouth1992

@vermouth1992
Copy link
Collaborator

There are conflicts with newly merge MR..

@vermouth1992 vermouth1992 merged commit 077173f into volcengine:main Jan 27, 2025
10 checks passed
Chendong98 pushed a commit to Chendong98/verl that referenced this pull request Feb 4, 2025
# Add Sequence Parallelism and Padding Removal to SFT Trainer

This PR adds sequence parallelism (SP) and padding removal optimizations
to the SFT trainer, which can help improve training efficiency for large
language models.

## Key Changes

### Core Features
1. **Sequence Parallelism**: Added support for sequence parallelism
through the Ulysses framework
   - Configurable via `ulysses_sequence_parallel_size` parameter
   - Properly handles data distribution across SP ranks
   - Maintains consistent loss computation across distributed setup

2. **Padding Removal**: Added support for efficient handling of
variable-length sequences
   - Enabled via `use_remove_padding` flag (requires SP to be enabled)
   - Uses flash-attention's padding removal utilities
   - Handles proper re-padding and loss computation

3. **Training Improvements**:
   - Added label smoothing support to loss computation
   - Added progress bar with epoch information
   - Added RoPE scaling configuration support
   - Improved error messages for batch size validation

### Testing
- Added comprehensive test suite (`test_trainer.py`) to verify:
- Forward pass consistency between original and SP+rmpad implementations
  - Loss computation correctness across distributed setup
  - Proper handling of micro-batches

### Example Usage
Added example script `examples/sft/gsm8k/run_qwen_05_sp2.sh`
demonstrating how to use the new features with Qwen-2.5B model.

## Implementation Details
- Uses device mesh for proper distributed training setup
- Handles data distribution ensuring same sequences within SP groups but
different across DP groups
- Carefully manages backward pass timing with gradient checkpointing
- Maintains compatibility with existing FSDP features

## Testing Instructions
1. Run the example script with sequence parallelism:
```bash
bash examples/sft/gsm8k/run_qwen_05_sp2.sh <nproc_per_node> <save_path>
```

2. Run the test suite:
```bash tests/sft/run_sft_sp_loss_match.sh```


^^ These are PR description generated by [OpenHands](https://github.com/All-Hands-AI/OpenHands)

---------

Co-authored-by: Jiayi Pan <[email protected]>
Co-authored-by: openhands <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants