-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor!(internlm): a new way to patch rotary based on flash_attn 2.2.1
- Loading branch information
Showing
7 changed files
with
173 additions
and
399 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,13 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
try: | ||
from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb | ||
from .deeplink import apply_rotary | ||
except: | ||
print( | ||
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", | ||
end="", | ||
) | ||
from .fallback import ( | ||
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_, | ||
ApplyRotaryEmb as DeepLinkApplyRotaryEmb, | ||
) | ||
from .fallback import apply_rotary | ||
from . import fallback | ||
|
||
__all__ = ["DeepLinkApplyRotaryEmbQKV_", "DeepLinkApplyRotaryEmb", "fallback"] | ||
__all__ = ["apply_rotary", "fallback"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,131 +1,65 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
from typing import Optional, Union | ||
import torch | ||
from einops import rearrange | ||
import deeplink_ext.cpp_extensions as ext | ||
|
||
assert hasattr(ext, "apply_rotary") | ||
|
||
__all__ = ["apply_rotary"] | ||
|
||
class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): | ||
batch, seqlen, three, nheads, headdim = qkv.shape | ||
assert three == 3 | ||
rotary_seqlen, rotary_dim = cos.shape | ||
rotary_dim *= 2 | ||
assert rotary_dim <= headdim | ||
assert seqlen <= rotary_seqlen | ||
cos_k = cos if cos_k is None else cos_k | ||
sin_k = sin if sin_k is None else sin_k | ||
assert ( | ||
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) | ||
) | ||
q_ro = qkv[:, :, 0, :, :rotary_dim] | ||
ext.apply_rotary( | ||
q_ro, | ||
q_ro, | ||
rearrange(cos[:seqlen], "s d -> s 1 d"), | ||
rearrange(sin[:seqlen], "s d -> s 1 d"), | ||
False, | ||
interleaved, | ||
) | ||
k_ro = qkv[:, :, 1, :, :rotary_dim] | ||
ext.apply_rotary( | ||
k_ro, | ||
k_ro, | ||
rearrange(cos[:seqlen], "s d -> s 1 d"), | ||
rearrange(sin[:seqlen], "s d -> s 1 d"), | ||
False, | ||
interleaved, | ||
) | ||
ctx.save_for_backward(cos, sin, cos_k, sin_k) | ||
ctx.interleaved = interleaved | ||
return qkv | ||
|
||
@staticmethod | ||
def backward(ctx, dqkv): | ||
cos, sin, cos_k, sin_k = ctx.saved_tensors | ||
interleaved = ctx.interleaved | ||
_, seqlen, _, _, headdim = dqkv.shape | ||
rotary_dim = cos.shape[-1] | ||
rotary_dim *= 2 | ||
dq_ro = dqkv[:, :, 0, :, :rotary_dim] | ||
ext.apply_rotary( | ||
dq_ro, | ||
dq_ro, | ||
rearrange(cos[:seqlen], "s d -> s 1 d"), | ||
rearrange(sin[:seqlen], "s d -> s 1 d"), | ||
True, | ||
interleaved, | ||
) | ||
dk_ro = dqkv[:, :, 1, :, :rotary_dim] | ||
ext.apply_rotary( | ||
dk_ro, | ||
dk_ro, | ||
rearrange(cos[:seqlen], "s d -> s 1 d"), | ||
rearrange(sin[:seqlen], "s d -> s 1 d"), | ||
True, | ||
interleaved, | ||
) | ||
return dqkv, None, None, None, None, None | ||
|
||
|
||
class DeepLinkApplyRotaryEmb(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x, cos, sin, interleaved=False, inplace=False): | ||
""" | ||
x: (batch_size, seqlen, nheads, headdim) | ||
cos, sin: (seqlen, rotary_dim / 2) | ||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead | ||
of 1st half and 2nd half (GPT-NeoX style). | ||
rotary_dim must be <= headdim | ||
Apply rotary embedding to the first rotary_dim of x. | ||
""" | ||
batch, seqlen, nheads, headdim = x.shape | ||
rotary_seqlen, rotary_dim = cos.shape | ||
rotary_dim *= 2 | ||
assert rotary_dim <= headdim | ||
assert seqlen <= rotary_seqlen | ||
assert sin.shape == (rotary_seqlen, rotary_dim // 2) | ||
x_ro = x[..., :rotary_dim] | ||
out = torch.empty_like(x) if not inplace else x | ||
out_ro = out[..., :rotary_dim] | ||
|
||
ext.apply_rotary( | ||
out_ro, | ||
x_ro, | ||
rearrange(cos[:seqlen], "s d -> s 1 d"), | ||
rearrange(sin[:seqlen], "s d -> s 1 d"), | ||
False, | ||
interleaved, | ||
def apply_rotary( | ||
x: torch.Tensor, | ||
cos: torch.Tensor, | ||
sin: torch.Tensor, | ||
seqlen_offsets: Union[int, torch.Tensor] = 0, | ||
cu_seqlens: Optional[torch.Tensor] = None, | ||
max_seqlen: Optional[int] = None, | ||
interleaved=False, | ||
inplace=False, | ||
conjugate=False, | ||
) -> torch.Tensor: | ||
""" | ||
Arguments: | ||
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None | ||
else (total_seqlen, nheads, headdim). | ||
cos: (seqlen_ro, rotary_dim / 2) | ||
sin: (seqlen_ro, rotary_dim / 2) | ||
seqlen_offsets: integer or integer tensor of size (batch,) | ||
cu_seqlens: (batch + 1,) or None | ||
max_seqlen: int | ||
Returns: | ||
y: (batch, seqlen, nheads, headdim) | ||
""" | ||
if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: | ||
raise NotImplementedError( | ||
"apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" | ||
) | ||
batch, seqlen, nheads, headdim = x.shape | ||
seqlen_ro, rotary_dim = cos.shape | ||
assert sin.shape == cos.shape | ||
rotary_dim *= 2 | ||
assert rotary_dim <= headdim, "rotary_dim must be <= headdim" | ||
assert headdim <= 256, "Only support headdim <= 256" | ||
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" | ||
assert ( | ||
cos.dtype == sin.dtype | ||
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" | ||
assert ( | ||
x.dtype == cos.dtype | ||
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" | ||
|
||
if not inplace and rotary_dim < headdim: | ||
out[..., rotary_dim:].copy_(x[..., rotary_dim:]) | ||
ctx.save_for_backward(cos, sin) | ||
ctx.interleaved = interleaved | ||
ctx.inplace = inplace | ||
return out if not inplace else x | ||
|
||
@staticmethod | ||
def backward(ctx, do): | ||
cos, sin = ctx.saved_tensors | ||
_, seqlen, _, headdim = do.shape | ||
rotary_dim = cos.shape[-1] | ||
rotary_dim *= 2 | ||
inplace = ctx.inplace | ||
do_ro = do[..., :rotary_dim] | ||
dx = torch.empty_like(do) if not inplace else do | ||
dx_ro = dx[..., :rotary_dim] | ||
ext.apply_rotary( | ||
dx_ro, | ||
do_ro, | ||
rearrange(cos[:seqlen], "s d -> s 1 d"), | ||
rearrange(sin[:seqlen], "s d -> s 1 d"), | ||
True, | ||
ctx.interleaved, | ||
) | ||
if not inplace and rotary_dim < headdim: | ||
dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) | ||
return dx, None, None, None, None | ||
output = torch.empty_like(x) if not inplace else x | ||
if rotary_dim < headdim and not inplace: | ||
output[..., rotary_dim:].copy_(x[..., rotary_dim:]) | ||
ext.apply_rotary( | ||
output, | ||
x, | ||
rearrange(cos[:seqlen], "s d -> s 1 d"), | ||
rearrange(sin[:seqlen], "s d -> s 1 d"), | ||
conjugate, | ||
interleaved, | ||
) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
from typing import Optional, Union | ||
import torch | ||
from einops import rearrange, repeat | ||
|
||
__all__ = ["apply_rotary"] | ||
|
||
|
||
def _rotate_half(x: torch.Tensor, interleaved=False): | ||
if not interleaved: | ||
x1, x2 = x.chunk(2, dim=-1) | ||
return torch.cat((-x2, x1), dim=-1) | ||
else: | ||
x1, x2 = x[..., ::2], x[..., 1::2] | ||
return rearrange( | ||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 | ||
) | ||
|
||
|
||
def _apply_rotary_emb_torch( | ||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved=False | ||
): | ||
""" | ||
x: (batch_size, seqlen, nheads, headdim) | ||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) | ||
""" | ||
|
||
data_type = x.dtype | ||
x = x.to(torch.float32) | ||
cos = cos.to(torch.float32) | ||
sin = sin.to(torch.float32) | ||
|
||
ro_dim = cos.shape[-1] * 2 | ||
assert ro_dim <= x.shape[-1] | ||
cos = repeat( | ||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" | ||
) | ||
sin = repeat( | ||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" | ||
) | ||
return torch.cat( | ||
[ | ||
x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin, | ||
x[..., ro_dim:], | ||
], | ||
dim=-1, | ||
).to(data_type) | ||
|
||
|
||
def apply_rotary( | ||
x: torch.Tensor, | ||
cos: torch.Tensor, | ||
sin: torch.Tensor, | ||
seqlen_offsets: Union[int, torch.Tensor] = 0, | ||
cu_seqlens: Optional[torch.Tensor] = None, | ||
max_seqlen: Optional[int] = None, | ||
interleaved=False, | ||
inplace=False, | ||
conjugate=False, | ||
) -> torch.Tensor: | ||
""" | ||
Arguments: | ||
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None | ||
else (total_seqlen, nheads, headdim). | ||
cos: (seqlen_ro, rotary_dim / 2) | ||
sin: (seqlen_ro, rotary_dim / 2) | ||
seqlen_offsets: integer or integer tensor of size (batch,) | ||
cu_seqlens: (batch + 1,) or None | ||
max_seqlen: int | ||
Returns: | ||
y: (batch, seqlen, nheads, headdim) | ||
""" | ||
if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: | ||
raise NotImplementedError( | ||
"apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" | ||
) | ||
batch, seqlen, nheads, headdim = x.shape | ||
seqlen_ro, rotary_dim = cos.shape | ||
assert sin.shape == cos.shape | ||
rotary_dim *= 2 | ||
assert rotary_dim <= headdim, "rotary_dim must be <= headdim" | ||
assert headdim <= 256, "Only support headdim <= 256" | ||
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" | ||
assert ( | ||
cos.dtype == sin.dtype | ||
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" | ||
assert ( | ||
x.dtype == cos.dtype | ||
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" | ||
|
||
if conjugate: | ||
sin = -sin | ||
out = _apply_rotary_emb_torch(x, cos[:seqlen], sin[:seqlen], interleaved) | ||
if inplace: | ||
x.copy_(out) | ||
out = x | ||
return out |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.