-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: refactor rms_norm for ascend speed #69
Merged
Merged
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
bf3636d
refactor rms_norm
POI-WX c2a0130
update
POI-WX af9e76d
update
POI-WX defac19
Merge branch 'main' of https://github.com/DeepLink-org/DeepLinkExt in…
POI-WX f55ac87
add for flash attention
POI-WX 084f053
fix bug of construct inv_rms due to shape
POI-WX 61fd608
update
POI-WX bd22e7d
update accorging to review
POI-WX File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,14 @@ | ||
from .rotary_embedding import apply_rotary, RotaryEmbedding | ||
from .adamw import adamw | ||
from .scaled_masked_softmax import ScaledMaskedSoftmax | ||
from .rms_norm import RMSNorm | ||
from .flash_attention import FlashSelfAttention | ||
|
||
__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "ScaledMaskedSoftmax"] | ||
__all__ = [ | ||
"apply_rotary", | ||
"RotaryEmbedding", | ||
"adamw", | ||
"ScaledMaskedSoftmax", | ||
"RMSNorm", | ||
"FlashSelfAttention", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import deeplink_ext.cpp_extensions as ext | ||
|
||
|
||
assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward") | ||
|
||
|
||
class RMSNorm(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, hidden_states, weight, eps): | ||
bias = torch.Tensor().cuda() | ||
output = torch.empty_like(hidden_states) | ||
input_dtype = hidden_states.dtype | ||
acc_dtype = ( | ||
torch.float32 | ||
if input_dtype in [torch.bfloat16, torch.float16] | ||
else input_dtype | ||
) | ||
inv_rms = torch.empty( | ||
hidden_states.shape[:-1] + (1,), | ||
dtype=acc_dtype, | ||
device=hidden_states.device, | ||
) | ||
ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) | ||
ctx.save_for_backward(hidden_states, inv_rms, weight, bias) | ||
ctx.eps = eps | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
hidden_states, inv_rms, weight, bias = ctx.saved_tensors | ||
grad_input = torch.empty_like(hidden_states) | ||
grad_weight = torch.empty_like(weight) | ||
grad_bias = torch.empty_like(bias) | ||
ext.rms_norm_backward( | ||
grad_input, | ||
grad_weight, | ||
grad_bias, | ||
grad_output, | ||
hidden_states, | ||
weight, | ||
bias, | ||
inv_rms, | ||
weight.shape, | ||
ctx.eps, | ||
) | ||
return grad_input, grad_weight, None, None |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.