diff --git a/benchmarks/benchmark_attention.py b/benchmarks/benchmark_attention.py index 264c5d7..c8a29a6 100644 --- a/benchmarks/benchmark_attention.py +++ b/benchmarks/benchmark_attention.py @@ -20,7 +20,7 @@ @util.report( - "attention forward", ["y_size"], [32 * i for i in range(1, 21)], {"num_batches": 64, "num_heads": 8, "x_size": 64} + "attention forward", ["y_size"], [2**i for i in range(5, 10)], {"num_batches": 64, "num_heads": 8, "x_size": 64} ) def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backend): factory_kwargs = {"device": "cuda", "dtype": dtype} @@ -39,8 +39,8 @@ def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backe @util.report( "attention backward", ["y_size"], - [64 * i for i in range(1, 21)], - {"num_batches": 64, "num_heads": 8, "x_size": 64}, + [2**i for i in range(5, 10)], + {"num_batches": 32, "num_heads": 8, "x_size": 64}, ) def bench_attention_backward(num_batches, num_heads, y_size, x_size, dtype, backend): factory_kwargs = {"device": "cuda", "dtype": dtype} diff --git a/trident/operation/attention.py b/trident/operation/attention.py index fa02549..ede7843 100644 --- a/trident/operation/attention.py +++ b/trident/operation/attention.py @@ -28,6 +28,16 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: raise ValueError("The dimension of query, key and value should be 4.") + if ( + not util.is_pow2(query.shape[-2]) + or not util.is_pow2(query.shape[-1]) + or not util.is_pow2(key.shape[-2]) + or not util.is_pow2(key.shape[-1]) + or not util.is_pow2(value.shape[-2]) + or not util.is_pow2(value.shape[-1]) + ): + raise ValueError("Attention supports only for power of 2 size tensors.") + if mask is not None: if is_causal: raise ValueError("Error because both attn_mask and is_causal are set.") @@ -185,7 +195,7 @@ def grid(meta): util.dtype(grad_query.dtype), 64, triton.next_power_of_2(x_size), - num_warps=4 if x_size <= 64 else 8, + num_warps=2, ) util.pop_trace() diff --git a/trident/util/util.py b/trident/util/util.py index 048ec74..7b6d916 100644 --- a/trident/util/util.py +++ b/trident/util/util.py @@ -58,6 +58,10 @@ def dtype(input): raise ValueError(f"Unable to convert the given input: '{input}'.") +def is_pow2(value): + return False if value == 0 else (value & (value - 1)) == 0 + + def size_and_stride(input: torch.Tensor, dim: int): if input.dim() == 2: if dim == 0: