Skip to content

Commit

Permalink
refactor: RMSNorm (#59)
Browse files Browse the repository at this point in the history
refactor rms norm op, and rotary_embeding and mha.

---------

Co-authored-by: root <[email protected]>
  • Loading branch information
zhangzefeng92 and yangbofun authored Apr 1, 2024
1 parent 778fc8c commit 8261278
Show file tree
Hide file tree
Showing 29 changed files with 663 additions and 659 deletions.
43 changes: 9 additions & 34 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <cstdint>
#include <tuple>
#include <utility>
#include <vector>

#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Generator.h>
Expand All @@ -27,43 +26,22 @@

namespace dipu::dipu_ext {

namespace {

at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault(
const OptionalIntArray& opt, at::IntArrayRef def) {
if (opt) {
return {*opt};
}
return def;
}

} // namespace

auto extRmsNorm(const at::Tensor& input,
auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias, double eps) {
at::OptionalIntArrayRef normalized_shape_at =
optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes());
auto input_shape = input.sizes();
std::vector<int64_t> input_size(input_shape.begin(), input_shape.end());
input_size.back() = 1;
auto inv_rms = at::empty(input_size, input.options());
auto output = at::empty_like(input);
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
bias, eps);
return std::make_tuple(std::move(output), std::move(inv_rms));
}

auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output,
const at::Tensor& inv_rms,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias,
double eps) {
at::OptionalIntArrayRef normalized_shape_at =
optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes());
auto grad_input = at::empty_like(grad_output);
auto grad_weight = at::empty_like(weight);
auto grad_bias = at::empty_like(bias);
auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
at::Tensor& grad_bias, const at::Tensor& grad_output,
const at::Tensor& input, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& inv_rms,
const OptionalIntArray& normalized_shape, double eps) {
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias,
grad_output, input, weight, bias, inv_rms, normalized_shape_at,
eps);
Expand Down Expand Up @@ -241,9 +219,6 @@ auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
m.def("rms_norm_lightllm", &extRmsNormLightllm,
"deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"),
py::arg("eps"));
}
if (&diopiRMSNormBackward != nullptr) {
m.def("rms_norm_backward", &extRmsNormBackward,
Expand Down
24 changes: 0 additions & 24 deletions csrc/pybind_type_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,4 @@ using OptionalIntArray = c10::optional<IntArray>;

} // namespace dipu::dipu_ext

namespace pybind11::detail {

namespace py = pybind11;

template <>
struct type_caster<at::OptionalIntArrayRef> {
public:
PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray"));

bool load(py::handle src, bool /*unused*/) {
if (PyList_Check(src.ptr())) {
value = py::cast<dipu::dipu_ext::IntArray>(src);
return true;
}
if (src.is_none()) {
value = c10::nullopt;
return true;
}
return false;
}
};

} // namespace pybind11::detail

#endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */
4 changes: 4 additions & 0 deletions deeplink_ext/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward


__all__ = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"]
78 changes: 78 additions & 0 deletions deeplink_ext/common/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import deeplink_ext.cpp_extensions as cpp_ext


def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps):
if None == normalized_shape:
cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps)
else:
cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps)


def rms_norm(input, normalized_shape, weight, bias, eps):
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device)
rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps)

return [output, inv_rms]


def rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
):
if None == normalized_shape:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
weight.shape,
eps,
)
else:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)


def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps):
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)

return [grad_input, grad_weight, grad_bias]
39 changes: 37 additions & 2 deletions deeplink_ext/internlm_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
# Copyright (c) 2024, DeepLink.

from . import mha, rms_norm, rotary
from . import mha

__all__ = ["mha", "rms_norm", "rotary"]

_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."


try:
from .rms_norm import RMSNorm, RMSNormWithNormalizedShape
except:
print(
_not_impl.format(op_name="RMSNorm or RMSNormWithNormalizedShape"),
)
from .rms_norm_fallback import (
RMSNorm,
RMSNormWithNormalizedShape,
)


try:
from .rotary_embedding import apply_rotary
except:
print(_not_impl.format(op_name="apply_rotary"))
from .rotary_embeddinig_fallback import apply_rotary


try:
from .mha import SelfAttention, CrossAttention
except Exception as e:
print(_not_impl.format(op_name="mha"))
from .mha_fallback import SelfAttention, CrossAttention

__all__ = [
"SelfAttention",
"CrossAttention",
"RMSNorm",
"RMSNormWithNormalizedShape",
"apply_rotary",
]
Loading

0 comments on commit 8261278

Please sign in to comment.