From d6823de786dd9251dfe9fced5d8625f616dcb91a Mon Sep 17 00:00:00 2001 From: POI-WX Date: Tue, 30 Jul 2024 12:11:42 +0800 Subject: [PATCH] update --- tests/internevo/test_flash_attention.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/internevo/test_flash_attention.py b/tests/internevo/test_flash_attention.py index cd37950..5126551 100644 --- a/tests/internevo/test_flash_attention.py +++ b/tests/internevo/test_flash_attention.py @@ -1,7 +1,7 @@ # Copyright (c) 2024, DeepLink. import torch -from tests.core import copy_to_cpu, allclose, call_normal_func +from tests.core import copy_to_cpu, allclose, calculate_fwd_and_bwd from deeplink_ext.internevo_ops.flash_attention_fallback import ( flash_attn_qkvpacked_func_torch, @@ -31,13 +31,13 @@ def test_flash_attn_qkvpacked_func_mha(): ] ) - ouput_forward_cpu, grads_cpu = call_normal_func( + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( flash_attn_qkvpacked_func_torch, qkv_cpu[0], dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_qkvpacked_func, qkv_gpu, dropout_p=0.0, @@ -66,14 +66,14 @@ def test_flash_attn_kvpacked_func_gqa(): ) q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) - ouput_forward_cpu, grads_cpu = call_normal_func( + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( flash_attn_kvpacked_func_torch, q_cpu, kv_cpu, dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_kvpacked_func, q_gpu, kv_gpu, @@ -109,7 +109,7 @@ def test_flash_attn_func_gqa(): ) q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) - ouput_forward_cpu, grads_cpu = call_normal_func( + ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd( flash_attn_func_torch, q_cpu, k_cpu, @@ -117,7 +117,7 @@ def test_flash_attn_func_gqa(): dropout_p=0.0, causal=True, ) - ouput_forward_gpu, grads_gpu = call_normal_func( + ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd( flash_attn_func, q_gpu, k_gpu,