Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
modify rms norm
Browse files Browse the repository at this point in the history
zhangzefeng92 committed Mar 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 7fc26c7 commit 8eca230
Showing 5 changed files with 103 additions and 112 deletions.
18 changes: 0 additions & 18 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
@@ -197,21 +197,6 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
}

// For lightllm, rms_norm reuses the diopi implementation of internlm
auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight,
float eps) {
at::ScalarType acc_type = x.scalar_type();
if (x.scalar_type() == at::kBFloat16 || x.scalar_type() == at::kHalf) {
acc_type = at::kFloat;
}
auto inv_rms = at::empty_like(x, acc_type);
auto out = at::empty_like(x);
auto bias = at::empty_like(weight);
at::OptionalIntArrayRef normalized_shape = weight.sizes();
callDiopi(diopiRMSNorm, out, inv_rms, x, normalized_shape, weight, bias, eps);
return out;
}

// For lightllm, rotary_embedding reuses the diopi implementation of internlm
void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
auto seq_len = q.size(0);
@@ -229,9 +214,6 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
m.def("rms_norm_lightllm", &extRmsNormLightllm,
"deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"),
py::arg("eps"));
}
if (&diopiRMSNormBackward != nullptr) {
m.def("rms_norm_backward", &extRmsNormBackward,
23 changes: 0 additions & 23 deletions csrc/pybind_type_cast.h
Original file line number Diff line number Diff line change
@@ -21,28 +21,5 @@ using OptionalIntArray = c10::optional<IntArray>;

} // namespace dipu::dipu_ext

namespace pybind11::detail {

namespace py = pybind11;

template <>
struct type_caster<at::OptionalIntArrayRef> {
public:
PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray"));

bool load(py::handle src, bool /*unused*/) {
if (PyList_Check(src.ptr())) {
value = py::cast<dipu::dipu_ext::IntArray>(src);
return true;
}
if (src.is_none()) {
value = c10::nullopt;
return true;
}
return false;
}
};

} // namespace pybind11::detail

#endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */
135 changes: 86 additions & 49 deletions deeplink_ext/internlm_ops/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,110 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext
import deeplink_ext.cpp_extensions as cpp_ext

assert hasattr(ext, "rms_norm")
assert hasattr(cpp_ext, "rms_norm")


# 定义自定义的 autograd 函数
class _DeepLinkRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps):
output = torch.empty_like(hidden_states)
inv_rms_shape = list(hidden_states.shape[:-1]) + [1]
inv_rms = torch.empty(
inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device
)
ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps)
def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps):
if None == normalized_shape:
cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps)
else:
cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps)

ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()
def rms_norm(input, normalized_shape, weight, bias, eps):
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device)
rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps)

return [output, inv_rms]

grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(

def rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
):
if None == normalized_shape:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
hidden_states,
input,
weight,
bias,
inv_rms,
weight.shape,
eps,
)
else:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)


def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps):
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)

return [grad_input, grad_weight, grad_bias]


# 定义自定义的 autograd 函数
class _DeepLinkRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps):
output, inv_rms = rms_norm(hidden_states, None, weight, bias, eps)
ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()
grad_input, grad_weight, grad_bias = rms_norm_backward(
hidden_states, grad_output, inv_rms, None, weight, bias, eps
)
return grad_input, grad_weight, grad_bias, None


class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps, normalized_shape):
output = torch.empty_like(hidden_states, dtype=torch.float32)
inv_rms_shape = list(hidden_states.shape[:-1]) + [1]
inv_rms = torch.empty(
inv_rms_shape, dtype=torch.float32, device=hidden_states.device
)
ext.rms_norm(
output,
inv_rms,
hidden_states.float(),
normalized_shape,
weight.float(),
bias.float(),
eps,
output, inv_rms = rms_norm(
hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps
)
output = output.half()
inv_rms = inv_rms.half()
@@ -72,24 +118,15 @@ def backward(ctx, grad_output):
eps = eps_tensor.item()
normalized_shape = ctx.intermediate_results

grad_input = torch.empty_like(hidden_states, dtype=torch.float32)
grad_weight = torch.empty_like(weight, dtype=torch.float32)
grad_bias = torch.empty_like(bias, dtype=torch.float32)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output.float(),
grad_input, grad_weight, grad_bias = rms_norm_backward(
hidden_states.float(),
weight.float(),
bias.float(),
grad_output.float(),
inv_rms.float(),
normalized_shape,
weight.float(),
bias.float(),
eps,
)
grad_output = grad_output.half()
hidden_states = hidden_states.half()
inv_rms = inv_rms.half()
return grad_input, grad_weight, grad_bias, None, None


13 changes: 12 additions & 1 deletion deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,18 @@ def patch_token_softmax_reducev_inference():
)

def patch_rms_norm_lightllm():
rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm
import torch

def rms_norm_lightllm(x, weight, eps):
output = torch.empty_like(x)
inv_rms_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype
inv_rms = torch.empty_like(x, dtype=inv_rms_dtype)
bias = torch.empty_like(weight)
ext.rms_norm(output, inv_rms, x, weight.shape, weight, bias, eps)

return output

rms_norm_pack.rmsnorm_forward = rms_norm_lightllm

def patch_rotary_emb():
rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb
26 changes: 5 additions & 21 deletions tests/test_rms_lightlm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext
from deeplink_ext.internlm_ops.rms_norm.deeplink import rms_norm, rms_norm_backward

# 定义输入张量
input = torch.randn(5, 5, requires_grad=True).cuda()
@@ -17,26 +17,10 @@
normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda()

print(input.is_dipu)
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device)
ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6)

# 使用 RMS normalization 反向传播
grad_input = torch.empty_like(grad_output)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
weight.shape,
1e-6,
output, inv_rms = rms_norm(input, None, weight, bias, 1e-6)

grad_input, grad_weight, grad_bias = rms_norm_backward(
input, grad_output, inv_rms, None, weight, bias, 1e-6
)

print("Output:", output)

0 comments on commit 8eca230

Please sign in to comment.