-
Notifications
You must be signed in to change notification settings - Fork 240
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cp] apply fsdp to model when CP is enabled without DP for correct lo…
…ss and lower mem usage (#685) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #684 * __->__ #685 **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training
- Loading branch information
Showing
2 changed files
with
23 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters