Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for normal MLA kernel #17624

Merged
merged 5 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 61 additions & 12 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,49 @@ def mla_absorbed(
)
).reshape(b, s, h_qo, kv_lora_rank)

def mla_normal(
self,
layer_id: int,
q: Tensor,
k: Tensor,
v: Tensor,
compressed_kv: Tensor,
k_pe: Tensor,
attn_score_scaling_factor: float = 1.0,
) -> Tensor:
"""Compute multi-head latent attention with the given data
on the specified layer using the normal flow(WITHOUT weight absorption).
"""
# pylint: disable=protected-access
b, s, h_qo, d_qk = q._expr.struct_info.shape
d_v = v._expr.struct_info.shape[3]
kv_lora_rank = compressed_kv._expr.struct_info.shape[3]
qk_rope_head_dim = k_pe._expr.struct_info.shape[3]
q = q.reshape(b * s, h_qo, d_qk)
k = k.reshape(b * s, h_qo, d_qk)
v = v.reshape(b * s, h_qo, d_v)
compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank)
k_pe = k_pe.reshape(b * s, qk_rope_head_dim)

return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.call_dps_packed(
"vm.builtin.attention_kv_cache_mla_normal",
[
self._expr,
rx.PrimValue(layer_id), # type: ignore[arg-type]
rx.PrimValue(attn_score_scaling_factor),
q._expr,
k._expr,
v._expr,
compressed_kv._expr,
k_pe._expr,
],
out_sinfo=rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype),
)
)
).reshape(b, s, h_qo, d_v)

