Skip to content

Commit

Permalink
feat: reimpl rotary embedding for npu (#127)
Browse files Browse the repository at this point in the history
1. reimpl rotary embedding with  npu_rotary_mul  in torch_npu
2. using combined rms_norm and rotary_embedding temporarily for
accuracy.
  • Loading branch information
jingguo-st authored Sep 12, 2024
1 parent ad193f8 commit 22bffd4
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 144 deletions.
3 changes: 2 additions & 1 deletion deeplink_ext/ascend_speed/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import RotaryEmbedding
# from ._rotary_embedding_npu import RotaryEmbedding
from .rotary_embedding_fallback import RotaryEmbeddingTorch as RotaryEmbedding
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import RotaryEmbedding
else:
Expand Down
72 changes: 18 additions & 54 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,6 @@
__all__ = ["ApplyRotaryEmb"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35
class ApplyRotaryEmb(torch.autograd.Function):
"""
Expand Down Expand Up @@ -67,45 +38,38 @@ def forward(
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

out = _apply_rotary(
x[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")

cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)

ctx.save_for_backward(cos, sin)
rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin)
ctx.save_for_backward(cat_cos, cat_sin)
ctx.interleaved = interleaved
ctx.in_place = in_place

if in_place:
x[..., :rotary_dim].copy_(out[..., :rotary_dim])
x[..., :rotary_dim].copy_(rot)
return x
else:
if rotary_dim < head_dim:
out = x.detach().clone()
if rotary_dim < head_dim and not in_place:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
cat_cos, cat_sin = ctx.saved_tensors
*_, seqlen, _, head_dim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
rotary_dim = cat_cos.shape[-1]

out = _apply_rotary(
do[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
dx_out = torch_npu.npu_rotary_mul(
do[..., :rotary_dim], cat_cos, torch.neg(cat_sin)
)

if ctx.in_place:
do[..., :rotary_dim].copy_(out[..., :rotary_dim])
do[..., :rotary_dim].copy_(dx_out)
return do, None, None, None, None
else:
if rotary_dim < head_dim:
out[..., rotary_dim:].copy(do[..., rotary_dim:])
return out, None, None, None, None
dx = do.detach().clone()
dx[..., :rotary_dim].copy_(dx_out)
return dx, None, None, None, None
3 changes: 2 additions & 1 deletion deeplink_ext/internevo_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import ApplyRotaryEmb
# from ._rotary_embedding_npu import ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb
else:
Expand Down
127 changes: 41 additions & 86 deletions deeplink_ext/interntrain_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2024, DeepLink.
# Copyright (c) 2024, InternEvo.

import torch
import torch_npu
Expand All @@ -8,38 +7,16 @@
__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


class ApplyRotaryEmb(torch.autograd.Function):
"""
ApplyRotaryEmb
Apply rotary positional embedding to input tensor x.
Args:
x (Tensor): Input tensor x is of shape [seq_length, ... , dim]
cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim]
sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""

@staticmethod
Expand All @@ -59,34 +36,34 @@ def forward(ctx, x, cos, sin, interleaved=False):
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
out = torch.empty_like(x)

re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")
out = _apply_rotary(
x[..., :rotary_dim],
re_cos,
re_sin,
False,
interleaved,
)

cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)

rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin)
out[..., :rotary_dim].copy_(rot)
if rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(re_cos, re_sin)

ctx.save_for_backward(cat_cos, cat_sin)
ctx.interleaved = interleaved
return out

@staticmethod
def backward(ctx, do):
re_cos, re_sin = ctx.saved_tensors
cat_cos, cat_sin = ctx.saved_tensors
headdim = do.shape[-1]
rotary_dim = re_cos.shape[-1]
rotary_dim *= 2
dx = _apply_rotary(
do[..., :rotary_dim],
re_cos,
re_sin,
True,
ctx.interleaved,
rotary_dim = cat_cos.shape[-1]

dx = torch.empty_like(do)
dx_rot = torch_npu.npu_rotary_mul(
do[..., :rotary_dim], cat_cos, torch.neg(cat_sin)
)
dx.copy_(dx_rot)

if rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None
Expand Down Expand Up @@ -141,16 +118,10 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
if len(qkv.shape) == 4
else rearrange(sin[:seqlen], "s d -> s 1 d")
)

# qro
out = _apply_rotary(
q_ro,
re_cos,
re_sin,
False,
interleaved,
)
q_ro.copy_(out)
cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)
q_out = torch_npu.npu_rotary_mul(q_ro, cat_cos, cat_sin)
q_ro.copy_(q_out)

k_ro = (
qkv[:, 1, :, :rotary_dim]
Expand All @@ -167,50 +138,34 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
if len(qkv.shape) == 4
else rearrange(sin_k[:seqlen], "s d -> s 1 d")
)
out = _apply_rotary(
k_ro,
re_cos_k,
re_sin_k,
False,
interleaved,
)
k_ro.copy_(out)
cat_cos_k = torch.cat([re_cos_k, re_cos_k], -1)
cat_sin_k = torch.cat([re_sin_k, re_sin_k], -1)
k_out = torch_npu.npu_rotary_mul(k_ro, cat_cos_k, cat_sin_k)
k_ro.copy_(k_out)

ctx.save_for_backward(re_cos, re_sin, re_cos_k, re_sin_k)
ctx.save_for_backward(cat_cos, cat_sin, cat_cos_k, cat_sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
re_cos, re_sin, re_cos_k, re_sin_k = ctx.saved_tensors
rotary_dim = re_cos.shape[-1]
rotary_dim *= 2
cat_cos, cat_sin, cat_cos_k, cat_sin_k = ctx.saved_tensors
rotary_dim = cat_cos.shape[-1]

dq_ro = (
dqkv[:, 0, :, :rotary_dim]
if len(dqkv.shape) == 4
else dqkv[:, :, 0, :, :rotary_dim]
)
out = _apply_rotary(
dq_ro,
re_cos,
re_sin,
True,
ctx.interleaved,
)
dq_ro.copy_(out)
dq_out = torch_npu.npu_rotary_mul(dq_ro, cat_cos, torch.neg(cat_sin))
dq_ro.copy_(dq_out)

dk_ro = (
dqkv[:, 1, :, :rotary_dim]
if len(dqkv.shape) == 4
else dqkv[:, :, 1, :, :rotary_dim]
)
out = _apply_rotary(
dk_ro,
re_cos_k,
re_sin_k,
True,
ctx.interleaved,
)
dk_ro.copy_(out)
dk_out = torch_npu.npu_rotary_mul(dk_ro, cat_cos_k, torch.neg(cat_sin_k))
dk_ro.copy_(dk_out)

return dqkv, None, None, None, None, None
3 changes: 2 additions & 1 deletion deeplink_ext/interntrain_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._mixed_rms_norm_npu import MixedFusedRMSNorm
# from ._mixed_rms_norm_npu import MixedFusedRMSNorm
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
elif platform_type == PlatformType.TORCH_DIPU:
from ._mixed_rms_norm_dipu import MixedFusedRMSNorm
else:
Expand Down
4 changes: 3 additions & 1 deletion deeplink_ext/interntrain_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
# from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
else:
Expand Down

0 comments on commit 22bffd4

Please sign in to comment.