Skip to content

Commit

Permalink
Made changes to the runtime to support normal kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
annanyapr committed Feb 17, 2025
1 parent e3ac7b5 commit 7569674
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 2 deletions.
45 changes: 44 additions & 1 deletion python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,49 @@ def mla_absorbed(
)
).reshape(b, s, h_qo, kv_lora_rank)

def mla_normal(
self,
layer_id: int,
q: Tensor,
k: Tensor,
v: Tensor,
compressed_kv: Tensor,
k_pe: Tensor,
attn_score_scaling_factor: float = 1.0,
) -> Tensor:
"""Compute multi-head latent attention with the given data
on the specified layer using the normal flow(WITHOUT weight absorption).
"""
# pylint: disable=protected-access
b, s, h_qo, d_qk = q._expr.struct_info.shape
d_v = v._expr.struct_info.shape[3]
kv_lora_rank = compressed_kv._expr.struct_info.shape[3]
qk_rope_head_dim = k_pe._expr.struct_info.shape[3]
q = q.reshape(b * s, h_qo, d_qk)
k = k.reshape(b * s, h_qo, d_qk)
v = v.reshape(b * s, h_qo, d_v)
compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank)
k_pe = k_pe.reshape(b * s, qk_rope_head_dim)

return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.call_dps_packed(
"vm.builtin.attention_kv_cache_mla_normal",
[
self._expr,
rx.PrimValue(layer_id), # type: ignore[arg-type]
rx.PrimValue(attn_score_scaling_factor),
q._expr,
k._expr,
v._expr,
compressed_kv._expr,
k_pe._expr,
],
out_sinfo=rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype),
)
)
).reshape(b, s, h_qo, d_v)

def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor:
"""Get the in-sequence positions of each slot in the query,
which are needed for applying positional embeddings in some models.
Expand Down Expand Up @@ -591,7 +634,7 @@ def create_mla_kv_cache( # pylint: disable=too-many-locals
rx.PrimValue(0),
bb.add_func(_attention_prefill_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_prefill_mla"),
bb.add_func(_attention_decode_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_decode_mla"),
bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"),
bb.add_func(_attention_prefill_ragged_generic(num_key_value_heads, num_attention_heads, qk_rope_head_dim, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"),
bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target), "tir_attention_prefill_ragged_mla_absorbed"),
bb.add_func(_merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), "tir_attention_merge_state"),
bb.add_func(llama_rope_with_position_map(10000, 1, qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None), "tir_split_rotary"),
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
});

TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray q_data, NDArray k_data, NDArray v_data, NDArray compressed_kv_data,
NDArray k_pe_data, NDArray o_data) {
kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), std::move(compressed_kv_data),
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
});

// RNN State methods
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
Expand Down
84 changes: 83 additions & 1 deletion src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2241,7 +2241,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
double attn_score_scaling_factor) {
// Todo(ruihang): implement it
// Part 1: Basic Checks and Setup.
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
CHECK_GE(local_layer_id, 0);
CHECK_LT(local_layer_id, num_layers_);
NDArray pages = pages_[local_layer_id];
CHECK(q_data.DataType() == pages.DataType());
CHECK(k_data.DataType() == pages.DataType());
CHECK(v_data.DataType() == pages.DataType());
CHECK(compressed_kv_data.DataType() == pages.DataType());
CHECK(k_pe_data.DataType() == pages.DataType());
CHECK(o_data.DataType() == pages.DataType());
CHECK(attn_kinds_[layer_id] == AttnKind::kMLA);

// Expected shapes:
// q_data: (num_total_length, num_qo_heads, qk_head_dim)
// k_data: (num_total_length, num_qo_heads, qk_head_dim)
// v_data: (num_total_length, num_qo_heads, v_head_dim)
// compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim)
// k_pe_data: (num_total_length, qk_rope_head_dim)
// o_data: (num_total_length, num_qo_heads, v_head_dim)
CHECK_EQ(q_data->ndim, 3);
CHECK_EQ(k_data->ndim, 3);
CHECK_EQ(v_data->ndim, 3);
CHECK_EQ(compressed_kv_data->ndim, 2);
CHECK_EQ(k_pe_data->ndim, 2);
CHECK_EQ(o_data->ndim, 3);

int64_t total_seq_length = 0;
for (int64_t i = 0; i < cur_batch_size_; ++i) {
total_seq_length += cur_append_lengths_[i];
}
CHECK_LE(q_data->shape[0], total_seq_length);
CHECK_LE(k_data->shape[0], total_seq_length);
CHECK_LE(v_data->shape[0], total_seq_length);
CHECK_LE(compressed_kv_data->shape[0], total_seq_length);
CHECK_LE(k_pe_data->shape[0], total_seq_length);
CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_);
CHECK_LE(o_data->shape[0], total_seq_length);
CHECK_EQ(q_data->shape[1], num_qo_heads_);
CHECK_EQ(o_data->shape[1], num_qo_heads_);
CHECK_EQ(k_data->shape[1], num_qo_heads_);
CHECK_EQ(v_data->shape[1], num_qo_heads_);
CHECK_EQ(q_data->shape[2], qk_head_dim_);
CHECK_EQ(k_data->shape[2], qk_head_dim_);
CHECK_EQ(v_data->shape[2], v_head_dim_);
CHECK_EQ(o_data->shape[2], v_head_dim_);


// Part 2: Synchronize streams and update auxiliary data.
ComputeStreamWaitForCopyStream();
ICHECK(!dirty_aux_data_device_);

// Append k/v data to kv-cache if flag "append_before_attn" is set.
if (append_before_attn_) {
f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data,
append_position_map_view_);
}

// Part 4: Call the ragged kernel.
// Here, we use f_mla_prefill_ragged_normal_, which is designed to work for both decode
// and normal prefill cases. Optionally, you could check a flag like `use_decode_kernel_[0]`
// to adjust parameters; here we assume the kernel internally supports both cases.
f_mla_prefill_ragged_normal_(q_data,
cur_append_length_indptr_view_,
k_data,
v_data,
cur_append_length_indptr_view_,
q_rope_position_map_view_,
k_ragged_rope_pos_offset_view_,
o_data, // output tensor
merged_attn_scores_view_,
/*causal=*/1,
RoPEMode::kNone, // Rope changes have already been applied before the kernel
0, // Rope param, not important
0, // Rope param, not important
attn_score_scaling_factor);

// Part 5: If appending is to occur after attention, call the append kernel.
if (!append_before_attn_) {
f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data,
append_position_map_view_);
}

}

void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
Expand Down

0 comments on commit 7569674

Please sign in to comment.