-
-
Notifications
You must be signed in to change notification settings - Fork 921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement fused modules #747
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mkeoliya
pushed a commit
to mkeoliya/axolotl
that referenced
this pull request
Dec 15, 2023
* MLP: Memory saving * Remove RMSNorm restrictions * Map packed weights to original * FusedAttention module * Simplify code * Move fused modules * Fix critical typo * Split inplace * Add FFT config * Add validation of fused arguments * Add fused arguments to config * Update docs * Fix validation logic * Add fused modules to flash attn * Only fuse during training * Remove timing * Formatting * Formatting * Formatting * chore: lint * chore: lint * add e2e tests for fused llama * no lora for tests --------- Co-authored-by: Wing Lian <[email protected]>
djsaunde
pushed a commit
that referenced
this pull request
Dec 17, 2024
* MLP: Memory saving * Remove RMSNorm restrictions * Map packed weights to original * FusedAttention module * Simplify code * Move fused modules * Fix critical typo * Split inplace * Add FFT config * Add validation of fused arguments * Add fused arguments to config * Update docs * Fix validation logic * Add fused modules to flash attn * Only fuse during training * Remove timing * Formatting * Formatting * Formatting * chore: lint * chore: lint * add e2e tests for fused llama * no lora for tests --------- Co-authored-by: Wing Lian <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There are two common ways to fuse layers in Llama/Mistral type of models. Speed and memory are measured on RTX 3090 with TinyLlama 1B.
gate_proj
andup_proj
togetherAll fusing of layers must happen AFTER the model is loaded in order to load the pretrained weights into the fused modules.
TinyLlama 1.1B - A6000
Conclusion: Fusing MLP can save roughly 27% memory in cache. Fusing attention seems to do nothing for the speed but increases memory with about 1GB.
Llama-2-7B - A100
Conclusion: Saves enough memory to load using
adamw_torch
.None fused (main):
adamw_torch
Memory: 37.732GB (+39.764GB cache, +1.366GB misc)adamw_torch_fused
Memory: 37.732GB (+14.506GB cache, +1.366GB misc)adamw_bnb_8bit
Memory: 25.393GB (+14.494GB cache, +1.366GB misc)MLP fused (PR):
adamw_torch
Memory: 37.732GB (+38.813GB cache, +1.366GB misc)adamw_torch_fused
Memory: 37.732GB (+14.647GB cache, +1.366GB misc)adamw_bnb_8bit
Memory: 25.269GB (+14.137GB cache, +1.366GB misc)MLP + Attention fused (PR):
adamw_bnb_8bit
Memory: 31.332GB (+13.752GB cache, +1.366GB misc)adamw_torch
Memory: OOMQLoRA
Currently, it is not compatible with QLoRA. But there is potential to do so. In
bitsandbytes
, you can import the 4-bit and 8-bit linears and use them instead of nn.Linear.https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/nn/modules.py#L258