From 8261278680ae5822b597a94029cbc0cdbcb38c0a Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Mon, 1 Apr 2024 16:47:01 +0800 Subject: [PATCH] refactor: RMSNorm (#59) refactor rms norm op, and rotary_embeding and mha. --------- Co-authored-by: root --- csrc/extensions.cpp | 43 +- csrc/pybind_type_cast.h | 24 - deeplink_ext/common/__init__.py | 4 + deeplink_ext/common/rms_norm.py | 78 +++ deeplink_ext/internlm_ops/__init__.py | 39 +- deeplink_ext/internlm_ops/mha.py | 473 ++++++++++++++++++ deeplink_ext/internlm_ops/mha/__init__.py | 16 - .../internlm_ops/mha/fallback/__init__.py | 5 - deeplink_ext/internlm_ops/mha/mha.py | 109 ---- deeplink_ext/internlm_ops/mha/mha_func.py | 49 -- .../internlm_ops/mha/mha_kvpacked_func.py | 51 -- .../internlm_ops/mha/mha_qkvpacked_func.py | 50 -- .../internlm_ops/mha/mha_varlen_func.py | 83 --- .../mha/mha_varlen_kvpacked_func.py | 83 --- .../mha/mha_varlen_qkvpacked_func.py | 68 --- .../fallback/fallback.py => mha_fallback.py} | 2 + .../{rms_norm/deeplink.py => rms_norm.py} | 40 +- .../internlm_ops/rms_norm/__init__.py | 16 - .../rms_norm/fallback/__init__.py | 5 - .../fallback.py => rms_norm_fallback.py} | 5 + deeplink_ext/internlm_ops/rotary/__init__.py | 13 - .../deeplink.py => rotary_embedding.py} | 0 ...llback.py => rotary_embedding_fallback.py} | 0 deeplink_ext/patch_internlm.py | 14 +- deeplink_ext/patch_lightllm.py | 4 +- tests/test_mha_internlm.py | 4 +- tests/test_rms_internlm.py | 12 +- tests/test_rms_lightlm.py | 15 +- tests/test_rotary_emb_internlm.py | 17 +- 29 files changed, 663 insertions(+), 659 deletions(-) create mode 100644 deeplink_ext/common/__init__.py create mode 100644 deeplink_ext/common/rms_norm.py create mode 100644 deeplink_ext/internlm_ops/mha.py delete mode 100644 deeplink_ext/internlm_ops/mha/__init__.py delete mode 100644 deeplink_ext/internlm_ops/mha/fallback/__init__.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_varlen_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py rename deeplink_ext/internlm_ops/{mha/fallback/fallback.py => mha_fallback.py} (98%) rename deeplink_ext/internlm_ops/{rms_norm/deeplink.py => rms_norm.py} (70%) delete mode 100644 deeplink_ext/internlm_ops/rms_norm/__init__.py delete mode 100644 deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py rename deeplink_ext/internlm_ops/{rms_norm/fallback/fallback.py => rms_norm_fallback.py} (89%) delete mode 100644 deeplink_ext/internlm_ops/rotary/__init__.py rename deeplink_ext/internlm_ops/{rotary/deeplink.py => rotary_embedding.py} (100%) rename deeplink_ext/internlm_ops/{rotary/fallback.py => rotary_embedding_fallback.py} (100%) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 19cdfe29..8b57071f 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include @@ -27,43 +26,22 @@ namespace dipu::dipu_ext { -namespace { - -at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault( - const OptionalIntArray& opt, at::IntArrayRef def) { - if (opt) { - return {*opt}; - } - return def; -} - -} // namespace - -auto extRmsNorm(const at::Tensor& input, +auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, + const at::Tensor& input, const OptionalIntArray& normalized_shape, const at::Tensor& weight, const at::Tensor& bias, double eps) { - at::OptionalIntArrayRef normalized_shape_at = - optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); - auto input_shape = input.sizes(); - std::vector input_size(input_shape.begin(), input_shape.end()); - input_size.back() = 1; - auto inv_rms = at::empty(input_size, input.options()); - auto output = at::empty_like(input); + at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); return std::make_tuple(std::move(output), std::move(inv_rms)); } -auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output, - const at::Tensor& inv_rms, - const OptionalIntArray& normalized_shape, - const at::Tensor& weight, const at::Tensor& bias, - double eps) { - at::OptionalIntArrayRef normalized_shape_at = - optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); - auto grad_input = at::empty_like(grad_output); - auto grad_weight = at::empty_like(weight); - auto grad_bias = at::empty_like(bias); +auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, + at::Tensor& grad_bias, const at::Tensor& grad_output, + const at::Tensor& input, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& inv_rms, + const OptionalIntArray& normalized_shape, double eps) { + at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias, grad_output, input, weight, bias, inv_rms, normalized_shape_at, eps); @@ -241,9 +219,6 @@ auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight, PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm"); - m.def("rms_norm_lightllm", &extRmsNormLightllm, - "deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"), - py::arg("eps")); } if (&diopiRMSNormBackward != nullptr) { m.def("rms_norm_backward", &extRmsNormBackward, diff --git a/csrc/pybind_type_cast.h b/csrc/pybind_type_cast.h index 6d128981..61e8f484 100644 --- a/csrc/pybind_type_cast.h +++ b/csrc/pybind_type_cast.h @@ -21,28 +21,4 @@ using OptionalIntArray = c10::optional; } // namespace dipu::dipu_ext -namespace pybind11::detail { - -namespace py = pybind11; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray")); - - bool load(py::handle src, bool /*unused*/) { - if (PyList_Check(src.ptr())) { - value = py::cast(src); - return true; - } - if (src.is_none()) { - value = c10::nullopt; - return true; - } - return false; - } -}; - -} // namespace pybind11::detail - #endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */ diff --git a/deeplink_ext/common/__init__.py b/deeplink_ext/common/__init__.py new file mode 100644 index 00000000..2d3353d9 --- /dev/null +++ b/deeplink_ext/common/__init__.py @@ -0,0 +1,4 @@ +from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward + + +__all__ = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] diff --git a/deeplink_ext/common/rms_norm.py b/deeplink_ext/common/rms_norm.py new file mode 100644 index 00000000..f196d297 --- /dev/null +++ b/deeplink_ext/common/rms_norm.py @@ -0,0 +1,78 @@ +import torch +import deeplink_ext.cpp_extensions as cpp_ext + + +def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): + if None == normalized_shape: + cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) + else: + cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) + + +def rms_norm(input, normalized_shape, weight, bias, eps): + output = torch.empty_like(input) + inv_rms_shape = list(input.shape[:-1]) + [1] + inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) + rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) + + return [output, inv_rms] + + +def rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, +): + if None == normalized_shape: + cpp_ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + weight.shape, + eps, + ) + else: + cpp_ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + +def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + return [grad_input, grad_weight, grad_bias] diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py index f492916a..61f616ad 100644 --- a/deeplink_ext/internlm_ops/__init__.py +++ b/deeplink_ext/internlm_ops/__init__.py @@ -1,5 +1,40 @@ # Copyright (c) 2024, DeepLink. -from . import mha, rms_norm, rotary +from . import mha -__all__ = ["mha", "rms_norm", "rotary"] + +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + + +try: + from .rms_norm import RMSNorm, RMSNormWithNormalizedShape +except: + print( + _not_impl.format(op_name="RMSNorm or RMSNormWithNormalizedShape"), + ) + from .rms_norm_fallback import ( + RMSNorm, + RMSNormWithNormalizedShape, + ) + + +try: + from .rotary_embedding import apply_rotary +except: + print(_not_impl.format(op_name="apply_rotary")) + from .rotary_embeddinig_fallback import apply_rotary + + +try: + from .mha import SelfAttention, CrossAttention +except Exception as e: + print(_not_impl.format(op_name="mha")) + from .mha_fallback import SelfAttention, CrossAttention + +__all__ = [ + "SelfAttention", + "CrossAttention", + "RMSNorm", + "RMSNormWithNormalizedShape", + "apply_rotary", +] diff --git a/deeplink_ext/internlm_ops/mha.py b/deeplink_ext/internlm_ops/mha.py new file mode 100644 index 00000000..18b8af3d --- /dev/null +++ b/deeplink_ext/internlm_ops/mha.py @@ -0,0 +1,473 @@ +# Copyright (c) 2023, DeepLink. + +import torch +import deeplink_ext.cpp_extensions as ext + +import torch.nn as nn + + +__all__ = [ + "MultiHeadAttention", + "MultiHeadAttentionKVPacked", + "MultiHeadAttentionQKVPacked", + "MultiHeadAttentionVarLen", + "MultiHeadAttentionVarLenKVPacked", + "MultiHeadAttentionVarLenQKVPacked", + "SelfAttention", + "CrossAttention", +] + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") + + +class MultiHeadAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_fwd( + q, + k, + v, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(q, k, v, out, softmax_lse, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + dq, dk, dv = ext.mha_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + None, + None, + None, + ) + return dq, dk, dv, None, None, None, None + + +class MultiHeadAttentionKVPacked(torch.autograd.Function): + @staticmethod + def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_fwd( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(q, kv, out, softmax_lse, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + q, kv, out, softmax_lse, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + ext.mha_bwd( + dout, + q, + kv[:, :, 0], + kv[:, :, 1], + out, + softmax_lse, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dq, + dkv[:, :, 0], + dkv[:, :, 1], + ) + return dq, dkv, None, None, None, None + + +class MultiHeadAttentionQKVPacked(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_fwd( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(qkv, out, softmax_lse, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + qkv, out, softmax_lse, rng_state = ctx.saved_tensors + dqkv = torch.empty_like(qkv) + rng = torch.Generator(device=qkv.device) + rng.set_state(rng_state) + ext.mha_bwd( + dout, + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + out, + softmax_lse, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + ) + return dqkv, None, None, None, None + + +assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") + + +class MultiHeadAttentionVarLen(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward( + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + ( + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + rng_state, + ) = ctx.saved_tensors + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + dq, dk, dv = ext.mha_varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + None, + None, + None, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +class MultiHeadAttentionVarLenKVPacked(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward( + q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + ( + q, + kv, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + rng_state, + ) = ctx.saved_tensors + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + ext.mha_varlen_bwd( + dout, + q, + kv[:, 0], + kv[:, 1], + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dq, + dkv[:, 0], + dkv[:, 1], + ) + return dq, dkv, None, None, None, None, None, None, None, None + + +class MultiHeadAttentionVarLenQKVPacked(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + dqkv = torch.empty_like(qkv) + rng = torch.Generator(device=qkv.device) + rng.set_state(rng_state) + ext.mha_varlen_bwd( + dout, + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + out, + softmax_lse, + cu_seqlens, + cu_seqlens, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + ) + return dqkv, None, None, None, None, None, None, None, None + + +class SelfAttention(nn.Module): + """Performs self-attention with support for both padded and unpadded sequences. + + Args: + causal (bool, optional): If True, applies causal self-attention, meaning each + position can only attend to previous positions. Default is False. + softmax_scale (float, optional): Scaling factor applied to the softmax + operation. If not provided, will be D^{-0.5}. Default is None. + dropout_p (float, optional): Dropout probability applied to the attention + scores. Default is 0.0. + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): + """Performs self-attention on the input sequences. + + Args: + qkv (torch.Tensor): Input tensor representing queries, keys, and values + concatenated together. (B, S, 3, H, D) for padded; (total, 3, H, D) + for unpadded. + causal (bool, optional): If provided, overrides the class-level 'causal' + argument for this forward pass. Default is None. + cu_seqlens (torch.Tensor((batch_size + 1,), dtype=torch.int32), optional): + Sequence lengths tensor for unpadded sequences. If provided, performs + attention on unpadded sequences. Default is None. + max_seqlen (int, optional): Maximum sequence length for unpadded sequences. + If provided, defines the maximum length of the sequences. Default is + None. + + Returns: + torch.Tensor: Output tensor after applying self-attention. + """ + if cu_seqlens is None: + # padded + return MultiHeadAttentionQKVPacked.apply( + qkv, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) + else: + # unpadded + return MultiHeadAttentionVarLenQKVPacked.apply( + qkv, + cu_seqlens, + max_seqlen, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) + + +class CrossAttention(nn.Module): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward( + self, + q, + kv, + causal=None, + cu_seqlens=None, + max_seqlen=None, + cu_seqlens_k=None, + max_seqlen_k=None, + ): + if cu_seqlens is None: + # padded + return MultiHeadAttentionKVPacked.apply( + q, + kv, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) + else: + # unpadded + return MultiHeadAttentionVarLenKVPacked.apply( + q, + kv, + cu_seqlens, + cu_seqlens_k, + max_seqlen, + max_seqlen_k, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) diff --git a/deeplink_ext/internlm_ops/mha/__init__.py b/deeplink_ext/internlm_ops/mha/__init__.py deleted file mode 100644 index 212ddfd9..00000000 --- a/deeplink_ext/internlm_ops/mha/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -try: - from .mha import DeepLinkSelfAttention, DeepLinkCrossAttention -except Exception as e: - print( - "[deeplink_ext] mha is not implemented in diopi. Falling back to the slower implementation.\n", - end="", - ) - from .fallback import ( - SelfAttention as DeepLinkSelfAttention, - CrossAttention as DeepLinkCrossAttention, - ) -from . import fallback - -__all__ = ["DeepLinkSelfAttention", "DeepLinkCrossAttention", "fallback"] diff --git a/deeplink_ext/internlm_ops/mha/fallback/__init__.py b/deeplink_ext/internlm_ops/mha/fallback/__init__.py deleted file mode 100644 index 8f12c7d4..00000000 --- a/deeplink_ext/internlm_ops/mha/fallback/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from .fallback import SelfAttention, CrossAttention - -__all__ = ["SelfAttention", "CrossAttention"] diff --git a/deeplink_ext/internlm_ops/mha/mha.py b/deeplink_ext/internlm_ops/mha/mha.py deleted file mode 100644 index 00798027..00000000 --- a/deeplink_ext/internlm_ops/mha/mha.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch.nn as nn -from .mha_qkvpacked_func import DeepLinkMultiHeadAttentionQKVPackedFunc -from .mha_varlen_qkvpacked_func import DeepLinkMultiHeadAttentionVarLenQKVPackedFunc -from .mha_kvpacked_func import DeepLinkMultiHeadAttentionKVPackedFunc -from .mha_varlen_kvpacked_func import DeepLinkMultiHeadAttentionVarLenKVPackedFunc - - -class DeepLinkSelfAttention(nn.Module): - """Performs self-attention with support for both padded and unpadded sequences. - - Args: - causal (bool, optional): If True, applies causal self-attention, meaning each - position can only attend to previous positions. Default is False. - softmax_scale (float, optional): Scaling factor applied to the softmax - operation. If not provided, will be D^{-0.5}. Default is None. - dropout_p (float, optional): Dropout probability applied to the attention - scores. Default is 0.0. - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): - """Performs self-attention on the input sequences. - - Args: - qkv (torch.Tensor): Input tensor representing queries, keys, and values - concatenated together. (B, S, 3, H, D) for padded; (total, 3, H, D) - for unpadded. - causal (bool, optional): If provided, overrides the class-level 'causal' - argument for this forward pass. Default is None. - cu_seqlens (torch.Tensor((batch_size + 1,), dtype=torch.int32), optional): - Sequence lengths tensor for unpadded sequences. If provided, performs - attention on unpadded sequences. Default is None. - max_seqlen (int, optional): Maximum sequence length for unpadded sequences. - If provided, defines the maximum length of the sequences. Default is - None. - - Returns: - torch.Tensor: Output tensor after applying self-attention. - """ - if cu_seqlens is None: - # padded - return DeepLinkMultiHeadAttentionQKVPackedFunc.apply( - qkv, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) - else: - # unpadded - return DeepLinkMultiHeadAttentionVarLenQKVPackedFunc.apply( - qkv, - cu_seqlens, - max_seqlen, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) - - -class DeepLinkCrossAttention(nn.Module): - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward( - self, - q, - kv, - causal=None, - cu_seqlens=None, - max_seqlen=None, - cu_seqlens_k=None, - max_seqlen_k=None, - ): - if cu_seqlens is None: - # padded - return DeepLinkMultiHeadAttentionKVPackedFunc.apply( - q, - kv, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) - else: - # unpadded - return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply( - q, - kv, - cu_seqlens, - cu_seqlens_k, - max_seqlen, - max_seqlen_k, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) diff --git a/deeplink_ext/internlm_ops/mha/mha_func.py b/deeplink_ext/internlm_ops/mha/mha_func.py deleted file mode 100644 index 3efecb5d..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_func.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") - - -class DeepLinkMultiHeadAttentionFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_fwd( - q, - k, - v, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(q, k, v, out, softmax_lse, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - dq, dk, dv = ext.mha_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - None, - None, - None, - ) - return dq, dk, dv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py deleted file mode 100644 index 33e248f1..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") - - -class DeepLinkMultiHeadAttentionKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_fwd( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(q, kv, out, softmax_lse, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - q, kv, out, softmax_lse, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - ext.mha_bwd( - dout, - q, - kv[:, :, 0], - kv[:, :, 1], - out, - softmax_lse, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dq, - dkv[:, :, 0], - dkv[:, :, 1], - ) - return dq, dkv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py deleted file mode 100644 index 61527adb..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") - - -class DeepLinkMultiHeadAttentionQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_fwd( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(qkv, out, softmax_lse, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - qkv, out, softmax_lse, rng_state = ctx.saved_tensors - dqkv = torch.empty_like(qkv) - rng = torch.Generator(device=qkv.device) - rng.set_state(rng_state) - ext.mha_bwd( - dout, - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - out, - softmax_lse, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ) - return dqkv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_varlen_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_func.py deleted file mode 100644 index 194a458d..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_varlen_func.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") - - -class DeepLinkMultiHeadAttentionVarLenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward( - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - ( - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - rng_state, - ) = ctx.saved_tensors - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - dq, dk, dv = ext.mha_varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - None, - None, - None, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py deleted file mode 100644 index 18569def..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") - - -class DeepLinkMultiHeadAttentionVarLenKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward( - q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - ( - q, - kv, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - rng_state, - ) = ctx.saved_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - ext.mha_varlen_bwd( - dout, - q, - kv[:, 0], - kv[:, 1], - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dq, - dkv[:, 0], - dkv[:, 1], - ) - return dq, dkv, None, None, None, None, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py deleted file mode 100644 index 562d0047..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") - - -class DeepLinkMultiHeadAttentionVarLenQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - dqkv = torch.empty_like(qkv) - rng = torch.Generator(device=qkv.device) - rng.set_state(rng_state) - ext.mha_varlen_bwd( - dout, - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - out, - softmax_lse, - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - ) - return dqkv, None, None, None, None, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/fallback/fallback.py b/deeplink_ext/internlm_ops/mha_fallback.py similarity index 98% rename from deeplink_ext/internlm_ops/mha/fallback/fallback.py rename to deeplink_ext/internlm_ops/mha_fallback.py index 9c0a4c90..b14de68c 100644 --- a/deeplink_ext/internlm_ops/mha/fallback/fallback.py +++ b/deeplink_ext/internlm_ops/mha_fallback.py @@ -4,6 +4,8 @@ import torch.nn as nn import einops +__all__ = ["SelfAttention", "CrossAttention"] + class SelfAttention(nn.Module): def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm.py similarity index 70% rename from deeplink_ext/internlm_ops/rms_norm/deeplink.py rename to deeplink_ext/internlm_ops/rms_norm.py index 406dc4aa..0595f269 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm.py @@ -1,17 +1,17 @@ # Copyright (c) 2024, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +from DeepLinkExt.deeplink_ext.common.rms_norm import rms_norm, rms_norm_backward -assert hasattr(ext, "rms_norm") + +__all__ = ["RMSNorm", "RMSNormWithNormalizedShape"] # 定义自定义的 autograd 函数 class _DeepLinkRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps): - output, inv_rms = ext.rms_norm(hidden_states, None, weight, bias, eps) - + output, inv_rms = rms_norm(hidden_states, None, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) return output @@ -19,7 +19,7 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - grad_input, grad_weight, grad_bias = ext.rms_norm_backward( + grad_input, grad_weight, grad_bias = rms_norm_backward( hidden_states, grad_output, inv_rms, None, weight, bias, eps ) return grad_input, grad_weight, grad_bias, None @@ -28,15 +28,12 @@ def backward(ctx, grad_output): class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output, inv_rms = ext.rms_norm( + output, inv_rms = rms_norm( hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps ) output = output.half() inv_rms = inv_rms.half() ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - hidden_states = hidden_states.half() - weight = weight.half() - bias = bias.half() ctx.intermediate_results = normalized_shape return output @@ -45,24 +42,21 @@ def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - hidden_states = hidden_states.float() - inv_rms = inv_rms.float() - weight = weight.float() - bias = bias.float() - grad_output = grad_output.float() - grad_input, grad_weight, grad_bias = ext.rms_norm_backward( - hidden_states, grad_output, inv_rms, normalized_shape, weight, bias, eps + + grad_input, grad_weight, grad_bias = rms_norm_backward( + hidden_states.float(), + grad_output.float(), + inv_rms.float(), + normalized_shape, + weight.float(), + bias.float(), + eps, ) - grad_output = grad_output.half() - hidden_states = hidden_states.half() - inv_rms = inv_rms.half() - weight = weight.half() - bias = bias.half() return grad_input, grad_weight, grad_bias, None, None # 定义一个 nn.Module 包裹这个自定义函数 -class DeepLinkRMSNorm(torch.nn.Module): +class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5): super().__init__() self.weight = torch.nn.Parameter(torch.ones(hidden_size)) @@ -75,7 +69,7 @@ def forward(self, hidden_states): ) -class DeepLinkRMSNormWithNormalizedShape(torch.nn.Module): +class RMSNormWithNormalizedShape(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5): super().__init__() self.weight = torch.nn.Parameter(torch.ones(hidden_size)) diff --git a/deeplink_ext/internlm_ops/rms_norm/__init__.py b/deeplink_ext/internlm_ops/rms_norm/__init__.py deleted file mode 100644 index 77909d1e..00000000 --- a/deeplink_ext/internlm_ops/rms_norm/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -try: - from .deeplink import DeepLinkRMSNorm, DeepLinkRMSNormWithNormalizedShape -except: - print( - "[deeplink_ext] rms_norm is not implemented in diopi. Falling back to the slower implementation.\n", - end="", - ) - from .fallback import ( - RMSNorm as DeepLinkRMSNorm, - RMSNorm as DeepLinkRMSNormWithNormalizedShape, # TODO(lljbash): check how this works - ) -from . import fallback - -__all__ = ["DeepLinkRMSNorm", "DeepLinkRMSNormWithNormalizedShape", "fallback"] diff --git a/deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py b/deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py deleted file mode 100644 index cb7c3ac4..00000000 --- a/deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from .fallback import RMSNorm - -__all__ = ["RMSNorm"] diff --git a/deeplink_ext/internlm_ops/rms_norm/fallback/fallback.py b/deeplink_ext/internlm_ops/rms_norm_fallback.py similarity index 89% rename from deeplink_ext/internlm_ops/rms_norm/fallback/fallback.py rename to deeplink_ext/internlm_ops/rms_norm_fallback.py index e58a83e7..6cd64fcb 100644 --- a/deeplink_ext/internlm_ops/rms_norm/fallback/fallback.py +++ b/deeplink_ext/internlm_ops/rms_norm_fallback.py @@ -2,6 +2,8 @@ import torch +__all__ = ["RMSNorm", "RMSNormWithNormalizedShape"] + # RMSNorm fallback from InternLM class RMSNorm(torch.nn.Module): @@ -22,3 +24,6 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states + + +RMSNormWithNormalizedShape = RMSNorm diff --git a/deeplink_ext/internlm_ops/rotary/__init__.py b/deeplink_ext/internlm_ops/rotary/__init__.py deleted file mode 100644 index 8ca5f9e2..00000000 --- a/deeplink_ext/internlm_ops/rotary/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -try: - from .deeplink import apply_rotary -except: - print( - "[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", - end="", - ) - from .fallback import apply_rotary -from . import fallback - -__all__ = ["apply_rotary", "fallback"] diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary_embedding.py similarity index 100% rename from deeplink_ext/internlm_ops/rotary/deeplink.py rename to deeplink_ext/internlm_ops/rotary_embedding.py diff --git a/deeplink_ext/internlm_ops/rotary/fallback.py b/deeplink_ext/internlm_ops/rotary_embedding_fallback.py similarity index 100% rename from deeplink_ext/internlm_ops/rotary/fallback.py rename to deeplink_ext/internlm_ops/rotary_embedding_fallback.py diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index f06d7eb3..17bc373f 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -66,10 +66,10 @@ def CrossEntropyLossProxy(reduction, **_): import flash_attn.modules.mha # type: ignore - flash_attn.modules.mha.SelfAttention = ext.mha.DeepLinkSelfAttention - flash_attn.modules.mha.FlashSelfAttention = ext.mha.DeepLinkSelfAttention - flash_attn.modules.mha.CrossAttention = ext.mha.DeepLinkCrossAttention - flash_attn.modules.mha.FlashCrossAttention = ext.mha.DeepLinkCrossAttention + flash_attn.modules.mha.SelfAttention = ext.mha.SelfAttention + flash_attn.modules.mha.FlashSelfAttention = ext.mha.SelfAttention + flash_attn.modules.mha.CrossAttention = ext.mha.CrossAttention + flash_attn.modules.mha.FlashCrossAttention = ext.mha.CrossAttention def _patch_ops(): import deeplink_ext.internlm_ops as ext @@ -115,7 +115,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore # if isinstance(module, RMSNorm): # which will fail under this patch. Thus we need also trick `isinstance`. internlm.model.norm.RMSNormTorch.__new__ = lambda _, *args, **kwargs: ( - ext.rms_norm.DeepLinkRMSNormWithNormalizedShape(*args, **kwargs) + ext.rms_norm.RMSNormWithNormalizedShape(*args, **kwargs) ) isinstance_orig = builtins.isinstance builtins.isinstance = lambda obj, class_or_tuple: ( @@ -129,9 +129,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore else (class_or_tuple,) ) ) - and isinstance_orig( - obj, ext.rms_norm.DeepLinkRMSNormWithNormalizedShape - ) + and isinstance_orig(obj, ext.rms_norm.RMSNormWithNormalizedShape) ) ) diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index 009868be..3b95024c 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -51,7 +51,9 @@ def patch_token_softmax_reducev_inference(): ) def patch_rms_norm_lightllm(): - rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm + from .common.rms_norm import rms_norm + + rms_norm_pack.rmsnorm_forward = rms_norm def patch_rotary_emb(): def rotary_emb(q, cos, sin): diff --git a/tests/test_mha_internlm.py b/tests/test_mha_internlm.py index b74ecc47..735be33b 100644 --- a/tests/test_mha_internlm.py +++ b/tests/test_mha_internlm.py @@ -29,7 +29,7 @@ def _run_cross_attention( D = 8 qkv = torch.randn(B, S, 3, H, D, dtype=torch.float16).cuda() output_gold, grad_gold = _run_self_attention(ext.fallback.SelfAttention, qkv) -output_ext, grad_ext = _run_self_attention(ext.DeepLinkSelfAttention, qkv) +output_ext, grad_ext = _run_self_attention(ext.SelfAttention, qkv) assert torch.allclose(output_gold, output_ext, atol=1e-3) print("SelfAttention forward test pass") assert torch.allclose(grad_gold, grad_ext, atol=2e-3) @@ -40,7 +40,7 @@ def _run_cross_attention( output_gold, dq_gold, dkv_gold = _run_cross_attention( ext.fallback.CrossAttention, q, kv ) -output_ext, dq_ext, dkv_ext = _run_cross_attention(ext.DeepLinkCrossAttention, q, kv) +output_ext, dq_ext, dkv_ext = _run_cross_attention(ext.CrossAttention, q, kv) assert torch.allclose(output_gold, output_ext, atol=1e-3) print("CrossAttention forward test pass") assert torch.allclose(dq_gold, dq_ext, atol=2e-3) diff --git a/tests/test_rms_internlm.py b/tests/test_rms_internlm.py index 72bca9ef..dbb02667 100644 --- a/tests/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -2,10 +2,14 @@ import torch import numpy as np -import deeplink_ext.internlm_ops.rms_norm as ext +from deeplink_ext.internlm_ops.rms_norm import RMSNorm, RMSNormWithNormalizedShape +from deeplink_ext.internlm_ops.rms_norm_fallback import ( + RMSNorm as RMSNorm_fb, + RMSNormWithNormalizedShape as RMSNormWithNormalizedShape_fb, +) -def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): +def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): x_base = torch.randn(5, 5, requires_grad=True).cuda() x_base.retain_grad() @@ -29,9 +33,9 @@ def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): print( "Test case: normalized_shape == None: grad_inputs closed ? ", - test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm), + rms_norm_test(RMSNorm_fb, RMSNorm), ) print( "Test case: normalized_shape == weight.size(): grad_inputs closed ? ", - test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape), + rms_norm_test(RMSNormWithNormalizedShape_fb, RMSNormWithNormalizedShape), ) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index ea9f66b3..ba57369b 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +from deeplink_ext.internlm_ops.rms_norm.deeplink import rms_norm, rms_norm_backward # 定义输入张量 input = torch.randn(5, 5, requires_grad=True).cuda() @@ -17,10 +17,9 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) +output, inv_rms = rms_norm(input, None, weight, bias, 1e-6) -# 使用 RMS normalization 反向传播 -grad_input, grad_weight, grad_bias = ext.rms_norm_backward( +grad_input, grad_weight, grad_bias = rms_norm_backward( input, grad_output, inv_rms, None, weight, bias, 1e-6 ) @@ -28,5 +27,13 @@ print("Grad Input:", grad_input) print("Grad Weight:", grad_weight) print("Grad Bias:", grad_bias) + +input.requires_grad_(True) +weight.requires_grad_(True) +bias.requires_grad_(True) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight +grads = torch.autograd.grad(b, [input, weight, bias], grad_output, allow_unused=True) assert torch.allclose(output, b) +assert torch.allclose(grad_input, grads[0]) +assert torch.allclose(grad_weight, grads[1]) +# assert torch.allclose(grad_bias, grads[2]) diff --git a/tests/test_rotary_emb_internlm.py b/tests/test_rotary_emb_internlm.py index e212bd82..70e93d05 100644 --- a/tests/test_rotary_emb_internlm.py +++ b/tests/test_rotary_emb_internlm.py @@ -1,7 +1,10 @@ # Copyright (c) 2023, DeepLink. import torch -import deeplink_ext.internlm_ops.rotary as ext +from deeplink_ext.internlm_ops.rotary_embedding import apply_rotary +from deeplink_ext.internlm_ops.rotary_embeddinig_fallback import ( + apply_rotary as apply_rotary_fb, +) def RotaryEmbTestFloat16() -> bool: @@ -13,10 +16,8 @@ def RotaryEmbTestFloat16() -> bool: inplace = True interleaved = False - res1 = ext.fallback.apply_rotary( - input, cos, sin, interleaved=interleaved, inplace=inplace - ) - res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) + res1 = apply_rotary_fb(input, cos, sin, interleaved=interleaved, inplace=inplace) + res2 = apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) # there is a little calculated error with ascend when dtype is float16 return torch.allclose(res1, res2, atol=1e-2, rtol=1e-3) @@ -31,10 +32,8 @@ def RotaryEmbTestFloat32() -> bool: inplace = True interleaved = False - res1 = ext.fallback.apply_rotary( - input, cos, sin, interleaved=interleaved, inplace=inplace - ) - res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) + res1 = apply_rotary_fb(input, cos, sin, interleaved=interleaved, inplace=inplace) + res2 = apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) return torch.allclose(res1, res2)