Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/DeepLink-org/DeepLinkExt in…
Browse files Browse the repository at this point in the history
…to wx/support_flash_attention_for_ascend
  • Loading branch information
POI-WX committed Mar 14, 2024
2 parents 5bc1450 + 0528ab2 commit 21698b1
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 435 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ pip install -e .

## Usage

### InternLM
### InternEvo

适配版本 https://github.com/DeepLink-org/InternEvo/tree/deeplinkext

```python
import deeplink_ext.patch_internlm
Expand Down
34 changes: 17 additions & 17 deletions deeplink_ext/internlm_ops/mha/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def forward(
q,
kv,
causal=None,
cu_seqlens_q=None,
max_seqlen_q=None,
cu_seqlens=None,
max_seqlen=None,
cu_seqlens_k=None,
max_seqlen_k=None,
):
if cu_seqlens_q is None:
if cu_seqlens is None:
# padded
# return DeepLinkMultiHeadAttentionKVPackedFunc.apply(
# q,
Expand All @@ -110,17 +110,17 @@ def forward(
self.softmax_scale,
causal if causal is not None else self.causal,
)
# else:
# # unpadded
# return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply(
# q,
# kv,
# cu_seqlens_q,
# cu_seqlens_k,
# max_seqlen_q,
# max_seqlen_k,
# self.dropout_p if self.training else 0.0,
# self.softmax_scale,
# causal if causal is not None else self.causal,
# False,
# )
else:
# unpadded
return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply(
q,
kv,
cu_seqlens,
cu_seqlens_k,
max_seqlen,
max_seqlen_k,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal if causal is not None else self.causal,
False,
)
19 changes: 10 additions & 9 deletions deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def forward(
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd(
q,
kv[:, :, 0],
kv[:, :, 1],
kv[:, 0],
kv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
Expand All @@ -40,7 +40,8 @@ def forward(
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state()
)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen_q
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)
Expand All @@ -63,20 +64,20 @@ def backward(ctx, dout):
ext.mha_varlen_bwd(
dout,
q,
kv[:, :, 0],
kv[:, :, 1],
kv[:, 0],
kv[:, 1],
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.causal,
rng,
ctx.softmax_scale,
dq,
dkv[:, :, 0],
dkv[:, :, 1],
dkv[:, 0],
dkv[:, 1],
)
return dq, dkv, None, None, None, None, None, None, None, None
9 changes: 3 additions & 6 deletions deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb
from .deeplink import apply_rotary
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n",
end="",
)
from .fallback import (
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
)
from .fallback import apply_rotary
from . import fallback

__all__ = ["DeepLinkApplyRotaryEmbQKV_", "DeepLinkApplyRotaryEmb", "fallback"]
__all__ = ["apply_rotary", "fallback"]
172 changes: 53 additions & 119 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,65 @@
# Copyright (c) 2024, DeepLink.

from typing import Optional, Union
import torch
from einops import rearrange
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "apply_rotary")

__all__ = ["apply_rotary"]

class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
)
q_ro = qkv[:, :, 0, :, :rotary_dim]
ext.apply_rotary(
q_ro,
q_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
k_ro,
k_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
interleaved = ctx.interleaved
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
ext.apply_rotary(
dq_ro,
dq_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
interleaved,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
dk_ro,
dk_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None:
raise NotImplementedError(
"apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet"
)
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
dx = torch.empty_like(do) if not inplace else do
dx_ro = dx[..., :rotary_dim]
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
ext.apply_rotary(
output,
x,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
conjugate,
interleaved,
)
return output
Loading

0 comments on commit 21698b1

Please sign in to comment.