From 49d1a9da35b696572212adbfe8d8b3006258bae6 Mon Sep 17 00:00:00 2001 From: POI-WX Date: Mon, 21 Oct 2024 07:28:36 +0000 Subject: [PATCH] fix bug of rope for npu and support the interleaved option --- .../internevo_ops/_rotary_embedding_npu.py | 90 +++++++++++++------ .../internevo_ops/rotary_embedding.py | 3 +- tests/internevo/test_rotary_embedding.py | 67 +++++++------- 3 files changed, 96 insertions(+), 64 deletions(-) diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 4e27d04..d9cc45a 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -2,7 +2,6 @@ import torch import torch_npu -from einops import rearrange __all__ = ["ApplyRotaryEmb"] @@ -38,38 +37,71 @@ def forward( assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) - - rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) - ctx.save_for_backward(cat_cos, cat_sin) + # "s d -> 1 s 1 d" + cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + x_ro = x[..., :rotary_dim] + ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place - if in_place: - x[..., :rotary_dim].copy_(rot) - return x - else: - out = x.detach().clone() - if rotary_dim < head_dim and not in_place: + if interleaved: + x_in = torch.cat([x_ro[..., ::2], x_ro[..., 1::2]], dim=-1) + out_ro = torch_npu.npu_rotary_mul(x_in, cos, sin) + if in_place: + x_ro[..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)]) + x_ro[..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :]) + return x + out = torch.empty_like(x) + out[..., :rotary_dim][..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)]) + out[..., :rotary_dim][..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :]) + if rotary_dim < head_dim: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) return out + else: + out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin) + if in_place: + x[..., :rotary_dim].copy_(out_ro) + return x + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + return out + return out_ro @staticmethod - def backward(ctx, do): - cat_cos, cat_sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cat_cos.shape[-1] - - dx_out = torch_npu.npu_rotary_mul( - do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) - ) - if ctx.in_place: - do[..., :rotary_dim].copy_(dx_out) - return do, None, None, None, None + def backward(ctx, grad_out): + cos, sin = ctx.saved_tensors + rotary_dim = cos.shape[-1] + head_dim = grad_out.shape[-1] + grad_out_ro = grad_out[..., :rotary_dim] + if ctx.interleaved: + grad_out_in = torch.cat( + [grad_out_ro[..., ::2], grad_out_ro[..., 1::2]], dim=-1 + ) + grad_input_ro = torch_npu.npu_rotary_mul(grad_out_in, cos, torch.neg(sin)) + if ctx.in_place: + grad_out_ro[..., ::2].copy_(grad_input_ro[..., : int(rotary_dim / 2)]) + grad_out_ro[..., 1::2].copy_(grad_input_ro[..., int(rotary_dim / 2) :]) + return grad_out, None, None, None, None + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim][..., ::2].copy_( + grad_input_ro[..., : int(rotary_dim / 2)] + ) + grad_input[..., :rotary_dim][..., 1::2].copy_( + grad_input_ro[..., int(rotary_dim / 2) :] + ) + if rotary_dim < head_dim: + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None else: - dx = do.detach().clone() - dx[..., :rotary_dim].copy_(dx_out) - return dx, None, None, None, None + grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin)) + if ctx.in_place: + grad_out[..., :rotary_dim].copy_(grad_input_ro) + return grad_out, None, None, None, None + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py index 1a2a36d..7764b9b 100644 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ b/deeplink_ext/internevo_ops/rotary_embedding.py @@ -4,8 +4,7 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._rotary_embedding_npu import ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb + from ._rotary_embedding_npu import ApplyRotaryEmb elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb else: diff --git a/tests/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py index 981c2f0..722577e 100644 --- a/tests/internevo/test_rotary_embedding.py +++ b/tests/internevo/test_rotary_embedding.py @@ -8,40 +8,41 @@ def test_ApplyRotaryEmb(): input_dtype_list = [torch.float16, torch.bfloat16] - interleaved = False in_place_options = [False, True] + interleaved_options = [False, True] for input_dtype in input_dtype_list: for in_place in in_place_options: - input_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") - sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") + for interleaved in interleaved_options: + input_ref = torch.randn( + 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True + ) + input_ext = input_ref.clone().detach().requires_grad_() + cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") + sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbTorch, - "cuda", - input_dtype, - input_ref, - cos, - sin, - interleaved, - in_place, - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmb, - "cuda", - input_dtype, - input_ext, - cos, - sin, - interleaved, - in_place, - ) - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" + output_ref, grad_ref = call_autograd_func( + ApplyRotaryEmbTorch, + "cuda", + input_dtype, + input_ref, + cos, + sin, + interleaved, + in_place, + ) + output_ext, grad_ext = call_autograd_func( + ApplyRotaryEmb, + "cuda", + input_dtype, + input_ext, + cos, + sin, + interleaved, + in_place, + ) + assert allclose( + output_ref, output_ext, rtol=1e-2, atol=5e-2 + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" + assert allclose( + grad_ref, grad_ext + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!"