diff --git a/deeplink_ext/internlm_ops/rotary/__init__.py b/deeplink_ext/internlm_ops/rotary/__init__.py index 51831c03..5381737a 100644 --- a/deeplink_ext/internlm_ops/rotary/__init__.py +++ b/deeplink_ext/internlm_ops/rotary/__init__.py @@ -1,11 +1,14 @@ # Copyright (c) 2024, DeepLink. try: - from .deeplink import DeepLinkApplyRotaryEmbQKV_ + from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb except: print( "[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", end="", ) - from .fallback import ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_ + from .fallback import ( + ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_, + ApplyRotaryEmb as DeepLinkApplyRotaryEmb, + ) from . import fallback diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary/deeplink.py index ca0f5705..d065cba8 100644 --- a/deeplink_ext/internlm_ops/rotary/deeplink.py +++ b/deeplink_ext/internlm_ops/rotary/deeplink.py @@ -69,3 +69,63 @@ def backward(ctx, dqkv): interleaved, ) return dqkv, None, None, None, None, None + + +class DeepLinkApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward(ctx, x, cos, sin, interleaved=False, inplace=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + batch, seqlen, nheads, headdim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + out = torch.empty_like(x) if not inplace else x + out_ro = out[..., :rotary_dim] + + ext.apply_rotary( + out_ro, + x_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + interleaved, + ) + + if not inplace and rotary_dim < headdim: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + ctx.inplace = inplace + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + _, seqlen, _, headdim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + inplace = ctx.inplace + do_ro = do[..., :rotary_dim] + dx = torch.empty_like(do) if not inplace else do + dx_ro = dx[..., :rotary_dim] + ext.apply_rotary( + dx_ro, + do_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + ctx.interleaved, + ) + if not inplace and rotary_dim < headdim: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + return dx, None, None, None, None diff --git a/deeplink_ext/internlm_ops/rotary/fallback/__init__.py b/deeplink_ext/internlm_ops/rotary/fallback/__init__.py index 3e863e58..64ec2ffb 100644 --- a/deeplink_ext/internlm_ops/rotary/fallback/__init__.py +++ b/deeplink_ext/internlm_ops/rotary/fallback/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) 2024, DeepLink. -from .fallback import ApplyRotaryEmbQKV_ +from .fallback import ApplyRotaryEmbQKV_, ApplyRotaryEmb diff --git a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py index 9b04ab0b..221e0573 100644 --- a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py +++ b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py @@ -113,3 +113,9 @@ def backward(ctx, dqkv): ) dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1) return dqkv, None, None, None, None, None + + +class ApplyRotaryEmb: + @staticmethod + def apply(*args, **kwargs): + raise NotImplementedError("fallback.ApplyRotaryEmb is not implemented yet") diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index c2d90a8b..8651c6e5 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -48,11 +48,6 @@ def CrossEntropyLossProxy(reduction, **_): def _patch_ops(): import internlm.model.embedding # type: ignore - def NotImplementedLegacyRotaryEmb(*args, **kwargs): - raise NotImplementedError( - "we assume that legacy_apply_rotary_embed is not used in internlm" - ) - class NonLegacyRotaryEmbQKV_(torch.autograd.Function): """the first 2 dims of qkv has been squeezed""" @@ -74,7 +69,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply internlm.model.embedding.legacy_apply_rotary_embed = ( - NotImplementedLegacyRotaryEmb + ext.rotary.DeepLinkApplyRotaryEmb.apply ) internlm.model.embedding.legacy_apply_rotary_embed_qkv = ( ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply