diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 1019187..4e27d04 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -7,35 +7,6 @@ __all__ = ["ApplyRotaryEmb"] -def _unsqueeze_to_4d(x: torch.Tensor): - while x.dim() < 4: - x = x.unsqueeze(0) - return x - - -def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved): - assert interleaved == False, "interleaved not support by torch_npu" - - x_view = _unsqueeze_to_4d(x) - cos_view = _unsqueeze_to_4d(cos) - sin_view = _unsqueeze_to_4d(sin) - - cos_cat = torch.cat([cos_view, cos_view], -1) - sin_cat = torch.cat([sin_view, sin_view], -1) - - if confj: - sin_cat.neg_() - - x_view_chunks = x_view.chunk(2, -1) - x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1) - - cos_x = torch.mul(cos_cat, x_view) - sin_x = torch.mul(sin_cat, x_view_new) - out = cos_x + sin_x - - return out - - # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 class ApplyRotaryEmb(torch.autograd.Function): """ @@ -67,45 +38,38 @@ def forward( assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - out = _apply_rotary( - x[..., :rotary_dim], - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, - ) + re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") + re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") + + cat_cos = torch.cat([re_cos, re_cos], -1) + cat_sin = torch.cat([re_sin, re_sin], -1) - ctx.save_for_backward(cos, sin) + rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) + ctx.save_for_backward(cat_cos, cat_sin) ctx.interleaved = interleaved ctx.in_place = in_place - if in_place: - x[..., :rotary_dim].copy_(out[..., :rotary_dim]) + x[..., :rotary_dim].copy_(rot) return x else: - if rotary_dim < head_dim: + out = x.detach().clone() + if rotary_dim < head_dim and not in_place: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) return out @staticmethod def backward(ctx, do): - cos, sin = ctx.saved_tensors + cat_cos, cat_sin = ctx.saved_tensors *_, seqlen, _, head_dim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 + rotary_dim = cat_cos.shape[-1] - out = _apply_rotary( - do[..., :rotary_dim], - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ctx.interleaved, + dx_out = torch_npu.npu_rotary_mul( + do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) ) - if ctx.in_place: - do[..., :rotary_dim].copy_(out[..., :rotary_dim]) + do[..., :rotary_dim].copy_(dx_out) return do, None, None, None, None else: - if rotary_dim < head_dim: - out[..., rotary_dim:].copy(do[..., rotary_dim:]) - return out, None, None, None, None + dx = do.detach().clone() + dx[..., :rotary_dim].copy_(dx_out) + return dx, None, None, None, None diff --git a/deeplink_ext/interntrain_ops/_rotary_embedding_npu.py b/deeplink_ext/interntrain_ops/_rotary_embedding_npu.py index 7bc546f..0784d21 100644 --- a/deeplink_ext/interntrain_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/interntrain_ops/_rotary_embedding_npu.py @@ -1,5 +1,4 @@ # Copyright (c) 2024, DeepLink. -# Copyright (c) 2024, InternEvo. import torch import torch_npu @@ -8,38 +7,16 @@ __all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"] -def _unsqueeze_to_4d(x: torch.Tensor): - while x.dim() < 4: - x = x.unsqueeze(0) - return x - - -def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved): - assert interleaved == False, "interleaved not support by torch_npu" - - x_view = _unsqueeze_to_4d(x) - cos_view = _unsqueeze_to_4d(cos) - sin_view = _unsqueeze_to_4d(sin) - - cos_cat = torch.cat([cos_view, cos_view], -1) - sin_cat = torch.cat([sin_view, sin_view], -1) - - if confj: - sin_cat.neg_() - - x_view_chunks = x_view.chunk(2, -1) - x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1) - - cos_x = torch.mul(cos_cat, x_view) - sin_x = torch.mul(sin_cat, x_view_new) - out = cos_x + sin_x - - return out - - class ApplyRotaryEmb(torch.autograd.Function): """ - ApplyRotaryEmb + Apply rotary positional embedding to input tensor x. + Args: + x (Tensor): Input tensor x is of shape [seq_length, ... , dim] + cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim] + sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE """ @staticmethod @@ -59,34 +36,34 @@ def forward(ctx, x, cos, sin, interleaved=False): assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) out = torch.empty_like(x) + re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - out = _apply_rotary( - x[..., :rotary_dim], - re_cos, - re_sin, - False, - interleaved, - ) + + cat_cos = torch.cat([re_cos, re_cos], -1) + cat_sin = torch.cat([re_sin, re_sin], -1) + + rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) + out[..., :rotary_dim].copy_(rot) if rotary_dim < headdim: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(re_cos, re_sin) + + ctx.save_for_backward(cat_cos, cat_sin) ctx.interleaved = interleaved return out @staticmethod def backward(ctx, do): - re_cos, re_sin = ctx.saved_tensors + cat_cos, cat_sin = ctx.saved_tensors headdim = do.shape[-1] - rotary_dim = re_cos.shape[-1] - rotary_dim *= 2 - dx = _apply_rotary( - do[..., :rotary_dim], - re_cos, - re_sin, - True, - ctx.interleaved, + rotary_dim = cat_cos.shape[-1] + + dx = torch.empty_like(do) + dx_rot = torch_npu.npu_rotary_mul( + do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) ) + dx.copy_(dx_rot) + if rotary_dim < headdim: dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) return dx, None, None, None @@ -141,16 +118,10 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): if len(qkv.shape) == 4 else rearrange(sin[:seqlen], "s d -> s 1 d") ) - - # qro - out = _apply_rotary( - q_ro, - re_cos, - re_sin, - False, - interleaved, - ) - q_ro.copy_(out) + cat_cos = torch.cat([re_cos, re_cos], -1) + cat_sin = torch.cat([re_sin, re_sin], -1) + q_out = torch_npu.npu_rotary_mul(q_ro, cat_cos, cat_sin) + q_ro.copy_(q_out) k_ro = ( qkv[:, 1, :, :rotary_dim] @@ -167,50 +138,34 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): if len(qkv.shape) == 4 else rearrange(sin_k[:seqlen], "s d -> s 1 d") ) - out = _apply_rotary( - k_ro, - re_cos_k, - re_sin_k, - False, - interleaved, - ) - k_ro.copy_(out) + cat_cos_k = torch.cat([re_cos_k, re_cos_k], -1) + cat_sin_k = torch.cat([re_sin_k, re_sin_k], -1) + k_out = torch_npu.npu_rotary_mul(k_ro, cat_cos_k, cat_sin_k) + k_ro.copy_(k_out) - ctx.save_for_backward(re_cos, re_sin, re_cos_k, re_sin_k) + ctx.save_for_backward(cat_cos, cat_sin, cat_cos_k, cat_sin_k) ctx.interleaved = interleaved return qkv @staticmethod def backward(ctx, dqkv): - re_cos, re_sin, re_cos_k, re_sin_k = ctx.saved_tensors - rotary_dim = re_cos.shape[-1] - rotary_dim *= 2 + cat_cos, cat_sin, cat_cos_k, cat_sin_k = ctx.saved_tensors + rotary_dim = cat_cos.shape[-1] dq_ro = ( dqkv[:, 0, :, :rotary_dim] if len(dqkv.shape) == 4 else dqkv[:, :, 0, :, :rotary_dim] ) - out = _apply_rotary( - dq_ro, - re_cos, - re_sin, - True, - ctx.interleaved, - ) - dq_ro.copy_(out) + dq_out = torch_npu.npu_rotary_mul(dq_ro, cat_cos, torch.neg(cat_sin)) + dq_ro.copy_(dq_out) dk_ro = ( dqkv[:, 1, :, :rotary_dim] if len(dqkv.shape) == 4 else dqkv[:, :, 1, :, :rotary_dim] ) - out = _apply_rotary( - dk_ro, - re_cos_k, - re_sin_k, - True, - ctx.interleaved, - ) - dk_ro.copy_(out) + dk_out = torch_npu.npu_rotary_mul(dk_ro, cat_cos_k, torch.neg(cat_sin_k)) + dk_ro.copy_(dk_out) + return dqkv, None, None, None, None, None