Skip to content

Commit

Permalink
feat(rotary,internlm): add DeepLinkApplyRotaryEmb (fallback unimpleme…
Browse files Browse the repository at this point in the history
…nted)
  • Loading branch information
lljbash committed Feb 1, 2024
1 parent a5d79ed commit 5befa86
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 9 deletions.
7 changes: 5 additions & 2 deletions deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmbQKV_
from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n",
end="",
)
from .fallback import ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_
from .fallback import (
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
)
from . import fallback
60 changes: 60 additions & 0 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,63 @@ def backward(ctx, dqkv):
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
dx = torch.empty_like(do) if not inplace else do
dx_ro = dx[..., :rotary_dim]
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
2 changes: 1 addition & 1 deletion deeplink_ext/internlm_ops/rotary/fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2024, DeepLink.

from .fallback import ApplyRotaryEmbQKV_
from .fallback import ApplyRotaryEmbQKV_, ApplyRotaryEmb
6 changes: 6 additions & 0 deletions deeplink_ext/internlm_ops/rotary/fallback/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ def backward(ctx, dqkv):
)
dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1)
return dqkv, None, None, None, None, None


class ApplyRotaryEmb:
@staticmethod
def apply(*args, **kwargs):
raise NotImplementedError("fallback.ApplyRotaryEmb is not implemented yet")
7 changes: 1 addition & 6 deletions deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ def CrossEntropyLossProxy(reduction, **_):
def _patch_ops():
import internlm.model.embedding # type: ignore

def NotImplementedLegacyRotaryEmb(*args, **kwargs):
raise NotImplementedError(
"we assume that legacy_apply_rotary_embed is not used in internlm"
)

class NonLegacyRotaryEmbQKV_(torch.autograd.Function):
"""the first 2 dims of qkv has been squeezed"""

Expand All @@ -74,7 +69,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs):

internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply
internlm.model.embedding.legacy_apply_rotary_embed = (
NotImplementedLegacyRotaryEmb
ext.rotary.DeepLinkApplyRotaryEmb.apply
)
internlm.model.embedding.legacy_apply_rotary_embed_qkv = (
ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply
Expand Down

0 comments on commit 5befa86

Please sign in to comment.