Skip to content

Commit

Permalink
[Kernel] Triton Paged Attn Decode Kernel
Browse files Browse the repository at this point in the history
Signed-off-by: Rahul Batra <[email protected]>
  • Loading branch information
Rahul Batra committed Dec 10, 2024
1 parent d05f886 commit e8a7e95
Show file tree
Hide file tree
Showing 4 changed files with 1,177 additions and 82 deletions.
94 changes: 70 additions & 24 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.attention.ops.triton_paged_attn_decode import (
_SEQ_PARTITION_SIZE as TRITON_PAGED_ATTN_DECODE_PARTITION_SIZE)
from vllm.attention.ops.triton_paged_attn_decode import (
paged_attn_decode_v1 as triton_paged_attn_decode_v1)
from vllm.attention.ops.triton_paged_attn_decode import (
paged_attn_decode_v2 as triton_paged_attn_decode_v2)
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes

Expand Down Expand Up @@ -117,7 +123,8 @@ def ref_single_query_cached_kv_attention(

@pytest.mark.parametrize(
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
["v1", "v2", "triton_v1", "triton_v2"] if not current_platform.is_rocm()
else ["v1", "v2", "rocm", "triton_v1", "triton_v2"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down Expand Up @@ -146,6 +153,7 @@ def test_paged_attention(

current_platform.seed_everything(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
Expand All @@ -157,7 +165,13 @@ def test_paged_attention(
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
if version == "triton_v2":
seq_lens = [
random.randint(TRITON_PAGED_ATTN_DECODE_PARTITION_SIZE + 1,
MAX_SEQ_LEN) for _ in range(num_seqs)
]
else:
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
Expand All @@ -180,38 +194,58 @@ def test_paged_attention(
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Using default kv_scale
k_scale = v_scale = 1.0

# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
if version in ("v1", "triton_v1"):
if version == "v1":
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)

opcheck(torch.ops._C.paged_attention_v1,
opcheck(
torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, 1024),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
key_cache_tri = key_cache.permute(0, 1, 3, 2,
4).flatten(3, 4).contiguous()
value_cache_tri = value_cache.permute(0, 1, 3, 2).contiguous()
triton_paged_attn_decode_v1(
output,
query,
key_cache_tri,
value_cache_tri,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
alibi_slopes,
k_scale,
v_scale,
)

elif version in ("v2", "rocm"):
elif version in ("v2", "rocm", "triton_v2"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
Expand Down Expand Up @@ -253,6 +287,19 @@ def test_paged_attention(
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

elif version == "triton_v2":
key_cache_tri = key_cache.permute(0, 1, 3, 2,
4).flatten(3, 4).contiguous()
value_cache_tri = value_cache.permute(0, 1, 3, 2).contiguous()
num_partitions = (
(max_seq_len + TRITON_PAGED_ATTN_DECODE_PARTITION_SIZE - 1) //
TRITON_PAGED_ATTN_DECODE_PARTITION_SIZE)
assert TRITON_PAGED_ATTN_DECODE_PARTITION_SIZE % block_size == 0
triton_paged_attn_decode_v2(output, query, key_cache_tri,
value_cache_tri, block_tables,
seq_lens, max_seq_len, kv_cache_dtype,
num_kv_heads, scale, alibi_slopes,
k_scale, v_scale, num_partitions)
else:
ops.paged_attention_rocm(
output,
Expand Down Expand Up @@ -281,7 +328,6 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

else:
raise AssertionError(f"Unknown version: {version}")

Expand Down
136 changes: 78 additions & 58 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.attention.ops.triton_paged_attn_decode import (
paged_attn_decode_v1, paged_attn_decode_v2)

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
Expand Down Expand Up @@ -125,68 +128,85 @@ def forward_decode(
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))

use_triton_pa_decode = envs.VLLM_USE_TRITON_PAGED_ATTN_DECODE
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
if use_triton_pa_decode == 1:
paged_attn_decode_v1(
output, query, key_cache, value_cache, block_tables,
seq_lens, max_seq_len, kv_cache_dtype, num_kv_heads, scale,
alibi_slopes, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
else:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)

if use_triton_pa_decode == 1:
paged_attn_decode_v2(
output, query, key_cache, value_cache, block_tables,
seq_lens, max_seq_len, kv_cache_dtype, num_kv_heads, scale,
alibi_slopes, k_scale, v_scale, max_num_partitions,
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
else:
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output

@staticmethod
Expand Down
Loading

0 comments on commit e8a7e95

Please sign in to comment.