From 00a75feee2b05d45b0f22a5b9e74f8eea16314a7 Mon Sep 17 00:00:00 2001 From: Lingjie Li Date: Thu, 7 Mar 2024 18:42:03 +0800 Subject: [PATCH 1/4] refactor!(internlm): a new way to patch rotary based on flash_attn 2.2.1 --- deeplink_ext/internlm_ops/rotary/__init__.py | 9 +- deeplink_ext/internlm_ops/rotary/deeplink.py | 172 +++++----------- deeplink_ext/internlm_ops/rotary/fallback.py | 98 +++++++++ .../internlm_ops/rotary/fallback/__init__.py | 5 - .../internlm_ops/rotary/fallback/fallback.py | 193 ------------------ deeplink_ext/patch_internlm.py | 17 +- tests/test_rotary_emb_internlm.py | 78 +------ 7 files changed, 173 insertions(+), 399 deletions(-) create mode 100644 deeplink_ext/internlm_ops/rotary/fallback.py delete mode 100644 deeplink_ext/internlm_ops/rotary/fallback/__init__.py delete mode 100644 deeplink_ext/internlm_ops/rotary/fallback/fallback.py diff --git a/deeplink_ext/internlm_ops/rotary/__init__.py b/deeplink_ext/internlm_ops/rotary/__init__.py index e1534637..8ca5f9e2 100644 --- a/deeplink_ext/internlm_ops/rotary/__init__.py +++ b/deeplink_ext/internlm_ops/rotary/__init__.py @@ -1,16 +1,13 @@ # Copyright (c) 2024, DeepLink. try: - from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb + from .deeplink import apply_rotary except: print( "[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", end="", ) - from .fallback import ( - ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_, - ApplyRotaryEmb as DeepLinkApplyRotaryEmb, - ) + from .fallback import apply_rotary from . import fallback -__all__ = ["DeepLinkApplyRotaryEmbQKV_", "DeepLinkApplyRotaryEmb", "fallback"] +__all__ = ["apply_rotary", "fallback"] diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary/deeplink.py index d065cba8..670a47b9 100644 --- a/deeplink_ext/internlm_ops/rotary/deeplink.py +++ b/deeplink_ext/internlm_ops/rotary/deeplink.py @@ -1,131 +1,65 @@ # Copyright (c) 2024, DeepLink. +from typing import Optional, Union import torch from einops import rearrange import deeplink_ext.cpp_extensions as ext assert hasattr(ext, "apply_rotary") +__all__ = ["apply_rotary"] -class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - q_ro = qkv[:, :, 0, :, :rotary_dim] - ext.apply_rotary( - q_ro, - q_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, - ) - k_ro = qkv[:, :, 1, :, :rotary_dim] - ext.apply_rotary( - k_ro, - k_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - interleaved = ctx.interleaved - _, seqlen, _, _, headdim = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - ext.apply_rotary( - dq_ro, - dq_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - interleaved, - ) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - ext.apply_rotary( - dk_ro, - dk_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - interleaved, - ) - return dqkv, None, None, None, None, None - - -class DeepLinkApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False, inplace=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - batch, seqlen, nheads, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - out = torch.empty_like(x) if not inplace else x - out_ro = out[..., :rotary_dim] - ext.apply_rotary( - out_ro, - x_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: + raise NotImplementedError( + "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" ) + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - if not inplace and rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.inplace = inplace - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - inplace = ctx.inplace - do_ro = do[..., :rotary_dim] - dx = torch.empty_like(do) if not inplace else do - dx_ro = dx[..., :rotary_dim] - ext.apply_rotary( - dx_ro, - do_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ctx.interleaved, - ) - if not inplace and rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ext.apply_rotary( + output, + x, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + conjugate, + interleaved, + ) + return output diff --git a/deeplink_ext/internlm_ops/rotary/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback.py new file mode 100644 index 00000000..7ce7c652 --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary/fallback.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, DeepLink. + +from typing import Optional, Union +import torch +from einops import rearrange, repeat + +__all__ = ["apply_rotary"] + + +def _rotate_half(x: torch.Tensor, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def _apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved=False +): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + + data_type = x.dtype + x = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ).to(data_type) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: + raise NotImplementedError( + "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" + ) + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + if conjugate: + sin = -sin + out = _apply_rotary_emb_torch(x, cos[:seqlen], sin[:seqlen], interleaved) + if inplace: + x.copy_(out) + out = x + return out diff --git a/deeplink_ext/internlm_ops/rotary/fallback/__init__.py b/deeplink_ext/internlm_ops/rotary/fallback/__init__.py deleted file mode 100644 index f0166042..00000000 --- a/deeplink_ext/internlm_ops/rotary/fallback/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from .fallback import ApplyRotaryEmbQKV_, ApplyRotaryEmb - -__all__ = ["ApplyRotaryEmbQKV_", "ApplyRotaryEmb"] diff --git a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py deleted file mode 100644 index abf63af6..00000000 --- a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from einops import rearrange - - -# Rotary_emb -# torch 绕过实现函数 -def apply_rotary(x1, x2, cos, sin, conj): - data_dtype = x1.dtype - x1 = x1.to(torch.float32) - x2 = x2.to(torch.float32) - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - if not conj: - out1 = x1 * cos - x2 * sin - out2 = x1 * sin + x2 * cos - else: - out1 = x1 * cos + x2 * sin - out2 = -x1 * sin + x2 * cos - out1 = out1.to(data_dtype) - out2 = out2.to(data_dtype) - return out1, out2 - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - """ - qkv: (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - q_ro = qkv[:, :, 0, :, :rotary_dim] - q1, q2 = ( - q_ro.chunk(2, dim=-1) - if not interleaved - else (q_ro[..., ::2], q_ro[..., 1::2]) - ) - q1, q2 = apply_rotary( - q1, - q2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - qkv[:, :, 0, :, :rotary_dim] = torch.cat((q1, q2), dim=-1) - k_ro = qkv[:, :, 1, :, :rotary_dim] - k1, k2 = ( - k_ro.chunk(2, dim=-1) - if not interleaved - else (k_ro[..., ::2], k_ro[..., 1::2]) - ) - k1, k2 = apply_rotary( - k1, - k2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - qkv[:, :, 1, :, :rotary_dim] = torch.cat((k1, k2), dim=-1) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - _, seqlen, _, _, headdim = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = ( - dq_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (dq_ro[..., ::2], dq_ro[..., 1::2]) - ) - dq1, dq2 = apply_rotary( - dq1, - dq2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ) - dqkv[:, :, 0, :, :rotary_dim] = torch.cat((dq1, dq2), dim=-1) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = ( - dk_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (dk_ro[..., ::2], dk_ro[..., 1::2]) - ) - dk1, dk2 = apply_rotary( - dk1, - dk2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ) - dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1) - return dqkv, None, None, None, None, None - - -class ApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False, inplace=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - - assert ( - interleaved == False - ), "Interleaved rotary embedding fallback is not supported yet" - assert ( - inplace == False - ), "Inplace rotary embedding fallback is not supported yet" - - batch, seqlen, nheads, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = ( - x_ro.chunk(2, dim=-1) - if not interleaved - else (x_ro[..., ::2], x_ro[..., 1::2]) - ) - out = torch.empty_like(x) - - o1, o2 = apply_rotary( - x1, - x2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - - out[..., :rotary_dim] = torch.cat((o1, o2), dim=-1) - if rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.inplace = inplace - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - do_ro = do[..., :rotary_dim] - do1, do2 = ( - do_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (do_ro[..., ::2], do_ro[..., 1::2]) - ) - dx = torch.empty_like(do) - - dx1, dx2 = apply_rotary( - do1, - do2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - - dx[..., :rotary_dim] = torch.cat((dx1, dx2), dim=-1) - - if rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index 6e4e0a9f..d5ac32b2 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -71,34 +71,33 @@ def CrossEntropyLossProxy(reduction, **_): def _patch_ops(): import deeplink_ext.internlm_ops as ext + import flash_attn.layers.rotary # type: ignore import internlm.model.embedding # type: ignore + flash_attn.layers.rotary.apply_rotary = ext.rotary.apply_rotary + class NonLegacyRotaryEmbQKV_(torch.autograd.Function): """the first 2 dims of qkv has been squeezed""" @staticmethod def forward(ctx, qkv: torch.Tensor, *args, **kwargs): # type: ignore unsqueezed_qkv = qkv.view([1] + list(qkv.shape)) - out: torch.Tensor = ext.rotary.DeepLinkApplyRotaryEmbQKV_.forward( - ctx, unsqueezed_qkv, *args, **kwargs + out: torch.Tensor = ( + internlm.model.embedding.LegacyApplyRotaryEmbQKV_.forward( + ctx, unsqueezed_qkv, *args, **kwargs + ) ) return out.view(out.shape[1:]) @staticmethod def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore unqueezed_dqkv = dqkv.view([1] + list(dqkv.shape)) - out: tuple = ext.rotary.DeepLinkApplyRotaryEmbQKV_.backward( + out: tuple = internlm.model.embedding.LegacyApplyRotaryEmbQKV_.backward( ctx, unqueezed_dqkv, *args, **kwargs ) return (out[0].view(out[0].shape[1:]),) + out[1:] internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply - internlm.model.embedding.legacy_apply_rotary_embed = ( - ext.rotary.DeepLinkApplyRotaryEmb.apply - ) - internlm.model.embedding.legacy_apply_rotary_embed_qkv = ( - ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply - ) import internlm.model.norm # type: ignore diff --git a/tests/test_rotary_emb_internlm.py b/tests/test_rotary_emb_internlm.py index 7cc4b5e2..b84ea165 100644 --- a/tests/test_rotary_emb_internlm.py +++ b/tests/test_rotary_emb_internlm.py @@ -4,77 +4,21 @@ import deeplink_ext.internlm_ops.rotary as ext -def RotaryEmbTest(func_name): - if func_name == "RotaryEmbQKV": - torch_apply = ext.fallback.ApplyRotaryEmbQKV_.apply - dipu_apply = ext.DeepLinkApplyRotaryEmbQKV_.apply - input = torch.randn( - 1, 125, 3, 16, 32, dtype=torch.float16, requires_grad=True - ).cuda() - elif func_name == "RotaryEmb": - torch_apply = ext.fallback.ApplyRotaryEmb.apply - dipu_apply = ext.DeepLinkApplyRotaryEmb.apply - input = torch.randn( - 1, 125, 16, 32, dtype=torch.float16, requires_grad=True - ).cuda() - else: - print(f"{func_name} is not supported.") - return False +def RotaryEmbTest() -> bool: + input = torch.randn(1, 125, 16, 32, dtype=torch.float16).cuda() - loss_fn = torch.nn.MSELoss() - cos = torch.randn(257, 16, dtype=torch.float16).cuda() - sin = torch.randn(257, 16, dtype=torch.float16).cuda() + cos = torch.randn(217, 16, dtype=torch.float16).cuda() + sin = torch.randn(217, 16, dtype=torch.float16).cuda() input1 = input.detach().clone() - input1.requires_grad = True - cos1 = cos.clone() - sin1 = sin.clone() - cos_k = None - sin_k = None + inplace = True interleaved = False - # 调用前向传播 - if func_name == "RotaryEmbQKV": - res1 = torch_apply(input, cos, sin, cos_k, sin_k, interleaved) - res2 = dipu_apply(input1, cos1, sin1, cos_k, sin_k, interleaved) - elif func_name == "RotaryEmb": - res1 = torch_apply(input, cos, sin, interleaved) - res2 = dipu_apply(input1, cos1, sin1, interleaved) - else: - print(f"{func_name} is not supported.") - return False + res1 = ext.fallback.apply_rotary( + input, cos, sin, interleaved=interleaved, inplace=inplace + ) + res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) - # 验证前向传播结果 - forward_correct = torch.allclose(res1, res2) + return torch.allclose(res1, res2) - # 计算第一个损失 - c = torch.ones_like(res1) - loss1 = loss_fn(res1, c) # 将输出的元素求和,得到标量 - input.retain_grad() - loss1.backward() - # 计算第二个损失 - c2 = torch.ones_like(res1) - loss2 = loss_fn(res2, c2) # 将输出的元素求和,得到标量 - input1.retain_grad() - loss2.backward() - - # 验证第一个反向传播梯度 - grad1 = input.grad - grad2 = input1.grad - backward_correct = torch.allclose(grad1, grad2) - # 判断前向和反向传播结果是否都正确 - if forward_correct and backward_correct: - print(f"{func_name} both forward and backward pass tests passed.") - return True - else: - print( - f"{func_name} tests failed: Forward pass:", - forward_correct, - "Backward pass:", - backward_correct, - ) - return False - - -assert RotaryEmbTest("RotaryEmbQKV") -assert RotaryEmbTest("RotaryEmb") +assert RotaryEmbTest() From 08f96f4efb232baad3b597af3441ae74f3851056 Mon Sep 17 00:00:00 2001 From: Lingjie Li Date: Tue, 12 Mar 2024 17:03:52 +0800 Subject: [PATCH 2/4] fix(internevo): trick isinstance for RMSNorm patch for latest internevo --- deeplink_ext/patch_internlm.py | 42 +++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index d5ac32b2..65d22f91 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -2,6 +2,8 @@ import os +__all__ = [] + _force_fallback = os.environ.get("DEEPLINK_EXT_FORCE_FALLBACK", "0") != "0" @@ -99,18 +101,44 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply + import builtins import internlm.model.norm # type: ignore - # NOTE: RMSNormTorch class object has been assigned to RMSNorm via + # HACK: RMSNormTorch class object has been assigned to RMSNorm via # RMSNorm = try_import_RMSNorm() - # everywhere (e.g. in modeling_llama.py). - # Thus simply reassigning RMSNormTorch to DeepLinkRMSNorm won't work. - # And we don't want to reassign every RMSNorm to DeepLinkRMSNorm. - # So we patch RMSNormTorch.__new__ to create a DeepLinkRMSNorm instance - # whenever RMSNorm(...) is called. + # everywhere (e.g. in modeling_llama.py). Thus simply reassigning + # RMSNormTorch to DeepLinkRMSNorm won't work. But we don't want to + # reassign every RMSNorm to DeepLinkRMSNorm. So we patch + # RMSNormTorch.__new__ to create a DeepLinkRMSNorm instance whenever + # RMSNorm(...) is called. + # This is not enough though. In latest internevo, there are checks like + # if isinstance(module, RMSNorm): + # which will fail under this patch. Thus we need also trick `isinstance`. internlm.model.norm.RMSNormTorch.__new__ = lambda _, *args, **kwargs: ( ext.rms_norm.DeepLinkRMSNormWithNormalizedShape(*args, **kwargs) ) + isinstance_orig = builtins.isinstance + builtins.isinstance = lambda obj, class_or_tuple: ( + isinstance_orig(obj, class_or_tuple) + or ( + ( + internlm.model.norm.RMSNormTorch + in ( + class_or_tuple + if isinstance_orig(class_or_tuple, tuple) + else (class_or_tuple,) + ) + ) + and isinstance_orig( + obj, ext.rms_norm.DeepLinkRMSNormWithNormalizedShape + ) + ) + ) + + import fused_dense_lib # type: ignore + import internlm.model.utils # type: ignore + + fused_dense_lib.linear_bias_wgrad = internlm.model.utils.linear_bias_wgrad_torch cpp_ext_found = _find_or_mock_module("deeplink_ext.cpp_extensions") if not cpp_ext_found: @@ -130,5 +158,3 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore _patch_internlm(force_fallback=_force_fallback) - -__all__ = [] From ef6b9eb6f07b504d26dc68e5f1a1d80209cb3c97 Mon Sep 17 00:00:00 2001 From: Lingjie Li Date: Tue, 12 Mar 2024 17:05:06 +0800 Subject: [PATCH 3/4] fix(internevo): fix mha varlen kvpacked func --- deeplink_ext/internlm_ops/mha/mha.py | 10 +++++----- .../mha/mha_varlen_kvpacked_func.py | 19 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/deeplink_ext/internlm_ops/mha/mha.py b/deeplink_ext/internlm_ops/mha/mha.py index c718ecd2..00798027 100644 --- a/deeplink_ext/internlm_ops/mha/mha.py +++ b/deeplink_ext/internlm_ops/mha/mha.py @@ -78,12 +78,12 @@ def forward( q, kv, causal=None, - cu_seqlens_q=None, - max_seqlen_q=None, + cu_seqlens=None, + max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None, ): - if cu_seqlens_q is None: + if cu_seqlens is None: # padded return DeepLinkMultiHeadAttentionKVPackedFunc.apply( q, @@ -98,9 +98,9 @@ def forward( return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply( q, kv, - cu_seqlens_q, + cu_seqlens, cu_seqlens_k, - max_seqlen_q, + max_seqlen, max_seqlen_k, self.dropout_p if self.training else 0.0, self.softmax_scale, diff --git a/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py index b69a0c44..18569def 100644 --- a/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py +++ b/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py @@ -25,8 +25,8 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( q, - kv[:, :, 0], - kv[:, :, 1], + kv[:, 0], + kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -40,7 +40,8 @@ def forward( q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() ) ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen_q + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -63,20 +64,20 @@ def backward(ctx, dout): ext.mha_varlen_bwd( dout, q, - kv[:, :, 0], - kv[:, :, 1], + kv[:, 0], + kv[:, 1], out, softmax_lse, cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen, - ctx.max_seqlen, + ctx.max_seqlen_q, + ctx.max_seqlen_k, ctx.dropout_p, ctx.causal, rng, ctx.softmax_scale, dq, - dkv[:, :, 0], - dkv[:, :, 1], + dkv[:, 0], + dkv[:, 1], ) return dq, dkv, None, None, None, None, None, None, None, None From ce04bfefcf6366778a47655cf62086c6112a4dff Mon Sep 17 00:00:00 2001 From: Lingjie Li Date: Tue, 12 Mar 2024 17:28:08 +0800 Subject: [PATCH 4/4] docs: update README.md about InternEvo --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 59b2e512..5b67a423 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,9 @@ pip install -e . ## Usage -### InternLM +### InternEvo + +适配版本 https://github.com/DeepLink-org/InternEvo/tree/deeplinkext ```python import deeplink_ext.patch_internlm