diff --git a/deeplink_ext/internlm_ops/rotary/__init__.py b/deeplink_ext/internlm_ops/rotary/__init__.py index 8ca5f9e2..16eb639e 100644 --- a/deeplink_ext/internlm_ops/rotary/__init__.py +++ b/deeplink_ext/internlm_ops/rotary/__init__.py @@ -1,13 +1,14 @@ # Copyright (c) 2024, DeepLink. try: - from .deeplink import apply_rotary + from .deeplink import apply_rotary, RotaryEmbedding_AscendSpeed except: print( "[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", end="", ) from .fallback import apply_rotary + RotaryEmbedding_AscendSpeed = None from . import fallback -__all__ = ["apply_rotary", "fallback"] +__all__ = ["apply_rotary", "fallback", "RotaryEmbedding_AscendSpeed"] diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary/deeplink.py index 670a47b9..11e21423 100644 --- a/deeplink_ext/internlm_ops/rotary/deeplink.py +++ b/deeplink_ext/internlm_ops/rotary/deeplink.py @@ -62,4 +62,38 @@ def apply_rotary( conjugate, interleaved, ) + + +def apply_rotary_for_ascend_speed( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + output = torch.empty_like(x) + ext.apply_rotary( + output, + x, + cos, + sin, + conjugate, + interleaved + ) return output + +class RotaryEmbedding_AscendSpeed(torch.autograd.Function): + @staticmethod + def forward(ctx, t, cos, sin): + ctx.save_for_backward(cos, sin) + return apply_rotary_for_ascend_speed(t, cos, sin) + + + @staticmethod + def backward(ctx, t): + cos, sin = ctx.saved_tensors + return apply_rotary_for_ascend_speed(t, cos, sin, conjugate=True), None, None \ No newline at end of file