Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jul 30, 2024
1 parent bb345a3 commit d6823de
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/internevo/test_flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024, DeepLink.

import torch
from tests.core import copy_to_cpu, allclose, call_normal_func
from tests.core import copy_to_cpu, allclose, calculate_fwd_and_bwd

from deeplink_ext.internevo_ops.flash_attention_fallback import (
flash_attn_qkvpacked_func_torch,
Expand Down Expand Up @@ -31,13 +31,13 @@ def test_flash_attn_qkvpacked_func_mha():
]
)

ouput_forward_cpu, grads_cpu = call_normal_func(
ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd(
flash_attn_qkvpacked_func_torch,
qkv_cpu[0],
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd(
flash_attn_qkvpacked_func,
qkv_gpu,
dropout_p=0.0,
Expand Down Expand Up @@ -66,14 +66,14 @@ def test_flash_attn_kvpacked_func_gqa():
)

q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu])
ouput_forward_cpu, grads_cpu = call_normal_func(
ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd(
flash_attn_kvpacked_func_torch,
q_cpu,
kv_cpu,
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd(
flash_attn_kvpacked_func,
q_gpu,
kv_gpu,
Expand Down Expand Up @@ -109,15 +109,15 @@ def test_flash_attn_func_gqa():
)

q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu])
ouput_forward_cpu, grads_cpu = call_normal_func(
ouput_forward_cpu, grads_cpu = calculate_fwd_and_bwd(
flash_attn_func_torch,
q_cpu,
k_cpu,
v_cpu,
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
ouput_forward_gpu, grads_gpu = calculate_fwd_and_bwd(
flash_attn_func,
q_gpu,
k_gpu,
Expand Down

0 comments on commit d6823de

Please sign in to comment.