Skip to content

Commit

Permalink
fix bug of rope for npu and support the interleaved option
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Oct 21, 2024
1 parent 9eb52a2 commit 49d1a9d
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 64 deletions.
90 changes: 61 additions & 29 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch_npu
from einops import rearrange

__all__ = ["ApplyRotaryEmb"]

Expand Down Expand Up @@ -38,38 +37,71 @@ def forward(
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")

cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)

rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin)
ctx.save_for_backward(cat_cos, cat_sin)
# "s d -> 1 s 1 d"
cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
x_ro = x[..., :rotary_dim]
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place
if in_place:
x[..., :rotary_dim].copy_(rot)
return x
else:
out = x.detach().clone()
if rotary_dim < head_dim and not in_place:
if interleaved:
x_in = torch.cat([x_ro[..., ::2], x_ro[..., 1::2]], dim=-1)
out_ro = torch_npu.npu_rotary_mul(x_in, cos, sin)
if in_place:
x_ro[..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)])
x_ro[..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :])
return x
out = torch.empty_like(x)
out[..., :rotary_dim][..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)])
out[..., :rotary_dim][..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :])
if rotary_dim < head_dim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out
else:
out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin)
if in_place:
x[..., :rotary_dim].copy_(out_ro)
return x
if rotary_dim < head_dim:
out = torch.empty_like(x)
out[..., :rotary_dim].copy_(out_ro)
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out
return out_ro

@staticmethod
def backward(ctx, do):
cat_cos, cat_sin = ctx.saved_tensors
*_, seqlen, _, head_dim = do.shape
rotary_dim = cat_cos.shape[-1]

dx_out = torch_npu.npu_rotary_mul(
do[..., :rotary_dim], cat_cos, torch.neg(cat_sin)
)
if ctx.in_place:
do[..., :rotary_dim].copy_(dx_out)
return do, None, None, None, None
def backward(ctx, grad_out):
cos, sin = ctx.saved_tensors
rotary_dim = cos.shape[-1]
head_dim = grad_out.shape[-1]
grad_out_ro = grad_out[..., :rotary_dim]
if ctx.interleaved:
grad_out_in = torch.cat(
[grad_out_ro[..., ::2], grad_out_ro[..., 1::2]], dim=-1
)
grad_input_ro = torch_npu.npu_rotary_mul(grad_out_in, cos, torch.neg(sin))
if ctx.in_place:
grad_out_ro[..., ::2].copy_(grad_input_ro[..., : int(rotary_dim / 2)])
grad_out_ro[..., 1::2].copy_(grad_input_ro[..., int(rotary_dim / 2) :])
return grad_out, None, None, None, None
grad_input = torch.empty_like(grad_out)
grad_input[..., :rotary_dim][..., ::2].copy_(
grad_input_ro[..., : int(rotary_dim / 2)]
)
grad_input[..., :rotary_dim][..., 1::2].copy_(
grad_input_ro[..., int(rotary_dim / 2) :]
)
if rotary_dim < head_dim:
grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:])
return grad_input, None, None, None, None
else:
dx = do.detach().clone()
dx[..., :rotary_dim].copy_(dx_out)
return dx, None, None, None, None
grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin))
if ctx.in_place:
grad_out[..., :rotary_dim].copy_(grad_input_ro)
return grad_out, None, None, None, None
if rotary_dim < head_dim:
grad_input = torch.empty_like(grad_out)
grad_input[..., :rotary_dim].copy_(grad_input_ro)
grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:])
return grad_input, None, None, None, None
return grad_input_ro, None, None, None, None
3 changes: 1 addition & 2 deletions deeplink_ext/internevo_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
# from ._rotary_embedding_npu import ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from ._rotary_embedding_npu import ApplyRotaryEmb
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb
else:
Expand Down
67 changes: 34 additions & 33 deletions tests/internevo/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,41 @@

def test_ApplyRotaryEmb():
input_dtype_list = [torch.float16, torch.bfloat16]
interleaved = False
in_place_options = [False, True]
interleaved_options = [False, True]
for input_dtype in input_dtype_list:
for in_place in in_place_options:
input_ref = torch.randn(
1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True
)
input_ext = input_ref.clone().detach().requires_grad_()
cos = torch.randn(64, 32, dtype=input_dtype, device="cuda")
sin = torch.randn(64, 32, dtype=input_dtype, device="cuda")
for interleaved in interleaved_options:
input_ref = torch.randn(
1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True
)
input_ext = input_ref.clone().detach().requires_grad_()
cos = torch.randn(64, 32, dtype=input_dtype, device="cuda")
sin = torch.randn(64, 32, dtype=input_dtype, device="cuda")

output_ref, grad_ref = call_autograd_func(
ApplyRotaryEmbTorch,
"cuda",
input_dtype,
input_ref,
cos,
sin,
interleaved,
in_place,
)
output_ext, grad_ext = call_autograd_func(
ApplyRotaryEmb,
"cuda",
input_dtype,
input_ext,
cos,
sin,
interleaved,
in_place,
)
assert allclose(
output_ref, output_ext, rtol=1e-2, atol=5e-2
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!"
assert allclose(
grad_ref, grad_ext
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!"
output_ref, grad_ref = call_autograd_func(
ApplyRotaryEmbTorch,
"cuda",
input_dtype,
input_ref,
cos,
sin,
interleaved,
in_place,
)
output_ext, grad_ext = call_autograd_func(
ApplyRotaryEmb,
"cuda",
input_dtype,
input_ext,
cos,
sin,
interleaved,
in_place,
)
assert allclose(
output_ref, output_ext, rtol=1e-2, atol=5e-2
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!"
assert allclose(
grad_ref, grad_ext
), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!"

0 comments on commit 49d1a9d

Please sign in to comment.