Skip to content

Commit

Permalink
add test for flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jul 24, 2024
1 parent 046b4fa commit 3a8a581
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 3 deletions.
30 changes: 27 additions & 3 deletions tests/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# Copyright (c) 2024, DeepLink.

import torch
import typing
from typing import Callable, Any

__all__ = ["call_module", "call_func", "copy_to_cpu", "allclose"]
__all__ = [
"call_module",
"call_autograd_func",
"call_normal_func",
"copy_to_cpu",
"allclose",
]


def call_module(module: torch.nn.Module, *forward_args):
Expand All @@ -24,7 +30,7 @@ def call_module(module: torch.nn.Module, *forward_args):
return output_forward, grads


def call_func(f: torch.autograd.Function, device, dtype, *args: list):
def call_autograd_func(f: torch.autograd.Function, device, dtype, *args: tuple):
class Module(torch.nn.Module):
def __init__(self, func):
super(Module, self).__init__()
Expand All @@ -36,6 +42,24 @@ def forward(self, *args):
return call_module(Module(f).to(device).to(dtype), *args)


def call_normal_func(func: Callable[..., Any], *args: tuple, **kwargs: dict):
output_forward = func(*args, **kwargs)
grads = []
if torch.is_tensor(output_forward):
output_forward.backward(torch.ones_like(output_forward))
elif isinstance(output_forward, (list, tuple)):
assert torch.is_tensor(output_forward[0]), "output_forward[0] is not a tensor"
output_forward[0].backward(torch.ones_like(output_forward[0]))
else:
raise RuntimeError(
"the result of forward is not a tensor or list or tuple of tensor"
)
for arg in args:
if torch.is_tensor(arg) and arg.requires_grad:
grads.append(arg.grad)
return output_forward, grads


def copy_to_cpu(tensors: list[torch.Tensor], dtype=None):
if dtype is None:
dtype = torch.float32
Expand Down
210 changes: 210 additions & 0 deletions tests/internevo/test_flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright (c) 2024, DeepLink.

import torch
from tests.core import copy_to_cpu, allclose, call_normal_func

from deeplink_ext.internevo_ops.flash_attention_fallback import (
torch_attn_qkvpacked_func,
torch_attn_kvpacked_func,
torch_attn_func,
)
from deeplink_ext.internevo_ops.flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
)


def test_flash_attn_qkvpacked_func_mha():
batch, seqlen, num_heads, headdim = [8, 32, 32, 64]

qkv_gpu = torch.rand(
[batch, seqlen, 3, num_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)

qkv_cpu = copy_to_cpu(
[
qkv_gpu,
]
)

ouput_forward_cpu, grads_cpu = call_normal_func(
torch_attn_qkvpacked_func,
qkv_cpu[0],
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
flash_attn_qkvpacked_func,
qkv_gpu,
dropout_p=0.0,
causal=True,
)

assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3)
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3)


def test_flash_attn_kvpacked_func_mha():
batch, seqlen, num_heads, headdim = [8, 32, 32, 64]

q_gpu = torch.rand(
[batch, seqlen, num_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)
kv_gpu = torch.rand(
[batch, seqlen, 2, num_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)

q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu])
ouput_forward_cpu, grads_cpu = call_normal_func(
torch_attn_kvpacked_func,
q_cpu,
kv_cpu,
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
flash_attn_kvpacked_func,
q_gpu,
kv_gpu,
dropout_p=0.0,
causal=True,
)

assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3)
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3)


def test_flash_attn_kvpacked_func_gqa():
batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64]
num_kv_heads = 8

q_gpu = torch.rand(
[batch, seqlen, num_q_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)
kv_gpu = torch.rand(
[batch, seqlen, 2, num_kv_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)

q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu])
ouput_forward_cpu, grads_cpu = call_normal_func(
torch_attn_kvpacked_func,
q_cpu,
kv_cpu,
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
flash_attn_kvpacked_func,
q_gpu,
kv_gpu,
dropout_p=0.0,
causal=True,
)

assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3)
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3)


def test_flash_attn_func_mha():
batch, seqlen, num_heads, headdim = [8, 32, 32, 64]

q_gpu = torch.rand(
[batch, seqlen, num_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)
k_gpu = torch.rand(
[batch, seqlen, num_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)
v_gpu = torch.rand(
[batch, seqlen, num_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)

q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu])
ouput_forward_cpu, grads_cpu = call_normal_func(
torch_attn_func,
q_cpu,
k_cpu,
v_cpu,
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
flash_attn_func,
q_gpu,
k_gpu,
v_gpu,
dropout_p=0.0,
causal=True,
)

assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3)
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3)


def test_flash_attn_func_gqa():
batch, seqlen, num_q_heads, headdim = [8, 32, 32, 64]
num_kv_heads = 8

q_gpu = torch.rand(
[batch, seqlen, num_q_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)
k_gpu = torch.rand(
[batch, seqlen, num_kv_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)
v_gpu = torch.rand(
[batch, seqlen, num_kv_heads, headdim],
dtype=torch.float16,
requires_grad=True,
device="cuda",
)

q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu])
ouput_forward_cpu, grads_cpu = call_normal_func(
torch_attn_func,
q_cpu,
k_cpu,
v_cpu,
dropout_p=0.0,
causal=True,
)
ouput_forward_gpu, grads_gpu = call_normal_func(
flash_attn_func,
q_gpu,
k_gpu,
v_gpu,
dropout_p=0.0,
causal=True,
)

assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3)
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3)

0 comments on commit 3a8a581

Please sign in to comment.