Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Add dropout feature to Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
daemyung authored Sep 15, 2023
1 parent bed3c4d commit f23f891
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 289 deletions.
49 changes: 20 additions & 29 deletions benchmarks/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,35 @@
import trident


@util.report("attention forward", ["sequence_size"], [2**i for i in range(10, 15)])
def bench_attention_forward(sequence_size, dtype, backend):
num_batches, num_heads, embedding_size = 4, 48, 64
@util.report(
"attention forward", ["y_size"], [32 * i for i in range(1, 21)], {"num_batches": 64, "num_heads": 8, "x_size": 64}
)
def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backend):
factory_kwargs = {"device": "cuda", "dtype": dtype}
query = torch.randn((num_batches, num_heads, sequence_size, embedding_size), **factory_kwargs)
key = torch.randn((num_batches, num_heads, sequence_size, embedding_size), **factory_kwargs)
value = torch.randn((num_batches, num_heads, sequence_size, embedding_size), **factory_kwargs)
query = torch.randn(num_batches, num_heads, y_size, x_size, **factory_kwargs)
key = torch.randn_like(query)
value = torch.randn_like(query)

if backend == "torch":
return triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.scaled_dot_product_attention(query, key, value)
)
return triton.testing.do_bench(lambda: torch.nn.functional.scaled_dot_product_attention(query, key, value))
else:
return triton.testing.do_bench_cudagraph(
return triton.testing.do_bench(
lambda: trident.function.scaled_dot_product_attention(query, key, value, use_accelerator=True)
)


@util.report("attention backward", ["sequence_size"], [256 * i for i in range(1, 4)])
def bench_attention_backward(sequence_size, dtype, backend):
num_batches, num_heads, embedding_size = 4, 48, 64
@util.report(
"attention backward",
["y_size"],
[64 * i for i in range(1, 21)],
{"num_batches": 64, "num_heads": 8, "x_size": 64},
)
def bench_attention_backward(num_batches, num_heads, y_size, x_size, dtype, backend):
factory_kwargs = {"device": "cuda", "dtype": dtype}
query = torch.randn(
(num_batches, num_heads, sequence_size, embedding_size),
**factory_kwargs,
requires_grad=True,
)
key = torch.randn(
(num_batches, num_heads, sequence_size, embedding_size),
**factory_kwargs,
requires_grad=True,
)
value = torch.randn(
(num_batches, num_heads, sequence_size, embedding_size),
**factory_kwargs,
requires_grad=True,
)
grad_output = torch.randn((num_batches, num_heads, sequence_size, embedding_size), **factory_kwargs)
query = torch.randn(num_batches, num_heads, y_size, x_size, **factory_kwargs, requires_grad=True)
key = torch.randn_like(query, requires_grad=True)
value = torch.randn_like(query, requires_grad=True)
grad_output = torch.randn(num_batches, num_heads, y_size, x_size, **factory_kwargs)

if backend == "torch":
output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
Expand Down
85 changes: 51 additions & 34 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,36 @@
from tests import util


@pytest.mark.parametrize("is_causal, embedding_size", [(True, 16), (False, 32), (False, 64)])
def test_forward(is_causal, embedding_size, device):
num_batches, num_heads, sequence_size = 6, 9, 1024
factory_kwargs = {"device": device}
query = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
key = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
value = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)

a = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
b = trident.function.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
assert util.equal(a, b)


@pytest.mark.parametrize("is_causal, embedding_size", [(True, 16), (False, 32), (False, 64)])
def test_backward(is_causal, embedding_size, device):
num_batches, num_heads, sequence_size = 6, 9, 1024
factory_kwargs = {"device": device}
query = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
key = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
value = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
grad_out = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
@pytest.mark.parametrize(
"num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 128, 64, False)]
)
def test_forward(num_batches, num_heads, y_size, x_size, is_causal, device):
query = torch.randn(num_batches, num_heads, y_size, x_size, device=device)
key = torch.randn_like(query)
value = torch.randn_like(query)

assert util.equal(
torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal),
trident.function.scaled_dot_product_attention(query, key, value, is_causal=is_causal),
)


@pytest.mark.parametrize(
"num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 128, 64, False)]
)
def test_backward(num_batches, num_heads, y_size, x_size, is_causal, device):
query = torch.rand(num_batches, num_heads, y_size, x_size, device=device)
key = torch.randn_like(query)
value = torch.randn_like(query)
grad_output = torch.randn_like(query)

def train(func):
q = query.clone()
k = key.clone()
v = value.clone()
q.requires_grad = k.requires_grad = v.requires_grad = True
func(q, k, v, is_causal=is_causal).backward(grad_out, retain_graph=True)
return q.grad, k.grad, v.grad
i = query.clone()
j = key.clone()
k = value.clone()
i.requires_grad = j.requires_grad = k.requires_grad = True
func(i, j, k, is_causal=is_causal).backward(grad_output, retain_graph=True)
return i.grad, j.grad, k.grad

(x, y, z) = train(torch.nn.functional.scaled_dot_product_attention)
(a, b, c) = train(trident.function.scaled_dot_product_attention)
Expand All @@ -57,13 +58,29 @@ def train(func):
assert util.equal(z, c)


@pytest.mark.parametrize("is_causal, embedding_size", [(True, 16)])
def test_attention(is_causal, embedding_size, device, dtype):
num_batches, num_heads, sequence_size = 6, 9, 1024
@pytest.mark.parametrize(
"num_batches, num_heads, y_size, x_size, is_causal", [(1, 1, 1, 16, True), (1, 1, 1, 16, False)]
)
def test_attention(num_batches, num_heads, y_size, x_size, is_causal, device, dtype):
if dtype is torch.bfloat16:
pytest.skip("Triton has a bug.")

factory_kwargs = {"device": device, "dtype": dtype}
query = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
key = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
value = torch.rand(num_batches, num_heads, sequence_size, embedding_size, **factory_kwargs)
query = torch.rand(num_batches, num_heads, y_size, x_size, **factory_kwargs, requires_grad=True)
key = torch.randn_like(query, requires_grad=True)
value = torch.randn_like(query, requires_grad=True)
grad_output = torch.randn_like(query)

output = trident.function.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
assert output is not None and output.dtype == dtype

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert query.grad is not None
assert query.grad.dtype == dtype
assert key.grad is not None
assert key.grad.dtype == dtype
assert value.grad is not None
assert value.grad.dtype == dtype
8 changes: 6 additions & 2 deletions trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,20 @@ def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
use_accelerator: bool = False,
):
"""
Computes scaled dot product attention on query, key and value tensors,
and applying dropout if a probability greater than 0.0 is specified.
"""
assert len(query.size()) == 4
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise ValueError("The dimension of query, key and value should be 4.")

return operation.Attention.apply(query, key, value, is_causal, 1.0 / math.sqrt(key.shape[-1]), use_accelerator)
return operation.Attention.apply(
query, key, value, dropout_p, is_causal, 1.0 / math.sqrt(key.shape[-1]), use_accelerator
)


def softmax(input: torch.Tensor, dim: int = None):
Expand Down
Loading

0 comments on commit f23f891

Please sign in to comment.