Skip to content

Commit

Permalink
refactor!(internlm): a new way to patch rotary based on flash_attn 2.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Mar 8, 2024
1 parent 3375758 commit 00a75fe
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 399 deletions.
9 changes: 3 additions & 6 deletions deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb
from .deeplink import apply_rotary
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n",
end="",
)
from .fallback import (
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
)
from .fallback import apply_rotary
from . import fallback

__all__ = ["DeepLinkApplyRotaryEmbQKV_", "DeepLinkApplyRotaryEmb", "fallback"]
__all__ = ["apply_rotary", "fallback"]
172 changes: 53 additions & 119 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,65 @@
# Copyright (c) 2024, DeepLink.

from typing import Optional, Union
import torch
from einops import rearrange
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "apply_rotary")

__all__ = ["apply_rotary"]

class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
)
q_ro = qkv[:, :, 0, :, :rotary_dim]
ext.apply_rotary(
q_ro,
q_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
k_ro,
k_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
interleaved = ctx.interleaved
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
ext.apply_rotary(
dq_ro,
dq_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
interleaved,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
dk_ro,
dk_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None:
raise NotImplementedError(
"apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet"
)
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
dx = torch.empty_like(do) if not inplace else do
dx_ro = dx[..., :rotary_dim]
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
ext.apply_rotary(
output,
x,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
conjugate,
interleaved,
)
return output
98 changes: 98 additions & 0 deletions deeplink_ext/internlm_ops/rotary/fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2024, DeepLink.

from typing import Optional, Union
import torch
from einops import rearrange, repeat

__all__ = ["apply_rotary"]


def _rotate_half(x: torch.Tensor, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)


def _apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved=False
):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""

data_type = x.dtype
x = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)

ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
).to(data_type)


def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None:
raise NotImplementedError(
"apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet"
)
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

if conjugate:
sin = -sin
out = _apply_rotary_emb_torch(x, cos[:seqlen], sin[:seqlen], interleaved)
if inplace:
x.copy_(out)
out = x
return out
5 changes: 0 additions & 5 deletions deeplink_ext/internlm_ops/rotary/fallback/__init__.py

This file was deleted.

Loading

0 comments on commit 00a75fe

Please sign in to comment.