Skip to content

Commit

Permalink
feat(internlm): Implement ApplyRotaryEmb fallback (#51)
Browse files Browse the repository at this point in the history
* Add ApplyRotaryEmb fallback

* fix _force_fallback when failing to import cpp_extensions
  • Loading branch information
jfxu-st authored Mar 8, 2024
1 parent 3d42957 commit 3375758
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 4 deletions.
78 changes: 75 additions & 3 deletions deeplink_ext/internlm_ops/rotary/fallback/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,79 @@ def backward(ctx, dqkv):
return dqkv, None, None, None, None, None


class ApplyRotaryEmb:
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def apply(*args, **kwargs):
raise NotImplementedError("fallback.ApplyRotaryEmb is not implemented yet")
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""

assert (
interleaved == False
), "Interleaved rotary embedding fallback is not supported yet"
assert (
inplace == False
), "Inplace rotary embedding fallback is not supported yet"

batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = (
x_ro.chunk(2, dim=-1)
if not interleaved
else (x_ro[..., ::2], x_ro[..., 1::2])
)
out = torch.empty_like(x)

o1, o2 = apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
)

out[..., :rotary_dim] = torch.cat((o1, o2), dim=-1)
if rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do)

dx1, dx2 = apply_rotary(
do1,
do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
)

dx[..., :rotary_dim] = torch.cat((dx1, dx2), dim=-1)

if rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
10 changes: 9 additions & 1 deletion deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ def _force_fallback():
print(
"[deeplink_ext] force_fallback is set, removing everything from cpp_extensions"
)
import deeplink_ext.cpp_extensions as cpp_ext
try:
import deeplink_ext.cpp_extensions as cpp_ext
except Exception as e:
print(
"[deeplink_ext] WARNING: failed to import deeplink_ext.cpp_extensions, "
"so everything will be falled back to pure python implementation. "
"Please check this import failure if you are using torch_dipu."
)
return

for attr in dir(cpp_ext):
if not attr.startswith("__") and callable(getattr(cpp_ext, attr)):
Expand Down

0 comments on commit 3375758

Please sign in to comment.