diff --git a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py index 221e0573..abf63af6 100644 --- a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py +++ b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py @@ -115,7 +115,79 @@ def backward(ctx, dqkv): return dqkv, None, None, None, None, None -class ApplyRotaryEmb: +class ApplyRotaryEmb(torch.autograd.Function): @staticmethod - def apply(*args, **kwargs): - raise NotImplementedError("fallback.ApplyRotaryEmb is not implemented yet") + def forward(ctx, x, cos, sin, interleaved=False, inplace=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + + assert ( + interleaved == False + ), "Interleaved rotary embedding fallback is not supported yet" + assert ( + inplace == False + ), "Inplace rotary embedding fallback is not supported yet" + + batch, seqlen, nheads, headdim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + x1, x2 = ( + x_ro.chunk(2, dim=-1) + if not interleaved + else (x_ro[..., ::2], x_ro[..., 1::2]) + ) + out = torch.empty_like(x) + + o1, o2 = apply_rotary( + x1, + x2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + ) + + out[..., :rotary_dim] = torch.cat((o1, o2), dim=-1) + if rotary_dim < headdim: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + ctx.inplace = inplace + return out + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + _, seqlen, _, headdim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + do_ro = do[..., :rotary_dim] + do1, do2 = ( + do_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (do_ro[..., ::2], do_ro[..., 1::2]) + ) + dx = torch.empty_like(do) + + dx1, dx2 = apply_rotary( + do1, + do2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + ) + + dx[..., :rotary_dim] = torch.cat((dx1, dx2), dim=-1) + + if rotary_dim < headdim: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + return dx, None, None, None, None diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index d98c43fa..6e4e0a9f 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -37,7 +37,15 @@ def _force_fallback(): print( "[deeplink_ext] force_fallback is set, removing everything from cpp_extensions" ) - import deeplink_ext.cpp_extensions as cpp_ext + try: + import deeplink_ext.cpp_extensions as cpp_ext + except Exception as e: + print( + "[deeplink_ext] WARNING: failed to import deeplink_ext.cpp_extensions, " + "so everything will be falled back to pure python implementation. " + "Please check this import failure if you are using torch_dipu." + ) + return for attr in dir(cpp_ext): if not attr.startswith("__") and callable(getattr(cpp_ext, attr)):