Skip to content

Commit

Permalink
optimize when interleaved is True
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Oct 22, 2024
1 parent 49d1a9d commit 0eebc4b
Showing 1 changed file with 52 additions and 30 deletions.
82 changes: 52 additions & 30 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,44 @@

import torch
import torch_npu
from einops import rearrange, repeat

__all__ = ["ApplyRotaryEmb"]


def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35
class ApplyRotaryEmb(torch.autograd.Function):
"""
Expand Down Expand Up @@ -37,27 +71,25 @@ def forward(
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

# "s d -> 1 s 1 d"
cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
x_ro = x[..., :rotary_dim]
if interleaved:
cos = cos[:seqlen]
sin = sin[:seqlen]
else:
# "s d -> 1 s 1 d"
cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2)
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place
if interleaved:
x_in = torch.cat([x_ro[..., ::2], x_ro[..., 1::2]], dim=-1)
out_ro = torch_npu.npu_rotary_mul(x_in, cos, sin)
out = apply_rotary_emb_torch(x, cos, sin, interleaved)
if in_place:
x_ro[..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)])
x_ro[..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :])
x.copy_(out)
return x
out = torch.empty_like(x)
out[..., :rotary_dim][..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)])
out[..., :rotary_dim][..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :])
if rotary_dim < head_dim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out
else:
return out
else:
x_ro = x[..., :rotary_dim]
out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin)
if in_place:
x[..., :rotary_dim].copy_(out_ro)
Expand All @@ -74,27 +106,17 @@ def backward(ctx, grad_out):
cos, sin = ctx.saved_tensors
rotary_dim = cos.shape[-1]
head_dim = grad_out.shape[-1]
grad_out_ro = grad_out[..., :rotary_dim]
if ctx.interleaved:
grad_out_in = torch.cat(
[grad_out_ro[..., ::2], grad_out_ro[..., 1::2]], dim=-1
grad_input = apply_rotary_emb_torch(
grad_out, cos, torch.neg(sin), ctx.interleaved
)
grad_input_ro = torch_npu.npu_rotary_mul(grad_out_in, cos, torch.neg(sin))
if ctx.in_place:
grad_out_ro[..., ::2].copy_(grad_input_ro[..., : int(rotary_dim / 2)])
grad_out_ro[..., 1::2].copy_(grad_input_ro[..., int(rotary_dim / 2) :])
grad_out.copy_(grad_input)
return grad_out, None, None, None, None
grad_input = torch.empty_like(grad_out)
grad_input[..., :rotary_dim][..., ::2].copy_(
grad_input_ro[..., : int(rotary_dim / 2)]
)
grad_input[..., :rotary_dim][..., 1::2].copy_(
grad_input_ro[..., int(rotary_dim / 2) :]
)
if rotary_dim < head_dim:
grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:])
return grad_input, None, None, None, None
else:
return grad_input, None, None, None, None
else:
grad_out_ro = grad_out[..., :rotary_dim]
grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin))
if ctx.in_place:
grad_out[..., :rotary_dim].copy_(grad_input_ro)
Expand Down

0 comments on commit 0eebc4b

Please sign in to comment.