Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor rms_norm for ascend speed #69

Merged
merged 8 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,16 @@ void extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
beta1, beta2, epsilon, weight_decay, step, amsgrad);
}

auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
void extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias, double eps) {
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
bias, eps);
return std::make_tuple(std::move(output), std::move(inv_rms));
}

auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
void extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
at::Tensor& grad_bias, const at::Tensor& grad_output,
const at::Tensor& input, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& inv_rms,
Expand All @@ -58,8 +57,6 @@ auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias,
grad_output, input, weight, bias, inv_rms, normalized_shape_at,
eps);
return std::make_tuple(std::move(grad_input), std::move(grad_weight),
std::move(grad_bias));
}

void extApplyRotary(at::Tensor& output, const at::Tensor& input,
Expand Down
11 changes: 10 additions & 1 deletion deeplink_ext/ascend_speed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from .rotary_embedding import apply_rotary, RotaryEmbedding
from .adamw import adamw
from .scaled_masked_softmax import ScaledMaskedSoftmax
from .rms_norm import RMSNorm
from .flash_attention import FlashSelfAttention

__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "ScaledMaskedSoftmax"]
__all__ = [
"apply_rotary",
"RotaryEmbedding",
"adamw",
"ScaledMaskedSoftmax",
"RMSNorm",
"FlashSelfAttention",
]
47 changes: 47 additions & 0 deletions deeplink_ext/ascend_speed/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import deeplink_ext.cpp_extensions as ext


assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward")


class RMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, eps):
bias = torch.Tensor().cuda()
output = torch.empty_like(hidden_states)
input_dtype = hidden_states.dtype
acc_dtype = (
torch.float32
if input_dtype in [torch.bfloat16, torch.float16]
else input_dtype
)
inv_rms = torch.empty(
list(hidden_states.shape[:-1]) + [1],
dtype=acc_dtype,
device=hidden_states.device,
)
ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps)
ctx.save_for_backward(hidden_states, inv_rms, weight, bias)
ctx.eps = eps
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias = ctx.saved_tensors
grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
hidden_states,
weight,
bias,
inv_rms,
weight.shape,
ctx.eps,
)
return grad_input, grad_weight, None, None
2 changes: 0 additions & 2 deletions deeplink_ext/internlm_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Copyright (c) 2024, DeepLink.

from . import mha


_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."

Expand Down
Loading