Skip to content

Commit 3eb5ad6

Browse files
authored
[KVCache] TIR attention kernel support for MLA (#17618)
This PR introduces the MLA attention kernels written in TIR. It also implements the KV cache MLA computation logic. A new unit test file is added to ensure the correctness of the TIR kernels. This PR also fixes a few TIR prefill kernel tile size initialization.
1 parent 9898039 commit 3eb5ad6

File tree

8 files changed

+2024
-869
lines changed

8 files changed

+2024
-869
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 1317 additions & 578 deletions
Large diffs are not rendered by default.

python/tvm/relax/frontend/nn/llm/tree_attn.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,17 @@ def tree_attn(
320320

321321
bdx = 32
322322
num_warps = 4
323-
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
323+
tile_x, tile_y, tile_z = (
324+
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
325+
d,
326+
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
327+
)
328+
original_tile_y = tile_y
329+
original_tile_z = tile_z
330+
while (tile_x * tile_z) % (bdx * num_warps) != 0:
331+
tile_z += original_tile_z
332+
while (tile_x * tile_y) % (bdx * num_warps) != 0:
333+
tile_y += original_tile_y
324334

325335
# Otherwise we would exceed maxComputeWorkgroupStorageSize
326336
if (
@@ -881,7 +891,17 @@ def tree_attn_with_paged_kv_cache(
881891

882892
bdx = 32
883893
num_warps = 4
884-
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
894+
tile_x, tile_y, tile_z = (
895+
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
896+
d,
897+
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
898+
)
899+
original_tile_y = tile_y
900+
original_tile_z = tile_z
901+
while (tile_x * tile_z) % (bdx * num_warps) != 0:
902+
tile_z += original_tile_z
903+
while (tile_x * tile_y) % (bdx * num_warps) != 0:
904+
tile_y += original_tile_y
885905

886906
# Otherwise we would exceed maxComputeWorkgroupStorageSize
887907
if (

src/runtime/relax_vm/kv_state.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,21 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
7474
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
7575
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
7676
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
77+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla")
78+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKVMLA);
7779
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
7880
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
7981
double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) {
8082
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data),
8183
attn_score_scaling_factor);
8284
});
85+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
86+
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
87+
double attn_score_scaling_factor, NDArray q_data, NDArray compressed_kv_data,
88+
NDArray k_pe_data, NDArray o_data) {
89+
kv_cache->MLAAbsorbed(layer_id, std::move(q_data), std::move(compressed_kv_data),
90+
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
91+
});
8392

8493
// RNN State methods
8594
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);

src/runtime/relax_vm/kv_state.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,20 +181,6 @@ class AttentionKVCacheObj : public KVStateObj {
181181
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
182182
NDArray o_data, double attn_score_scaling_factor) = 0;
183183

184-
/*!
185-
* \brief Compute attention with Q/K/V data.
186-
* \param layer_id The model layer where the attention compute happens.
187-
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`
188-
* \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`
189-
* \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`
190-
* \param mask The input mask data, in layout `(total_sqr_length)`.
191-
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
192-
* \param attn_score_scaling_factor The additional attention scaling factor.
193-
*/
194-
virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data,
195-
NDArray v_data, Optional<NDArray> mask, NDArray o_data,
196-
double attn_score_scaling_factor) = 0;
197-
198184
/*!
199185
* \brief Compute multi-head latent attention after applying weight absorption.
200186
* \param layer_id The model layer where the attention compute happens.
@@ -275,6 +261,16 @@ class AttentionKVCacheObj : public KVStateObj {
275261
virtual void DebugGetKV(int64_t seq_id, //
276262
int64_t start_pos, int64_t end_pos, NDArray k_data, NDArray v_data) = 0;
277263

264+
/*!
265+
* \brief Fetch the compact K/V data of the given sequence for MLA cache.
266+
* \param seq_id The sequence whose K/V data is to be fetched.
267+
* \param start_pos The start position (inclusive) of the K/V data to fetch.
268+
* \param end_pos The end position (exclusive) of the K/V data to fetch.
269+
* \param kv_data The output KV data of the given sequence in layout elaborated above.
270+
*/
271+
virtual void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos,
272+
NDArray kv_data) = 0;
273+
278274
/*!
279275
* \brief Set the K/V data of the given sequence from input K/V data.
280276
* `start_pos` (inclusive) controls starting position of K/V data

0 commit comments

Comments
 (0)