From 945321db7a0e45719b8b589e8e1d48b2aebf0861 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 4 Sep 2024 02:29:48 +0000 Subject: [PATCH 1/4] fix rotary embedding in ascend speed --- .../ascend_speed/_rotary_embedding_npu.py | 43 +++---------------- 1 file changed, 5 insertions(+), 38 deletions(-) diff --git a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py index 650da29..646f7e9 100644 --- a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py +++ b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py @@ -6,38 +6,6 @@ __all__ = ["RotaryEmbedding"] -def _unsqueeze_to_4d(x: torch.Tensor): - while x.dim() < 4: - x = x.unsqueeze(0) - return x - - -def apply_rotary(x: torch.Tensor, cos, sin, confj=False, interleaved=False): - assert interleaved == False, "interleaved not support by torch_npu" - - x_view = _unsqueeze_to_4d(x) - cos_view = _unsqueeze_to_4d(cos) - sin_view = _unsqueeze_to_4d(sin) - - cos_cat = torch.cat([cos_view, cos_view], -1) - sin_cat = torch.cat([sin_view, sin_view], -1) - - if confj: - sin_cat.neg_() - - x_view_chunks = x_view.chunk(2, -1) - x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1) - - print(cos_cat.shape) - print(x_view.shape) - - cos_x = torch.mul(cos_cat, x_view) - sin_x = torch.mul(sin_cat, x_view_new) - out = cos_x + sin_x - - return out - - class RotaryEmbedding(torch.autograd.Function): """ Apply rotary positional embedding to input tensor x. @@ -52,12 +20,11 @@ class RotaryEmbedding(torch.autograd.Function): @staticmethod def forward(ctx, x, cos, sin): - cos, _ = torch.chunk(cos, 2, -1) - sin, _ = torch.chunk(sin, 2, -1) - ctx.save_for_backward(cos, sin) - return apply_rotary(x, cos, sin) + out = torch_npu.npu_rotary_mul(x, cos, sin) + ctx.save_for_backward(out, cos, sin) + return out @staticmethod def backward(ctx, grad_output): - cos, sin = ctx.saved_tensors - return apply_rotary(grad_output, cos, sin, conjugate=True), None, None + out, cos, sin = ctx.saved_tensors + return torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], None, None From 51e48640917b62c13a732e5d2a954a61701d54e7 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 4 Sep 2024 02:38:39 +0000 Subject: [PATCH 2/4] refact test cases --- .../ascend_speed/test_rotary_embedding.py | 0 tests/conftest.py | 15 + .../{dipu => }/easyllm/test_rms_norm_dipu.py | 0 tests/fusion_result.json | 33 ++ .../internevo/test_flash_attention.py | 0 .../internevo/test_rotary_embedding.py | 0 .../internevo/test_varlen_flash_attention.py | 0 .../{dipu => }/interntrain/test_adamw_dipu.py | 0 .../interntrain/test_flash_attention.py | 0 tests/{dipu => }/interntrain/test_rms_norm.py | 3 - .../interntrain/test_rotary_embedding.py | 0 .../test_varlen_flash_attention.py | 0 tests/npu/easyllm/test_rms_norm_npu.py | 48 -- .../npu/internevo/test_flash_attention_npu.py | 130 ----- .../internevo/test_rotary_embedding_npu.py | 47 -- .../test_varlen_flash_attention_npu.py | 354 -------------- tests/npu/interntrain/test_adamw_npu.py | 63 --- .../interntrain/test_flash_attention_npu.py | 191 -------- tests/npu/interntrain/test_rms_norm_npu.py | 26 - .../interntrain/test_rotary_embedding_npu.py | 84 ---- .../test_varlen_flash_attention_npu.py | 458 ------------------ 21 files changed, 48 insertions(+), 1404 deletions(-) rename tests/{dipu => }/ascend_speed/test_rotary_embedding.py (100%) create mode 100644 tests/conftest.py rename tests/{dipu => }/easyllm/test_rms_norm_dipu.py (100%) create mode 100644 tests/fusion_result.json rename tests/{dipu => }/internevo/test_flash_attention.py (100%) rename tests/{dipu => }/internevo/test_rotary_embedding.py (100%) rename tests/{dipu => }/internevo/test_varlen_flash_attention.py (100%) rename tests/{dipu => }/interntrain/test_adamw_dipu.py (100%) rename tests/{dipu => }/interntrain/test_flash_attention.py (100%) rename tests/{dipu => }/interntrain/test_rms_norm.py (98%) rename tests/{dipu => }/interntrain/test_rotary_embedding.py (100%) rename tests/{dipu => }/interntrain/test_varlen_flash_attention.py (100%) delete mode 100644 tests/npu/easyllm/test_rms_norm_npu.py delete mode 100644 tests/npu/internevo/test_flash_attention_npu.py delete mode 100644 tests/npu/internevo/test_rotary_embedding_npu.py delete mode 100644 tests/npu/internevo/test_varlen_flash_attention_npu.py delete mode 100644 tests/npu/interntrain/test_adamw_npu.py delete mode 100644 tests/npu/interntrain/test_flash_attention_npu.py delete mode 100644 tests/npu/interntrain/test_rms_norm_npu.py delete mode 100644 tests/npu/interntrain/test_rotary_embedding_npu.py delete mode 100644 tests/npu/interntrain/test_varlen_flash_attention_npu.py diff --git a/tests/dipu/ascend_speed/test_rotary_embedding.py b/tests/ascend_speed/test_rotary_embedding.py similarity index 100% rename from tests/dipu/ascend_speed/test_rotary_embedding.py rename to tests/ascend_speed/test_rotary_embedding.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..cd2aae5 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest +import torch + +from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + +@pytest.fixture(scope='session', autouse=True) +def import_module(): + platform = deeplink_ext_get_platform_type() + if platform == PlatformType.TORCH_NPU: + import torch_npu + from torch_npu.contrib import transfer_to_npu + elif platform == PlatformType.TORCH_DIPU: + import torch_dipu + else: + raise ValueError("backend platform does not supported by deeplink_ext") diff --git a/tests/dipu/easyllm/test_rms_norm_dipu.py b/tests/easyllm/test_rms_norm_dipu.py similarity index 100% rename from tests/dipu/easyllm/test_rms_norm_dipu.py rename to tests/easyllm/test_rms_norm_dipu.py diff --git a/tests/fusion_result.json b/tests/fusion_result.json new file mode 100644 index 0000000..030df50 --- /dev/null +++ b/tests/fusion_result.json @@ -0,0 +1,33 @@ +[{ + "graph_fusion": { + "RefreshInt64ToInt32FusionPass": { + "effect_times": "1", + "match_times": "1" + } + }, + "session_and_graph_id": "0_0" +},{ + "graph_fusion": { + "RefreshInt64ToInt32FusionPass": { + "effect_times": "1", + "match_times": "1" + } + }, + "session_and_graph_id": "1_1" +},{ + "graph_fusion": { + "RefreshInt64ToInt32FusionPass": { + "effect_times": "1", + "match_times": "1" + } + }, + "session_and_graph_id": "2_2" +},{ + "graph_fusion": { + "RefreshInt64ToInt32FusionPass": { + "effect_times": "1", + "match_times": "1" + } + }, + "session_and_graph_id": "3_3" +}] \ No newline at end of file diff --git a/tests/dipu/internevo/test_flash_attention.py b/tests/internevo/test_flash_attention.py similarity index 100% rename from tests/dipu/internevo/test_flash_attention.py rename to tests/internevo/test_flash_attention.py diff --git a/tests/dipu/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py similarity index 100% rename from tests/dipu/internevo/test_rotary_embedding.py rename to tests/internevo/test_rotary_embedding.py diff --git a/tests/dipu/internevo/test_varlen_flash_attention.py b/tests/internevo/test_varlen_flash_attention.py similarity index 100% rename from tests/dipu/internevo/test_varlen_flash_attention.py rename to tests/internevo/test_varlen_flash_attention.py diff --git a/tests/dipu/interntrain/test_adamw_dipu.py b/tests/interntrain/test_adamw_dipu.py similarity index 100% rename from tests/dipu/interntrain/test_adamw_dipu.py rename to tests/interntrain/test_adamw_dipu.py diff --git a/tests/dipu/interntrain/test_flash_attention.py b/tests/interntrain/test_flash_attention.py similarity index 100% rename from tests/dipu/interntrain/test_flash_attention.py rename to tests/interntrain/test_flash_attention.py diff --git a/tests/dipu/interntrain/test_rms_norm.py b/tests/interntrain/test_rms_norm.py similarity index 98% rename from tests/dipu/interntrain/test_rms_norm.py rename to tests/interntrain/test_rms_norm.py index 2ea2bce..ac24b2c 100644 --- a/tests/dipu/interntrain/test_rms_norm.py +++ b/tests/interntrain/test_rms_norm.py @@ -34,6 +34,3 @@ def test_MixedFusedRMSNorm(): assert allclose( grad_ref, grad_ext, rtol=1e-2, atol=1e-2 ), f"When input dtype is {input_dtype} and weight dtype is {weight_dtype}, MixedRMSNorm fails to pass the backward test!" - - -test_MixedFusedRMSNorm() diff --git a/tests/dipu/interntrain/test_rotary_embedding.py b/tests/interntrain/test_rotary_embedding.py similarity index 100% rename from tests/dipu/interntrain/test_rotary_embedding.py rename to tests/interntrain/test_rotary_embedding.py diff --git a/tests/dipu/interntrain/test_varlen_flash_attention.py b/tests/interntrain/test_varlen_flash_attention.py similarity index 100% rename from tests/dipu/interntrain/test_varlen_flash_attention.py rename to tests/interntrain/test_varlen_flash_attention.py diff --git a/tests/npu/easyllm/test_rms_norm_npu.py b/tests/npu/easyllm/test_rms_norm_npu.py deleted file mode 100644 index b56f2ae..0000000 --- a/tests/npu/easyllm/test_rms_norm_npu.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu -from tests.core import calculate_fwd_and_bwd, allclose -from deeplink_ext.easyllm_ops.rms_norm import rms_norm -from deeplink_ext.easyllm_ops.rms_norm_fallback import rms_norm_torch - - -def test_rms_norm(): - input_dtype_list = [torch.float16, torch.bfloat16] - weight_dtype_list = [torch.float16, torch.bfloat16] - for input_dtype, weight_dtype in zip(input_dtype_list, weight_dtype_list): - hidden_states_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="npu", requires_grad=True - ) - hidden_states_ext = hidden_states_ref.clone().detach().requires_grad_(True) - - weight_ref = torch.nn.Parameter( - torch.ones( - list(hidden_states_ref.shape)[-1], dtype=weight_dtype, device="npu" - ), - requires_grad=True, - ) - weight_ext = weight_ref.clone().detach().requires_grad_(True) - - epsilon = 1e-5 - - output_ref, grad_ref = calculate_fwd_and_bwd( - rms_norm_torch, - hidden_states_ref, - weight_ref, - epsilon, - ) - - output_ext, grad_ext = calculate_fwd_and_bwd( - rms_norm, - hidden_states_ext, - weight_ext, - epsilon, - ) - - assert allclose( - output_ref, output_ext, rtol=1e-05, atol=1e-5 - ), f"When input dtype is {input_dtype} and weight dtype is {weight_dtype}, RMSNorm fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext, rtol=1e-2, atol=1e-2 - ), f"When input dtype is {input_dtype} and weight dtype is {weight_dtype}, RMSNorm fails to pass the backward test!" diff --git a/tests/npu/internevo/test_flash_attention_npu.py b/tests/npu/internevo/test_flash_attention_npu.py deleted file mode 100644 index 7da08f8..0000000 --- a/tests/npu/internevo/test_flash_attention_npu.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from tests.core import copy_to_cpu, allclose, calculate_fwd_and_bwd - -from deeplink_ext.internevo_ops.flash_attention_fallback import ( - flash_attn_qkvpacked_func_torch, - flash_attn_kvpacked_func_torch, - flash_attn_func_torch, -) -from deeplink_ext.internevo_ops.flash_attention import ( - flash_attn_qkvpacked_func, - flash_attn_kvpacked_func, - flash_attn_func, -) - - -def test_flash_attn_qkvpacked_func_mha(): - batch, seqlen, num_heads, headdim = [8, 32, 32, 64] - - qkv_gpu = torch.rand( - [batch, seqlen, 3, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - qkv_cpu = copy_to_cpu( - [ - qkv_gpu, - ] - ) - - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_qkvpacked_func_torch, - qkv_cpu[0], - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_qkvpacked_func, - qkv_gpu, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) - - -def test_flash_attn_kvpacked_func_gqa(): - batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64] - num_kv_heads = 8 - - q_gpu = torch.rand( - [batch, seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_gpu = torch.rand( - [batch, seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_kvpacked_func_torch, - q_cpu, - kv_cpu, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_kvpacked_func, - q_gpu, - kv_gpu, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) - - -def test_flash_attn_func_gqa(): - batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64] - num_kv_heads = 8 - - q_gpu = torch.rand( - [batch, seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - k_gpu = torch.rand( - [batch, seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - v_gpu = torch.rand( - [batch, seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_func_torch, - q_cpu, - k_cpu, - v_cpu, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_func, - q_gpu, - k_gpu, - v_gpu, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) diff --git a/tests/npu/internevo/test_rotary_embedding_npu.py b/tests/npu/internevo/test_rotary_embedding_npu.py deleted file mode 100644 index 9ba4a2b..0000000 --- a/tests/npu/internevo/test_rotary_embedding_npu.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from tests.core import call_autograd_func, allclose -from deeplink_ext.internevo_ops.rotary_embedding import ApplyRotaryEmb -from deeplink_ext.internevo_ops.rotary_embedding_fallback import ApplyRotaryEmbTorch - - -def test_ApplyRotaryEmb(): - input_dtype_list = [torch.float16, torch.bfloat16] - interleaved = False - in_place_options = [False, True] - for input_dtype in input_dtype_list: - for in_place in in_place_options: - input_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="npu", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="npu") - sin = torch.randn(64, 32, dtype=input_dtype, device="npu") - - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbTorch, - "npu", - input_dtype, - input_ref, - cos, - sin, - interleaved, - in_place, - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmb, - "npu", - input_dtype, - input_ext, - cos, - sin, - interleaved, - in_place, - ) - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" diff --git a/tests/npu/internevo/test_varlen_flash_attention_npu.py b/tests/npu/internevo/test_varlen_flash_attention_npu.py deleted file mode 100644 index cb222e2..0000000 --- a/tests/npu/internevo/test_varlen_flash_attention_npu.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from tests.core import allclose, calculate_fwd_and_bwd, copy_to_cpu - -from deeplink_ext.internevo_ops.flash_attention_fallback import ( - flash_attn_varlen_qkvpacked_func_torch, - flash_attn_varlen_kvpacked_func_torch, - flash_attn_varlen_func_torch, -) -from deeplink_ext.internevo_ops.flash_attention import ( - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func, -) - -# fmt: off -# latest sequence length is 20206-16110=4096 -cu_seqlens_max_length_4096 = [ - 0, 186, 382, 1259, 1464, 2547, 2705, 3495, 3854, 4696, 4762, 4885, 5118, 5355, 5503, 5760, 6168, 6353, - 8272, 8461, 9273, 9531, 9763, 9871, 10234, 10370, 10574, 10712, 11022, 11236, 11599, 11837, 12179, 12320, - 12560, 12731, 13038, 13180, 13477, 14025, 14742, 14872, 15131, 15773, 15967, 16110, 20206, -] -# fmt: on - - -def test_flash_attn_varlen_qkvpacked_func_mha(): - total_seqlen, num_heads, headdim = [256, 32, 64] - - qkv_gpu = torch.randn( - [total_seqlen, 3, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - qkv_cpu = copy_to_cpu( - [ - qkv_gpu, - ] - ) - - cu_seqlens_cpu = torch.tensor([0, 32, 64, 128, 256], dtype=torch.int32) - cu_seqlens_gpu = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - max_seqlen = 128 - - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_varlen_qkvpacked_func_torch, - qkv_cpu[0], - cu_seqlens_cpu, - max_seqlen, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_varlen_qkvpacked_func, - qkv_gpu, - cu_seqlens_gpu, - max_seqlen, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) - assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) - - -def test_flash_attn_varlen_qkvpacked_func_mha_long_max_seqlen(): - # Test function to verify if the module behaves correctly when the maximum sequence length exceeds 2048. - total_seqlen, num_heads, headdim = [20206, 2, 64] - - qkv_gpu = torch.randn( - [total_seqlen, 3, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - qkv_cpu = copy_to_cpu( - [ - qkv_gpu, - ] - ) - - cu_seqlens_cpu = torch.tensor(cu_seqlens_max_length_4096, dtype=torch.int32) - cu_seqlens_gpu = torch.tensor( - cu_seqlens_max_length_4096, dtype=torch.int32, device="npu" - ) - # the maximum sequence length is 4096 - max_seqlen = 4096 - - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_varlen_qkvpacked_func_torch, - qkv_cpu[0], - cu_seqlens_cpu, - max_seqlen, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_varlen_qkvpacked_func, - qkv_gpu, - cu_seqlens_gpu, - max_seqlen, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) - assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) - - -def test_flash_attn_varlen_kvpacked_func_gqa(): - total_seqlen, num_q_heads, headdim = [256, 32, 64] - num_kv_heads = 8 - - q_gpu = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_gpu = torch.randn( - [total_seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - - cu_seqlens_q_cpu = torch.tensor([0, 32, 64, 128, 256], dtype=torch.int32) - cu_seqlens_k_cpu = torch.tensor([0, 32, 64, 128, 256], dtype=torch.int32) - cu_seqlens_q_gpu = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - cu_seqlens_k_gpu = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - max_seqlen_q = 128 - max_seqlen_k = 128 - - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_varlen_kvpacked_func_torch, - q_cpu, - kv_cpu, - cu_seqlens_q_cpu, - cu_seqlens_k_cpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_varlen_kvpacked_func, - q_gpu, - kv_gpu, - cu_seqlens_q_gpu, - cu_seqlens_k_gpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-2) - - -def test_flash_attn_varlen_kvpacked_func_gqa_long_max_seqlen(): - # Test function to verify if the module behaves correctly when the maximum sequence length exceeds 2048. - total_seqlen, num_q_heads, headdim = [20206, 6, 64] - num_kv_heads = 2 - - q_gpu = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_gpu = torch.randn( - [total_seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - - cu_seqlens_q_cpu = torch.tensor(cu_seqlens_max_length_4096, dtype=torch.int32) - cu_seqlens_k_cpu = torch.tensor(cu_seqlens_max_length_4096, dtype=torch.int32) - cu_seqlens_q_gpu = torch.tensor( - cu_seqlens_max_length_4096, dtype=torch.int32, device="npu" - ) - cu_seqlens_k_gpu = torch.tensor( - cu_seqlens_max_length_4096, dtype=torch.int32, device="npu" - ) - # the maximum sequence length is 4096 - max_seqlen_q = 4096 - max_seqlen_k = 4096 - - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_varlen_kvpacked_func_torch, - q_cpu, - kv_cpu, - cu_seqlens_q_cpu, - cu_seqlens_k_cpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_varlen_kvpacked_func, - q_gpu, - kv_gpu, - cu_seqlens_q_gpu, - cu_seqlens_k_gpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-2) - - -def test_flash_attn_varlen_func_gqa(): - total_seqlen, num_q_heads, headdim = [256, 32, 64] - num_kv_heads = 8 - - q_gpu = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - k_gpu = torch.randn( - [total_seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - v_gpu = torch.randn( - [total_seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) - - cu_seqlens_q_cpu = torch.tensor([0, 32, 64, 128, 256], dtype=torch.int32) - cu_seqlens_k_cpu = torch.tensor([0, 32, 64, 128, 256], dtype=torch.int32) - cu_seqlens_q_gpu = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - cu_seqlens_k_gpu = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - max_seqlen_q = 128 - max_seqlen_k = 128 - - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_varlen_func_torch, - q_cpu, - k_cpu, - v_cpu, - cu_seqlens_q_cpu, - cu_seqlens_k_cpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_varlen_func, - q_gpu, - k_gpu, - v_gpu, - cu_seqlens_q_gpu, - cu_seqlens_k_gpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) - assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) - - -def test_flash_attn_varlen_func_gqa_long_max_seqlen(): - # Test function to verify if the module behaves correctly when the maximum sequence length exceeds 2048. - total_seqlen, num_q_heads, headdim = [20206, 6, 64] - num_kv_heads = 2 - - q_gpu = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - k_gpu = torch.randn( - [total_seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - v_gpu = torch.randn( - [total_seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) - - cu_seqlens_q_cpu = torch.tensor(cu_seqlens_max_length_4096, dtype=torch.int32) - cu_seqlens_k_cpu = torch.tensor(cu_seqlens_max_length_4096, dtype=torch.int32) - cu_seqlens_q_gpu = torch.tensor( - cu_seqlens_max_length_4096, dtype=torch.int32, device="npu" - ) - cu_seqlens_k_gpu = torch.tensor( - cu_seqlens_max_length_4096, dtype=torch.int32, device="npu" - ) - # the maximum sequence length is 4096 - max_seqlen_q = 4096 - max_seqlen_k = 4096 - - ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( - flash_attn_varlen_func_torch, - q_cpu, - k_cpu, - v_cpu, - cu_seqlens_q_cpu, - cu_seqlens_k_cpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( - flash_attn_varlen_func, - q_gpu, - k_gpu, - v_gpu, - cu_seqlens_q_gpu, - cu_seqlens_k_gpu, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) - assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) diff --git a/tests/npu/interntrain/test_adamw_npu.py b/tests/npu/interntrain/test_adamw_npu.py deleted file mode 100644 index cd92e12..0000000 --- a/tests/npu/interntrain/test_adamw_npu.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import copy -import torch -import torch_npu -from torch import nn -from deeplink_ext.interntrain_ops.adamw import AdamW - - -def test_AdamW(): - - class MlpModel(nn.Module): - - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(128, 256) - self.linear2 = nn.Linear(256, 512) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - dtype = torch.float32 - device = "npu" - input_data_cpu = torch.rand(16, 128, dtype=dtype) - input_data_device = input_data_cpu.to(device) - cpu_model = MlpModel().to(dtype) - device_model = copy.deepcopy(cpu_model).to(device) - - adamW_cpu = torch.optim.AdamW( - params=cpu_model.parameters(), - lr=1e-4, - betas=(0.9, 0.95), - eps=1e-8, - amsgrad=True, - ) - - adamW_ext = AdamW( - params=device_model.parameters(), - lr=1e-4, - betas=(0.9, 0.95), - eps=1e-8, - amsgrad=True, - ) - - steps = 15 - for step in range(steps): - adamW_cpu.zero_grad() - adamW_ext.zero_grad() - - output_cpu = cpu_model(input_data_cpu) - output_device = device_model(input_data_device) - - output_cpu.mean().backward() - output_device.mean().backward() - - adamW_cpu.step() - adamW_ext.step() - - params_zip = zip(list(cpu_model.parameters()), list(device_model.parameters())) - for cpu_param, device_param in params_zip: - assert torch.allclose(cpu_param, device_param.cpu(), rtol=1e-4, atol=1e-4) diff --git a/tests/npu/interntrain/test_flash_attention_npu.py b/tests/npu/interntrain/test_flash_attention_npu.py deleted file mode 100644 index 086b0e8..0000000 --- a/tests/npu/interntrain/test_flash_attention_npu.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu -from tests.core import copy_to_cpu, allclose, call_module - -from deeplink_ext.interntrain_ops.flash_attention import ( - FlashSelfAttention, - FlashCrossAttention, -) -from deeplink_ext.interntrain_ops.flash_attention_fallback import ( - SelfAttention, - CrossAttention, -) - - -def test_self_attention_qkv_mha(): - batch, seqlen, num_heads, headdim = [8, 32, 32, 64] - - qkv_gpu = torch.rand( - [batch, seqlen, 3, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - qkv_cpu = copy_to_cpu( - [ - qkv_gpu, - ] - ) - ouput_forward_cpu, grads_cpu = call_module( - SelfAttention(), - qkv_cpu[0], - None, - None, - None, - None, - ) - ouput_forward_gpu, grads_gpu = call_module( - FlashSelfAttention().npu(), - qkv_gpu, - None, - None, - None, - None, - ) - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) - - -def test_self_attention_q_k_v_gqa(): - batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64] - num_kv_heads = 8 - - q_gpu = torch.rand( - [batch, seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - k_gpu = torch.rand( - [batch, seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - v_gpu = torch.rand( - [batch, seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) - ouput_forward_cpu, grads_cpu = call_module( - SelfAttention(), - None, - q_cpu, - k_cpu, - v_cpu, - None, - ) - ouput_forward_gpu, grads_gpu = call_module( - FlashSelfAttention().npu(), - None, - q_gpu, - k_gpu, - v_gpu, - None, - ) - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) - - -def test_self_attention_q_kv_gqa(): - batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64] - num_kv_heads = 8 - - q_gpu = torch.rand( - [batch, seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_gpu = torch.rand( - [batch, seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - ouput_forward_cpu, grads_cpu = call_module( - SelfAttention(), - None, - q_cpu, - None, - None, - kv_cpu, - ) - ouput_forward_gpu, grads_gpu = call_module( - FlashSelfAttention().npu(), - None, - q_gpu, - None, - None, - kv_gpu, - ) - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) - - -def test_cross_attention_q_kv_mha(): - batch, seqlen, num_heads, headdim = [8, 32, 32, 64] - - q_gpu = torch.rand( - [batch, seqlen, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_gpu = torch.rand( - [batch, seqlen, 2, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - ouput_forward_cpu, grads_cpu = call_module( - CrossAttention(), - q_cpu, - kv_cpu, - ) - ouput_forward_gpu, grads_gpu = call_module( - FlashCrossAttention().npu(), - q_gpu, - kv_gpu, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) - - -def test_cross_attention_q_kv_gqa(): - batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64] - num_kv_heads = 8 - q_gpu = torch.rand( - [batch, seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_gpu = torch.rand( - [batch, seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - - q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - ouput_forward_cpu, grads_cpu = call_module(CrossAttention(), q_cpu, kv_cpu) - ouput_forward_gpu, grads_gpu = call_module( - FlashCrossAttention().npu(), - q_gpu, - kv_gpu, - ) - - assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) - assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) diff --git a/tests/npu/interntrain/test_rms_norm_npu.py b/tests/npu/interntrain/test_rms_norm_npu.py deleted file mode 100644 index a506b12..0000000 --- a/tests/npu/interntrain/test_rms_norm_npu.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch - -from deeplink_ext.interntrain_ops.rms_norm import MixedFusedRMSNorm -from deeplink_ext.interntrain_ops.rms_norm_fallback import MixedRMSNormTorch - - -def test_rms_norm_npu(): - input_dtype_list = [torch.float32, torch.bfloat16, torch.float32, torch.float32] - weight_dtype_list = [torch.float32, torch.bfloat16, torch.float16, torch.bfloat16] - - for input_dtype, weight_dtype in zip(input_dtype_list, weight_dtype_list): - x = torch.randn(1, 64, 32, 64, dtype=input_dtype).requires_grad_() - m_cpu = MixedRMSNormTorch([x.shape[-1]], 1e-5).to(weight_dtype) - out = m_cpu(x) - out.backward(torch.ones_like(out)) - - y = x.detach().clone().npu().requires_grad_() - m_npu = MixedFusedRMSNorm([y.shape[-1]], 1e-5).npu().to(weight_dtype) - out2 = m_npu(y) - - out2.backward(torch.ones_like(out2)) - - torch.allclose(out, out2.cpu(), atol=1e-4, rtol=1e-4) - torch.allclose(x.grad, y.grad.cpu(), atol=1e-4, rtol=1e-4) diff --git a/tests/npu/interntrain/test_rotary_embedding_npu.py b/tests/npu/interntrain/test_rotary_embedding_npu.py deleted file mode 100644 index 541de84..0000000 --- a/tests/npu/interntrain/test_rotary_embedding_npu.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu -from tests.core import call_autograd_func, allclose - -from deeplink_ext.interntrain_ops.rotary_embedding import ( - ApplyRotaryEmb, - ApplyRotaryEmbQKV_, -) -from deeplink_ext.interntrain_ops.rotary_embedding_fallback import ( - ApplyRotaryEmbTorch, - ApplyRotaryEmbQKV_Torch, -) - - -def test_ApplyRotaryEmb(): - input_dtype_list = [torch.float16, torch.bfloat16, torch.float32] - interleaved = False - for input_dtype in input_dtype_list: - input_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="npu", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="npu") - sin = torch.randn(64, 32, dtype=input_dtype, device="npu") - - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbTorch, "npu", input_dtype, input_ref, cos, sin, interleaved - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmb, "npu", input_dtype, input_ext, cos, sin, interleaved - ) - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype}, ApplyRotaryEmb fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext - ), f"When input dtype is {input_dtype}, ApplyRotaryEmb fails to pass the backward test!" - - -def test_ApplyRotaryEmbQKV__qkv(): - # Note: For ascend, when dtype of input is fp32, the difference in calculation results is significant. - input_dtype_list = [torch.float16, torch.bfloat16] - interleaved = False - for input_dtype in input_dtype_list: - input_ref = torch.randn( - 1, 64, 3, 32, 64, dtype=input_dtype, device="npu", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="npu") - sin = torch.randn(64, 32, dtype=input_dtype, device="npu") - - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbQKV_Torch, - "npu", - input_dtype, - input_ref, - cos, - sin, - None, - None, - interleaved, - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmbQKV_, - "npu", - input_dtype, - input_ext, - cos, - sin, - None, - None, - interleaved, - ) - - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype}, ApplyRotaryEmbQKV_ fails to pass the forward test!" - - assert allclose( - grad_ref, - grad_ext, - ), f"When input dtype is {input_dtype}, ApplyRotaryEmbQKV_ fails to pass the backward test!" diff --git a/tests/npu/interntrain/test_varlen_flash_attention_npu.py b/tests/npu/interntrain/test_varlen_flash_attention_npu.py deleted file mode 100644 index ded765c..0000000 --- a/tests/npu/interntrain/test_varlen_flash_attention_npu.py +++ /dev/null @@ -1,458 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu -from tests.core import allclose, call_module - -from deeplink_ext.interntrain_ops.flash_attention import ( - FlashSelfAttention, - FlashCrossAttention, -) -from deeplink_ext.interntrain_ops.flash_attention_fallback import ( - SelfAttention, - CrossAttention, -) - -# fmt: off -g_cu_seqlens = [ - 0, 186, 382, 1259, 1464, 2547, 2705, 3495, 3854, 4696, 4762, 4885, 5118, 5355, 5503, 5760, 6168, 6353, - 8272, 8461, 9273, 9531, 9763, 9871, 10234, 10370, 10574, 10712, 11022, 11236, 11599, 11837, 12179, 12320, - 12560, 12731, 13038, 13180, 13477, 14025, 14742, 14872, 15131, 15773, 15967, 16110, 16384, -] -# fmt on - -class TestFlashSelfAttention: - def test_self_attention_varlen_qkv_mha(self): - total_seqlen, num_heads, headdim = [256, 32, 64] - - qkv_ref = torch.randn( - [total_seqlen, 3, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - qkv_ext = qkv_ref.clone().detach().requires_grad_(True) - - cu_seqlens_ref = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - max_seqlen = 128 - - ouput_forward_ref, grads_ref = call_module( - SelfAttention().npu(), - qkv_ref, - None, - None, - None, - None, - True, - cu_seqlens_ref, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashSelfAttention().npu(), - qkv_ext, - None, - None, - None, - None, - True, - cu_seqlens_ref, - max_seqlen, - ) - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=1e-2) - - - def test_self_attention_varlen_q_k_v_gqa(self): - total_seqlen, num_q_heads, headdim = [256, 32, 64] - num_kv_heads = 8 - - q_ref = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - k_ref = torch.randn( - [total_seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - v_ref = torch.randn( - [total_seqlen, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_ext = q_ref.clone().detach().requires_grad_(True) - k_ext = k_ref.clone().detach().requires_grad_(True) - v_ext = v_ref.clone().detach().requires_grad_(True) - - cu_seqlens_q_ref = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - cu_seqlens_k_ref = torch.tensor( - [0, 32, 64, 128, 256], dtype=torch.int32, device="npu" - ) - max_seqlen = 128 - - ouput_forward_ref, grads_ref = call_module( - SelfAttention().npu(), - None, - q_ref, - k_ref, - v_ref, - None, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashSelfAttention().npu(), - None, - q_ext, - k_ext, - v_ext, - None, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=1e-2) - - - def test_self_attention_varlen_q_kv_mha(self): - total_seqlen, num_heads, headdim = [16384, 6, 64] - - q_ref = torch.randn( - [total_seqlen, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_ref = torch.randn( - [total_seqlen, 2, num_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_ext = q_ref.clone().detach().requires_grad_(True) - kv_ext = kv_ref.clone().detach().requires_grad_(True) - - cu_seqlens_q_ref = torch.tensor(g_cu_seqlens, dtype=torch.int32, device="npu") - cu_seqlens_k_ref = torch.tensor(g_cu_seqlens, dtype=torch.int32, device="npu") - max_seqlen = 1919 - - ouput_forward_ref, grads_ref = call_module( - SelfAttention().npu(), - None, - q_ref, - None, - None, - kv_ref, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashSelfAttention().npu(), - None, - q_ext, - None, - None, - kv_ext, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=1e-2) - - - def test_self_attention_varlen_q_kv_gqa(self): - total_seqlen, num_q_heads, headdim = [16384, 6, 64] - num_kv_heads = 2 - - q_ref = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_ref = torch.randn( - [total_seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_ext = q_ref.clone().detach().requires_grad_(True) - kv_ext = kv_ref.clone().detach().requires_grad_(True) - - cu_seqlens_q_ref = torch.tensor(g_cu_seqlens, dtype=torch.int32, device="npu") - cu_seqlens_k_ref = torch.tensor(g_cu_seqlens, dtype=torch.int32, device="npu") - max_seqlen = 1919 - - ouput_forward_ref, grads_ref = call_module( - SelfAttention().npu(), - None, - q_ref, - None, - None, - kv_ref, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashSelfAttention().npu(), - None, - q_ext, - None, - None, - kv_ext, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=1e-2) - - - def test_self_attention_varlen_q_kv_gqa_long_max_seqlen(self): - # Test function to verify if the module behaves correctly when the maximum sequence length exceeds 2048. - total_seqlen, num_q_heads, headdim = [20206, 6, 64] - num_kv_heads = 2 - - q_ref = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_ref = torch.randn( - [total_seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_ext = q_ref.clone().detach().requires_grad_(True) - kv_ext = kv_ref.clone().detach().requires_grad_(True) - - # fmt: off - # the new sequence lengths for the test case, latest sequence length is 20206-16110=4096 - cu_seqlens_max_length_4096 = [ - 0, 186, 382, 1259, 1464, 2547, 2705, 3495, 3854, 4696, 4762, 4885, 5118, 5355, 5503, 5760, 6168, 6353, - 8272, 8461, 9273, 9531, 9763, 9871, 10234, 10370, 10574, 10712, 11022, 11236, 11599, 11837, 12179, 12320, - 12560, 12731, 13038, 13180, 13477, 14025, 14742, 14872, 15131, 15773, 15967, 16110, 20206, - ] - # fmt: on - - cu_seqlens_q_ref = torch.tensor(cu_seqlens_max_length_4096, dtype=torch.int32, device="npu") - cu_seqlens_k_ref = torch.tensor(cu_seqlens_max_length_4096, dtype=torch.int32, device="npu") - # the maximum sequence length is 4096 - max_seqlen = 4096 - - ouput_forward_ref, grads_ref = call_module( - SelfAttention().npu(), - None, - q_ref, - None, - None, - kv_ref, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashSelfAttention().npu(), - None, - q_ext, - None, - None, - kv_ext, - True, - None, - None, - cu_seqlens_q_ref, - cu_seqlens_k_ref, - max_seqlen, - max_seqlen, - ) - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=1e-2) - - -class TestFlashCrossAttention: - - def test_cross_attention_varlen_q_kv_mha(self): - total_seqlen, num_heads, headdim = [16384, 6, 64] - - q_ref = torch.randn( - [total_seqlen, num_heads, headdim], - dtype=torch.bfloat16, - requires_grad=True, - device="npu", - ) - kv_ref = torch.randn( - [total_seqlen, 2, num_heads, headdim], - dtype=torch.bfloat16, - requires_grad=True, - device="npu", - ) - q_ext = q_ref.clone().detach().requires_grad_(True) - kv_ext = kv_ref.clone().detach().requires_grad_(True) - - cu_seqlens_ref = torch.tensor(g_cu_seqlens, dtype=torch.int32, device="npu") - cu_seqlens_k_ref = torch.tensor(g_cu_seqlens, dtype=torch.int32, device="npu") - max_seqlen = 1919 - - ouput_forward_ref, grads_ref = call_module( - CrossAttention().npu(), - q_ref, - kv_ref, - True, - cu_seqlens_ref, - max_seqlen, - cu_seqlens_k_ref, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashCrossAttention().npu(), - q_ext, - kv_ext, - True, - cu_seqlens_ref, - max_seqlen, - cu_seqlens_k_ref, - max_seqlen, - ) - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=5e-2) - - def test_cross_attention_varlen_q_kv_gqa(self): - total_seqlen, num_q_heads, headdim = [256, 32, 64] - num_kv_heads = 8 - - q_ref = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_ref = torch.randn( - [total_seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_ext = q_ref.clone().detach().requires_grad_(True) - kv_ext = kv_ref.clone().detach().requires_grad_(True) - - cu_seqlens_ref = torch.tensor([0, 32, 64, 128, 256], dtype=torch.int32, device="npu") - cu_seqlens_k_ref = torch.tensor([0, 32, 64, 128, 256], dtype=torch.int32, device="npu") - max_seqlen = 128 - - ouput_forward_ref, grads_ref = call_module( - CrossAttention().npu(), - q_ref, - kv_ref, - True, - cu_seqlens_ref, - max_seqlen, - cu_seqlens_k_ref, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashCrossAttention().npu(), - q_ext, - kv_ext, - True, - cu_seqlens_ref, - max_seqlen, - cu_seqlens_k_ref, - max_seqlen, - ) - - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=1e-2) - - def test_cross_attention_varlen_q_kv_gqa_long_max_seqlen(self): - # Test function to verify if the module behaves correctly when the maximum sequence length exceeds 2048. - total_seqlen, num_q_heads, headdim = [4224, 32, 64] - num_kv_heads = 8 - - q_ref = torch.randn( - [total_seqlen, num_q_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - kv_ref = torch.randn( - [total_seqlen, 2, num_kv_heads, headdim], - dtype=torch.float16, - requires_grad=True, - device="npu", - ) - q_ext = q_ref.clone().detach().requires_grad_(True) - kv_ext = kv_ref.clone().detach().requires_grad_(True) - - # last sequence length is 4224-128=4096 - cu_seqlens_ref = torch.tensor([0, 32, 64, 128, 4224], dtype=torch.int32, device="npu") - cu_seqlens_k_ref = torch.tensor([0, 32, 64, 128, 4224], dtype=torch.int32, device="npu") - # the maximum sequence length is 4096 - max_seqlen = 4096 - - ouput_forward_ref, grads_ref = call_module( - CrossAttention().npu(), - q_ref, - kv_ref, - True, - cu_seqlens_ref, - max_seqlen, - cu_seqlens_k_ref, - max_seqlen, - ) - ouput_forward_ext, grads_ext = call_module( - FlashCrossAttention().npu(), - q_ext, - kv_ext, - True, - cu_seqlens_ref, - max_seqlen, - cu_seqlens_k_ref, - max_seqlen, - ) - - assert allclose(ouput_forward_ref, ouput_forward_ext, rtol=1e-5, atol=1e-5) - assert allclose(grads_ref, grads_ext, rtol=1e-5, atol=1e-2) From 99b1e76932b9d796ac9fae6990113418248057f6 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 4 Sep 2024 02:39:47 +0000 Subject: [PATCH 3/4] fix run cmd --- .github/workflows/static.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index dcc4638..bd4fb7b 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -107,10 +107,11 @@ jobs: run: | source /mnt/cache/share/platform/cienv/dipu_latest_ci cd ${DEEPLINK_PATH}/${{ github.run_number }}/DeepLinkExt - export PYTHONPATH=$PWD:$PYTHONPATH + + cd tests/ export DEEPLINK_EXT_PLATFORM_TYPE=torch_dipu - python -m pytest tests/dipu + python -m pytest -v ./ export DEEPLINK_EXT_PLATFORM_TYPE=torch_npu - python -m pytest tests/npu \ No newline at end of file + python -m pytest -v ./ \ No newline at end of file From 2b5e0326c7acbf78cb9838acbe23468812a13d3b Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 4 Sep 2024 02:57:44 +0000 Subject: [PATCH 4/4] fix py format --- deeplink_ext/ascend_speed/_rotary_embedding_npu.py | 6 +++++- tests/conftest.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py index 646f7e9..2b84b3e 100644 --- a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py +++ b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py @@ -27,4 +27,8 @@ def forward(ctx, x, cos, sin): @staticmethod def backward(ctx, grad_output): out, cos, sin = ctx.saved_tensors - return torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], None, None + return ( + torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], + None, + None, + ) diff --git a/tests/conftest.py b/tests/conftest.py index cd2aae5..98906e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,8 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def import_module(): platform = deeplink_ext_get_platform_type() if platform == PlatformType.TORCH_NPU: