Skip to content

Commit

Permalink
fix format and support rope
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jul 16, 2024
1 parent 35d3b14 commit 540ad6e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
14 changes: 13 additions & 1 deletion deeplink_ext/internevo_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@
except Exception as e:
print(_not_impl.format(op_name="flash attention"))

try:
from .rms_norm import MixedFusedRMSNorm
except:
print(
_not_impl.format(op_name="RMSNorm"),
)

try:
from .rotary_embedding import ApplyRotaryEmb
except:
print(_not_impl.format(op_name="rotary embedding"))

__all__ = [
"AdamW",
Expand All @@ -29,5 +40,6 @@
"flash_attn_varlen_qkvpacked_func",
"flash_attn_varlen_kvpacked_func",
"flash_attn_varlen_func",
"rms_norm",
"MixedFusedRMSNorm",
"ApplyRotaryEmb",
]
90 changes: 90 additions & 0 deletions deeplink_ext/internevo_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2024, DeepLink.

import torch
from einops import rearrange
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "apply_rotary")

__all__ = ["ApplyRotaryEmb"]


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35
class ApplyRotaryEmb(torch.autograd.Function):
"""
ApplyRotaryEmb
"""

@staticmethod
def forward(
ctx,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False,
in_place: bool = False,
):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
*_, seqlen, _, head_dim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2

assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

if in_place:
out = x
else:
out = torch.empty_like(x)

ext.apply_rotary(
out[..., :rotary_dim],
x[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)

if rotary_dim < head_dim and not in_place:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])

ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place

return out

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
*_, seqlen, _, head_dim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2

if ctx.in_place:
dx = do
else:
dx = torch.empty_like(do)

ext.apply_rotary(
dx[..., :rotary_dim],
do[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
)

if rotary_dim < head_dim and not ctx.in_place:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])

return dx, None, None, None, None

0 comments on commit 540ad6e

Please sign in to comment.