From 33181398a9052649fb6c79e3eab767b827e98166 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 25 Oct 2024 13:29:31 -0400 Subject: [PATCH] Fix kernel cache miss and add RDNA configs - added Navi configurations (Related PR: https://github.com/ROCm/triton/pull/640) - resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0 --- vllm/attention/ops/triton_flash_attention.py | 211 +++++++++++++------ 1 file changed, 148 insertions(+), 63 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index f94211116a746..a0efc72d5e777 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -20,6 +20,8 @@ """ +import subprocess + import torch import triton import triton.language as tl @@ -207,103 +209,186 @@ def _attn_fwd_inner( return acc, l_i, m_i -@triton.autotune( - configs=[ +def get_gfx_version(): + try: + # Run the rocminfo command + result = subprocess.run(['rocminfo'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True) + output = result.stdout + + # Parse the output to find the gfx version + for line in output.splitlines(): + line = line.strip() + if line.startswith("Name: gfx"): + gfx_version = line.split("Name:")[1].strip() + return gfx_version + except Exception as e: + print(f"Error: {e}") + return None + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target( + ).arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') + + +def is_rdna(): + return is_hip() and triton.runtime.driver.active.get_current_target( + ).arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", + "gfx1201") + + +def get_cdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), + # Fall-back config. triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'] + + +def get_rdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + num_warps=2), triton.Config( { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'] + + +def get_autotune_configs(): + if is_rdna(): + return get_rdna_autotune_configs() + elif is_cdna(): + return get_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + + +autotune_configs, autotune_keys = get_autotune_configs() + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -795,8 +880,8 @@ def forward( HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, + MAX_SEQLENS_Q=0, + MAX_SEQLENS_K=0, IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model,