Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!(internlm): migrate to internevo #50

Merged
merged 4 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ pip install -e .

## Usage

### InternLM
### InternEvo

适配版本 https://github.com/DeepLink-org/InternEvo/tree/deeplinkext

```python
import deeplink_ext.patch_internlm
Expand Down
10 changes: 5 additions & 5 deletions deeplink_ext/internlm_ops/mha/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def forward(
q,
kv,
causal=None,
cu_seqlens_q=None,
max_seqlen_q=None,
cu_seqlens=None,
max_seqlen=None,
cu_seqlens_k=None,
max_seqlen_k=None,
):
if cu_seqlens_q is None:
if cu_seqlens is None:
# padded
return DeepLinkMultiHeadAttentionKVPackedFunc.apply(
q,
Expand All @@ -98,9 +98,9 @@ def forward(
return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply(
q,
kv,
cu_seqlens_q,
cu_seqlens,
cu_seqlens_k,
max_seqlen_q,
max_seqlen,
max_seqlen_k,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
Expand Down
19 changes: 10 additions & 9 deletions deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def forward(
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd(
q,
kv[:, :, 0],
kv[:, :, 1],
kv[:, 0],
kv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
Expand All @@ -40,7 +40,8 @@ def forward(
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state()
)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen_q
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)
Expand All @@ -63,20 +64,20 @@ def backward(ctx, dout):
ext.mha_varlen_bwd(
dout,
q,
kv[:, :, 0],
kv[:, :, 1],
kv[:, 0],
kv[:, 1],
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.causal,
rng,
ctx.softmax_scale,
dq,
dkv[:, :, 0],
dkv[:, :, 1],
dkv[:, 0],
dkv[:, 1],
)
return dq, dkv, None, None, None, None, None, None, None, None
9 changes: 3 additions & 6 deletions deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb
from .deeplink import apply_rotary
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n",
end="",
)
from .fallback import (
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
)
from .fallback import apply_rotary
from . import fallback

__all__ = ["DeepLinkApplyRotaryEmbQKV_", "DeepLinkApplyRotaryEmb", "fallback"]
__all__ = ["apply_rotary", "fallback"]
172 changes: 53 additions & 119 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,65 @@
# Copyright (c) 2024, DeepLink.

from typing import Optional, Union
import torch
from einops import rearrange
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "apply_rotary")

__all__ = ["apply_rotary"]

class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
)
q_ro = qkv[:, :, 0, :, :rotary_dim]
ext.apply_rotary(
q_ro,
q_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
k_ro,
k_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
interleaved = ctx.interleaved
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
ext.apply_rotary(
dq_ro,
dq_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
interleaved,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
dk_ro,
dk_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None:
raise NotImplementedError(
"apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet"
)
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
dx = torch.empty_like(do) if not inplace else do
dx_ro = dx[..., :rotary_dim]
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
ext.apply_rotary(
output,
x,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
conjugate,
interleaved,
)
return output
98 changes: 98 additions & 0 deletions deeplink_ext/internlm_ops/rotary/fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2024, DeepLink.

from typing import Optional, Union
import torch
from einops import rearrange, repeat

__all__ = ["apply_rotary"]


def _rotate_half(x: torch.Tensor, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)


def _apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved=False
):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""

data_type = x.dtype
x = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)

ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
).to(data_type)


def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None:
raise NotImplementedError(
"apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet"
)
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

if conjugate:
sin = -sin
out = _apply_rotary_emb_torch(x, cos[:seqlen], sin[:seqlen], interleaved)
if inplace:
x.copy_(out)
out = x
return out
5 changes: 0 additions & 5 deletions deeplink_ext/internlm_ops/rotary/fallback/__init__.py

This file was deleted.

Loading
Loading