Skip to content

Commit

Permalink
Fix kernel cache miss and add RDNA configs
Browse files Browse the repository at this point in the history
- added Navi configurations (Related PR: ROCm/triton#640)
- resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0
  • Loading branch information
hyoon1 committed Nov 12, 2024
1 parent 8f3bf8b commit 3318139
Showing 1 changed file with 148 additions and 63 deletions.
211 changes: 148 additions & 63 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"""

import subprocess

import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -207,103 +209,186 @@ def _attn_fwd_inner(
return acc, l_i, m_i


@triton.autotune(
configs=[
def get_gfx_version():
try:
# Run the rocminfo command
result = subprocess.run(['rocminfo'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
output = result.stdout

# Parse the output to find the gfx version
for line in output.splitlines():
line = line.strip()
if line.startswith("Name: gfx"):
gfx_version = line.split("Name:")[1].strip()
return gfx_version
except Exception as e:
print(f"Error: {e}")
return None


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target(
).arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908')


def is_rdna():
return is_hip() and triton.runtime.driver.active.get_current_target(
).arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200",
"gfx1201")


def get_cdna_autotune_configs():
return [
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 128,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
'BLOCK_M': 128,
'BLOCK_N': 32,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
# Fall-back config.
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_rdna_autotune_configs():
return [
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=2),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
num_warps=2),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_autotune_configs():
if is_rdna():
return get_rdna_autotune_configs()
elif is_cdna():
return get_cdna_autotune_configs()
else:
raise ValueError("Unknown Device Type")


autotune_configs, autotune_keys = get_autotune_configs()


@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
Expand Down Expand Up @@ -795,8 +880,8 @@ def forward(
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
MAX_SEQLENS_Q=0,
MAX_SEQLENS_K=0,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
Expand Down

0 comments on commit 3318139

Please sign in to comment.