Skip to content

Commit

Permalink
refactor rms_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Apr 2, 2024
1 parent 220857e commit bf3636d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
7 changes: 2 additions & 5 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@ void extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
beta1, beta2, epsilon, weight_decay, step, amsgrad);
}

auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
void extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias, double eps) {
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
bias, eps);
return std::make_tuple(std::move(output), std::move(inv_rms));
}

auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
void extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
at::Tensor& grad_bias, const at::Tensor& grad_output,
const at::Tensor& input, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& inv_rms,
Expand All @@ -55,8 +54,6 @@ auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias,
grad_output, input, weight, bias, inv_rms, normalized_shape_at,
eps);
return std::make_tuple(std::move(grad_input), std::move(grad_weight),
std::move(grad_bias));
}

void extApplyRotary(at::Tensor& output, const at::Tensor& input,
Expand Down
3 changes: 2 additions & 1 deletion deeplink_ext/ascend_speed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .rotary_embedding import apply_rotary, RotaryEmbedding
from .adamw import adamw
from .rms_norm import RMSNorm

__all__ = ["apply_rotary", "RotaryEmbedding", "adamw"]
__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "RMSNorm"]
43 changes: 43 additions & 0 deletions deeplink_ext/ascend_speed/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import deeplink_ext.cpp_extensions as ext


assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward")


class RMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, eps):
bias = torch.Tensor().cuda()
output = torch.empty_like(hidden_states)
input_dtype = hidden_states.dtype
acc_dtype = (
torch.float32
if (input_dtype == torch.bfloat16 or input_dtype == torch.float16)
else input_dtype
)
inv_rms = torch.empty_like(hidden_states, dtype=acc_dtype)
ext.rms_norm(output, hidden_states, None, weight, bias, eps)
ctx.save_for_backward(hidden_states, inv_rms, weight, bias)
ctx.eps = eps
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias = ctx.saved_tensors
grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
hidden_states,
grad_output,
inv_rms,
None,
weight,
bias,
ctx.eps,
)
return grad_input, grad_weight, None, None

0 comments on commit bf3636d

Please sign in to comment.