def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor:
"""Get the in-sequence positions of each slot in the query,
which are needed for applying positional embeddings in some models.
Expand Down Expand Up @@ -591,7 +634,7 @@ def create_mla_kv_cache( # pylint: disable=too-many-locals
rx.PrimValue(0),
bb.add_func(_attention_prefill_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_prefill_mla"),
bb.add_func(_attention_decode_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_decode_mla"),
bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"),
bb.add_func(_attention_prefill_ragged_generic(num_key_value_heads, num_attention_heads, qk_rope_head_dim, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"),
bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target), "tir_attention_prefill_ragged_mla_absorbed"),
bb.add_func(_merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), "tir_attention_merge_state"),
bb.add_func(llama_rope_with_position_map(10000, 1, qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None), "tir_split_rotary"),
Expand Down Expand Up @@ -2420,6 +2463,12 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches


def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target):
return _attention_prefill_ragged_generic(h_kv, h_q, d, d, dtype, rope_scaling, target)


def _attention_prefill_ragged_generic(
h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any], target: Target
):
# pylint: disable=line-too-long
(
NUM_BLKS,
Expand All @@ -2431,7 +2480,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any],
tile_x,
tile_y,
tile_z,
) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
) = _get_prefill_kernel_config(h_kv, h_q, d_qk, dtype, target)

# fmt: off
@T.prim_func
Expand Down Expand Up @@ -2459,14 +2508,14 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
q_rope_position_elem_offset = T.int32(is_size_var=True)
k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)

q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
q = T.match_buffer(var_q, (qo_len, h_q, d_qk), dtype)
q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset)
k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
k = T.match_buffer(var_k, (kv_len, h_kv, d_qk), dtype)
v = T.match_buffer(var_v, (kv_len, h_kv, d_v), dtype)
kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset)
q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset)
k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
output = T.match_buffer(var_output, (qo_len, h_q, d_v), dtype)
lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable

# kernel code
Expand All @@ -2485,13 +2534,13 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
iterator = _var("int32")
kv_chunk_len = _var("int32")

Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared")
K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared")
V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared")
Q_smem = T.alloc_buffer((tile_x, d_qk), dtype, scope="shared")
K_smem = T.alloc_buffer((tile_z, d_qk), dtype, scope="shared")
V_smem = T.alloc_buffer((tile_z, d_v), dtype, scope="shared")
S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared")

S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local")
O_local = T.alloc_buffer((tile_x, d), "float32", scope="local")
O_local = T.alloc_buffer((tile_x, d_v), "float32", scope="local")

m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared")
m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared")
Expand Down Expand Up @@ -2548,7 +2597,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),
_rope(q, q_rope_position[cur_L], d_qk, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),
q[cur_L, cur_H_qo, j]
)
else:
Expand All @@ -2565,7 +2614,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
if cur_L < kv_chunk_len[0]:
K_smem[i, j] = T.if_then_else(
rotary_mode == 1,
_rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype, rope_scaling),
_rope(k, k_rope_pos_offset[b_idx] + cur_L, d_qk, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype, rope_scaling),
k[L_kv_base + cur_L, by, j]
)
else:
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
});

TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray q_data, NDArray k_data,
NDArray v_data, NDArray compressed_kv_data, NDArray k_pe_data,
NDArray o_data) {
kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), std::move(v_data),
std::move(compressed_kv_data), std::move(k_pe_data), std::move(o_data),
attn_score_scaling_factor);
});

// RNN State methods
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
Expand Down
77 changes: 76 additions & 1 deletion src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2241,7 +2241,82 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
double attn_score_scaling_factor) {
// Todo(ruihang): implement it
// Part 1: Basic Checks and Setup.
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
CHECK_GE(local_layer_id, 0);
CHECK_LT(local_layer_id, num_layers_);
NDArray pages = pages_[local_layer_id];
CHECK(q_data.DataType() == pages.DataType());
CHECK(k_data.DataType() == pages.DataType());
CHECK(v_data.DataType() == pages.DataType());
CHECK(compressed_kv_data.DataType() == pages.DataType());
CHECK(k_pe_data.DataType() == pages.DataType());
CHECK(o_data.DataType() == pages.DataType());
CHECK(attn_kinds_[layer_id] == AttnKind::kMLA);

// Expected shapes:
// q_data: (num_total_length, num_qo_heads, qk_head_dim)
// k_data: (num_total_length, num_qo_heads, qk_head_dim)
// v_data: (num_total_length, num_qo_heads, v_head_dim)
// compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim)
// k_pe_data: (num_total_length, qk_rope_head_dim)
// o_data: (num_total_length, num_qo_heads, v_head_dim)
CHECK_EQ(q_data->ndim, 3);
CHECK_EQ(k_data->ndim, 3);
CHECK_EQ(v_data->ndim, 3);
CHECK_EQ(compressed_kv_data->ndim, 2);
CHECK_EQ(k_pe_data->ndim, 2);
CHECK_EQ(o_data->ndim, 3);

int64_t total_seq_length = 0;
for (int64_t i = 0; i < cur_batch_size_; ++i) {
total_seq_length += cur_append_lengths_[i];
}
CHECK_LE(q_data->shape[0], total_seq_length);
CHECK_LE(k_data->shape[0], total_seq_length);
CHECK_LE(v_data->shape[0], total_seq_length);
CHECK_LE(compressed_kv_data->shape[0], total_seq_length);
CHECK_LE(k_pe_data->shape[0], total_seq_length);
CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_);
CHECK_LE(o_data->shape[0], total_seq_length);
CHECK_EQ(q_data->shape[1], num_qo_heads_);
CHECK_EQ(o_data->shape[1], num_qo_heads_);
CHECK_EQ(k_data->shape[1], num_qo_heads_);
CHECK_EQ(v_data->shape[1], num_qo_heads_);
CHECK_EQ(q_data->shape[2], qk_head_dim_);
CHECK_EQ(k_data->shape[2], qk_head_dim_);
CHECK_EQ(v_data->shape[2], v_head_dim_);
CHECK_EQ(o_data->shape[2], v_head_dim_);

// Part 2: Synchronize streams and update auxiliary data.
ComputeStreamWaitForCopyStream();
ICHECK(!dirty_aux_data_device_);

// Append k/v data to kv-cache if flag "append_before_attn" is set.
if (append_before_attn_) {
f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data,
append_position_map_view_);
}

// Part 4: Call the ragged kernel.
// Here, we use f_mla_prefill_ragged_normal_, which is designed to work for both decode
// and normal prefill cases. Optionally, you could check a flag like `use_decode_kernel_[0]`
// to adjust parameters; here we assume the kernel internally supports both cases.
f_mla_prefill_ragged_normal_(q_data, cur_append_length_indptr_view_, k_data, v_data,
cur_append_length_indptr_view_, q_rope_position_map_view_,
k_ragged_rope_pos_offset_view_,
o_data, // output tensor
merged_attn_scores_view_,
/*causal=*/1, static_cast<int>(RoPEMode::kNone),
0, // Rope param, not important
0, // Rope param, not important
attn_score_scaling_factor);

// Part 5: If appending is to occur after attention, call the append kernel.
if (!append_before_attn_) {
f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data,
append_position_map_view_);
}
}

void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
Expand Down