Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(internlm): implement patch for non-legacy rotary_emb #38

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output,

void extApplyRotary(at::Tensor output, const at::Tensor& input,
const at::Tensor& cos, const at::Tensor& sin,
const bool conj, const bool interleaved = false) {
const bool conj, const bool interleaved) {
callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved);
}

Expand Down Expand Up @@ -239,7 +239,8 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
auto dim = q.size(-1);
auto cos_view = cos.view({seq_len, 1, dim / 2});
auto sin_view = sin.view({seq_len, 1, dim / 2});
callDiopi(diopiRotaryEmbedding, q, q, cos_view, sin_view, false, false);
callDiopi(diopiRotaryEmbedding, q, q, cos_view, sin_view, /*conj=*/false,
/*interleaved=*/false);
}

// 判断是否有对应的 diopi 实现:
Expand Down
7 changes: 2 additions & 5 deletions deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmb, DeepLinkApplyRotaryEmbQKV_
from .deeplink import DeepLinkApplyRotaryEmbQKV_
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n",
end="",
)
from .fallback import (
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
)
from .fallback import ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_
from . import fallback
107 changes: 4 additions & 103 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
interleaved,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
Expand All @@ -37,7 +37,7 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
interleaved,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
Expand All @@ -57,7 +57,7 @@ def backward(ctx, dqkv):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
interleaved,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
Expand All @@ -66,105 +66,6 @@ def backward(ctx, dqkv):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = (
x_ro.chunk(2, dim=-1)
if not interleaved
else (x_ro[..., ::2], x_ro[..., 1::2])
)
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

if inplace:
ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
)
else:
ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
)

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
if inplace:
ext.apply_rotary(
do_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
)
else:
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
)

if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
2 changes: 1 addition & 1 deletion deeplink_ext/internlm_ops/rotary/fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2024, DeepLink.

from .fallback import ApplyRotaryEmb, ApplyRotaryEmbQKV_
from .fallback import ApplyRotaryEmbQKV_
75 changes: 0 additions & 75 deletions deeplink_ext/internlm_ops/rotary/fallback/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,78 +114,3 @@ def backward(ctx, dqkv):
)
dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1)
return dqkv, None, None, None, None, None


class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = (
x_ro.chunk(2, dim=-1)
if not interleaved
else (x_ro[..., ::2], x_ro[..., 1::2])
)
out = torch.empty_like(x)
out_ro = out[..., :rotary_dim]

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
)

if rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return out

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do)

dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
)

if rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
29 changes: 24 additions & 5 deletions deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def _patch_internlm():
import os
import sys
import unittest.mock as mock
import torch
import deeplink_ext.internlm_ops as ext

def _find_or_mock_module(module_name):
Expand Down Expand Up @@ -47,15 +48,33 @@ def CrossEntropyLossProxy(reduction, **_):
def _patch_ops():
import internlm.model.embedding # type: ignore

# TODO(lljbash,gongqiwei): implement a module aligned with rotary_emb
def NotImplementedRotaryEnb(*args, **kwargs):
def NotImplementedLegacyRotaryEmb(*args, **kwargs):
raise NotImplementedError(
"the patch for apply_rotary_emb_qkv_ (requires rotary_emb) has not been implemented in deeplink_ext yet"
"we assume that legacy_apply_rotary_embed is not used in internlm"
)

internlm.model.embedding.apply_rotary_emb_qkv_ = NotImplementedRotaryEnb
class NonLegacyRotaryEmbQKV_(torch.autograd.Function):
"""the first 2 dims of qkv has been squeezed"""

@staticmethod
def forward(ctx, qkv: torch.Tensor, *args, **kwargs):
unsqueezed_qkv = qkv.view([1] + list(qkv.shape))
out: torch.Tensor = ext.rotary.DeepLinkApplyRotaryEmbQKV_.forward(
ctx, unsqueezed_qkv, *args, **kwargs
)
return out.view(out.shape[1:])

@staticmethod
def backward(ctx, dqkv: torch.Tensor, *args, **kwargs):
unqueezed_dqkv = dqkv.view([1] + list(dqkv.shape))
out: tuple = ext.rotary.DeepLinkApplyRotaryEmbQKV_.backward(
ctx, unqueezed_dqkv, *args, **kwargs
)
return (out[0].view(out[0].shape[1:]),) + out[1:]

internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply
internlm.model.embedding.legacy_apply_rotary_embed = (
ext.rotary.DeepLinkApplyRotaryEmb.apply
NotImplementedLegacyRotaryEmb
)
internlm.model.embedding.legacy_apply_rotary_embed_qkv = (
ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply
Expand Down