-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
237 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
import torch | ||
from tests.core import copy_to_cpu, allclose, call_normal_func | ||
|
||
from deeplink_ext.internevo_ops.flash_attention_fallback import ( | ||
torch_attn_qkvpacked_func, | ||
torch_attn_kvpacked_func, | ||
torch_attn_func, | ||
) | ||
from deeplink_ext.internevo_ops.flash_attention import ( | ||
flash_attn_qkvpacked_func, | ||
flash_attn_kvpacked_func, | ||
flash_attn_func, | ||
) | ||
|
||
|
||
def test_flash_attn_qkvpacked_func_mha(): | ||
batch, seqlen, num_heads, headdim = [8, 32, 32, 64] | ||
|
||
qkv_gpu = torch.rand( | ||
[batch, seqlen, 3, num_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
|
||
qkv_cpu = copy_to_cpu( | ||
[ | ||
qkv_gpu, | ||
] | ||
) | ||
|
||
ouput_forward_cpu, grads_cpu = call_normal_func( | ||
torch_attn_qkvpacked_func, | ||
qkv_cpu[0], | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
ouput_forward_gpu, grads_gpu = call_normal_func( | ||
flash_attn_qkvpacked_func, | ||
qkv_gpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
|
||
assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) | ||
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def test_flash_attn_kvpacked_func_mha(): | ||
batch, seqlen, num_heads, headdim = [8, 32, 32, 64] | ||
|
||
q_gpu = torch.rand( | ||
[batch, seqlen, num_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
kv_gpu = torch.rand( | ||
[batch, seqlen, 2, num_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
|
||
q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) | ||
ouput_forward_cpu, grads_cpu = call_normal_func( | ||
torch_attn_kvpacked_func, | ||
q_cpu, | ||
kv_cpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
ouput_forward_gpu, grads_gpu = call_normal_func( | ||
flash_attn_kvpacked_func, | ||
q_gpu, | ||
kv_gpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
|
||
assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) | ||
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def test_flash_attn_kvpacked_func_gqa(): | ||
batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64] | ||
num_kv_heads = 8 | ||
|
||
q_gpu = torch.rand( | ||
[batch, seqlen, num_q_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
kv_gpu = torch.rand( | ||
[batch, seqlen, 2, num_kv_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
|
||
q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu]) | ||
ouput_forward_cpu, grads_cpu = call_normal_func( | ||
torch_attn_kvpacked_func, | ||
q_cpu, | ||
kv_cpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
ouput_forward_gpu, grads_gpu = call_normal_func( | ||
flash_attn_kvpacked_func, | ||
q_gpu, | ||
kv_gpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
|
||
assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) | ||
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def test_flash_attn_func_mha(): | ||
batch, seqlen, num_heads, headdim = [8, 32, 32, 64] | ||
|
||
q_gpu = torch.rand( | ||
[batch, seqlen, num_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
k_gpu = torch.rand( | ||
[batch, seqlen, num_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
v_gpu = torch.rand( | ||
[batch, seqlen, num_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
|
||
q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) | ||
ouput_forward_cpu, grads_cpu = call_normal_func( | ||
torch_attn_func, | ||
q_cpu, | ||
k_cpu, | ||
v_cpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
ouput_forward_gpu, grads_gpu = call_normal_func( | ||
flash_attn_func, | ||
q_gpu, | ||
k_gpu, | ||
v_gpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
|
||
assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) | ||
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def test_flash_attn_func_gqa(): | ||
batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64] | ||
num_kv_heads = 8 | ||
|
||
q_gpu = torch.rand( | ||
[batch, seqlen, num_q_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
k_gpu = torch.rand( | ||
[batch, seqlen, num_kv_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
v_gpu = torch.rand( | ||
[batch, seqlen, num_kv_heads, headdim], | ||
dtype=torch.float16, | ||
requires_grad=True, | ||
device="cuda", | ||
) | ||
|
||
q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu]) | ||
ouput_forward_cpu, grads_cpu = call_normal_func( | ||
torch_attn_func, | ||
q_cpu, | ||
k_cpu, | ||
v_cpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
ouput_forward_gpu, grads_gpu = call_normal_func( | ||
flash_attn_func, | ||
q_gpu, | ||
k_gpu, | ||
v_gpu, | ||
dropout_p=0.0, | ||
causal=True, | ||
) | ||
|
||
assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) | ||
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) |