Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Use power of 2 size tensors in Attention benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
steve.an committed Oct 10, 2023
1 parent 96b0328 commit 74102c9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@util.report(
"attention forward", ["y_size"], [32 * i for i in range(1, 21)], {"num_batches": 64, "num_heads": 8, "x_size": 64}
"attention forward", ["y_size"], [2**i for i in range(5, 10)], {"num_batches": 64, "num_heads": 8, "x_size": 64}
)
def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backend):
factory_kwargs = {"device": "cuda", "dtype": dtype}
Expand All @@ -39,8 +39,8 @@ def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backe
@util.report(
"attention backward",
["y_size"],
[64 * i for i in range(1, 21)],
{"num_batches": 64, "num_heads": 8, "x_size": 64},
[2**i for i in range(5, 10)],
{"num_batches": 32, "num_heads": 8, "x_size": 64},
)
def bench_attention_backward(num_batches, num_heads, y_size, x_size, dtype, backend):
factory_kwargs = {"device": "cuda", "dtype": dtype}
Expand Down
12 changes: 11 additions & 1 deletion trident/operation/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def forward(ctx: Any, *args: Any, **kwargs: Any):
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise ValueError("The dimension of query, key and value should be 4.")

if (
not util.is_pow2(query.shape[-2])
or not util.is_pow2(query.shape[-1])
or not util.is_pow2(key.shape[-2])
or not util.is_pow2(key.shape[-1])
or not util.is_pow2(value.shape[-2])
or not util.is_pow2(value.shape[-1])
):
raise ValueError("Attention supports only for power of 2 size tensors.")

if mask is not None:
if is_causal:
raise ValueError("Error because both attn_mask and is_causal are set.")
Expand Down Expand Up @@ -185,7 +195,7 @@ def grid(meta):
util.dtype(grad_query.dtype),
64,
triton.next_power_of_2(x_size),
num_warps=4 if x_size <= 64 else 8,
num_warps=2,
)
util.pop_trace()

Expand Down
4 changes: 4 additions & 0 deletions trident/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def dtype(input):
raise ValueError(f"Unable to convert the given input: '{input}'.")


def is_pow2(value):
return False if value == 0 else (value & (value - 1)) == 0


def size_and_stride(input: torch.Tensor, dim: int):
if input.dim() == 2:
if dim == 0:
Expand Down

0 comments on commit 74102c9

Please sign in to comment.