Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,9 +995,13 @@ def batch_prefill_paged_kv_cpu(
for d_idx in T.serial(d):

Q_local[d_idx] = T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[curl_q], d, rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling),
q[curl_q, h_qo, d_idx]
q_rope_position[curl_q] != -1,
T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[curl_q], d, rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling),
q[curl_q, h_qo, d_idx]
),
0.0
)
for row_idx in T.serial(max_num_pages * page_size):
if row_idx < kv_chunk_len[0]:
Expand Down Expand Up @@ -1048,7 +1052,10 @@ def batch_prefill_paged_kv_cpu(
# Store Output
for d_idx in T.serial(d):
O_local[d_idx] = O_local[d_idx] /d_val[0]
output[curl_q, h_qo, d_idx] = O_local[d_idx]
if q_rope_position[curl_q] != -1:
output[curl_q, h_qo, d_idx] = O_local[d_idx]
else:
output[curl_q, h_qo, d_idx] = 0.0
lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0])
return batch_prefill_paged_kv_cpu

Expand Down Expand Up @@ -1358,9 +1365,13 @@ def batch_prefill_paged_kv(
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),
q[cur_L, cur_H_qo, j]
q_rope_position[cur_L] != -1,
T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),
q[cur_L, cur_H_qo, j]
),
0.0
)
else:
Q_smem[i, j] = 0.0
Expand Down Expand Up @@ -1477,7 +1488,10 @@ def batch_prefill_paged_kv(
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]
if q_rope_position[cur_L] != -1:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]
else:
output[cur_L, cur_H_qo, j] = 0.0

# Store LSE to gmem
for li in T.grid(tile_x):
Expand Down Expand Up @@ -1607,9 +1621,13 @@ def batch_decode_paged_kv(

for d in T.serial(D):
Q_local[d] = T.if_then_else(
rotary_mode == 1,
_rope(Q, q_rope_position[b], head_dim, rope_theta, rope_scale, (b, h_qo, d), qkv_dtype, rope_scaling),
Q[b, h_qo, d],
q_rope_position[b] != -1,
T.if_then_else(
rotary_mode == 1,
_rope(Q, q_rope_position[b], head_dim, rope_theta, rope_scale, (b, h_qo, d), qkv_dtype, rope_scaling),
Q[b, h_qo, d],
),
0.0
)

for row_idx in T.serial(kv_chunk_len[0]):
Expand Down Expand Up @@ -1647,7 +1665,10 @@ def batch_decode_paged_kv(
O_local[d] = O_local[d] + V_local[d] * factor[0]
for d in T.serial(D):
O_local[d] = O_local[d] / d_val[0]
output[b, h_qo, d] = O_local[d]
if q_rope_position[b] != -1:
output[b, h_qo, d] = O_local[d]
else:
output[b, h_qo, d] = 0.0
lse[b, h_qo] = m_val[0] + T.log2(d_val[0])
# fmt: on
# pylint: enable=line-too-long
Expand Down Expand Up @@ -1796,9 +1817,13 @@ def batch_decode_paged_kv(
# load q
for vec in T.vectorized(VEC_SIZE):
Q_local[vec] = T.if_then_else(
rotary_mode == 1,
_rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype, rope_scaling),
Q[bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec]
q_rope_position[batch_idx] != -1,
T.if_then_else(
rotary_mode == 1,
_rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype, rope_scaling),
Q[bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec]
),
0.0
)

for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)):
Expand Down Expand Up @@ -1902,7 +1927,10 @@ def batch_decode_paged_kv(

# store O to global memory
for vec in T.vectorized(VEC_SIZE):
output[batch_idx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec]
if q_rope_position[batch_idx] != -1:
output[batch_idx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec]
else:
output[batch_idx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] = 0.0

# store lse to global memory
lse[batch_idx, by * GROUP_SIZE + bz * bdy + ty] = st_m[0] + T.log2(st_d[0])
Expand Down Expand Up @@ -2367,9 +2395,13 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
result[0] = 0.0
for d_idx in T.serial(d_qk):
query_val[0] = T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[q_indptr[b] + q_idx], d_qk, rope_theta, rope_scale, (q_indptr[b] + q_idx, h, d_idx), dtype, rope_scaling),
q[q_indptr[b] + q_idx, h, d_idx],
q_rope_position[q_indptr[b] + q_idx] != -1,
T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[q_indptr[b] + q_idx], d_qk, rope_theta, rope_scale, (q_indptr[b] + q_idx, h, d_idx), dtype, rope_scaling),
q[q_indptr[b] + q_idx, h, d_idx],
),
0.0,
)

key_val[0] = T.if_then_else(
Expand Down Expand Up @@ -2543,9 +2575,13 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[cur_L], d_qk, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),
q[cur_L, cur_H_qo, j]
q_rope_position[cur_L] != -1,
T.if_then_else(
rotary_mode == 1,
_rope(q, q_rope_position[cur_L], d_qk, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),
q[cur_L, cur_H_qo, j]
),
0.0
)
else:
Q_smem[i, j] = 0.0
Expand Down
11 changes: 6 additions & 5 deletions src/runtime/vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,17 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_method("vm.builtin.kv_state_popn", &KVStateObj::PopN)
.def_packed("vm.builtin.kv_state_begin_forward",
[](ffi::PackedArgs args, ffi::Any* rv) {
CHECK(args.size() == 3 || args.size() == 4)
<< "KVState BeginForward only accepts 3 or 4 arguments";
CHECK(args.size() == 4 || args.size() == 5)
<< "KVState BeginForward only accepts 4 or 5 arguments";
KVState kv_state = args[0].cast<KVState>();
ffi::Shape seq_ids = args[1].cast<ffi::Shape>();
ffi::Shape append_lengths = args[2].cast<ffi::Shape>();
int64_t seqlen_padding_factor = args[3].cast<int64_t>();
ffi::Optional<ffi::Shape> token_tree_parent_ptr;
if (args.size() == 4) {
token_tree_parent_ptr = args[3].cast<ffi::Optional<ffi::Shape>>();
if (args.size() == 5) {
token_tree_parent_ptr = args[4].cast<ffi::Optional<ffi::Shape>>();
}
kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr);
kv_state->BeginForward(seq_ids, append_lengths, seqlen_padding_factor, token_tree_parent_ptr);
})
.def_method("vm.builtin.kv_state_end_forward", &KVStateObj::EndForward);
}
Expand Down
9 changes: 5 additions & 4 deletions src/runtime/vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ class KVStateObj : public Object {
* in the model forward function.
* \param seq_ids The ids of the sequence to run in the incoming model forward.
* \param append_lengths The sequence lengths to run forward for for each sequence.
* \param seqlen_padding_factor The padding factor of the sequences in the current round of forwarding.
* \param token_tree_parent_ptr The parent idx array of the token trees. Its length
* is the sum of "append_lengths". Nullptr means the token tree of each sequence
* is a chain.
*/
virtual void BeginForward(
const IntTuple& seq_ids, const IntTuple& append_lengths,
const ffi::Optional<IntTuple>& token_tree_parent_ptr = std::nullopt) = 0;
virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
const int64_t seqlen_padding_factor = 0,
const ffi::Optional<IntTuple>& token_tree_parent_ptr = std::nullopt) = 0;

/*!
* \brief Mark the start of the forward function.
Expand Down Expand Up @@ -159,7 +160,7 @@ class AttentionKVCacheObj : public KVStateObj {
const IntTuple& leaf_indices) = 0;

/*! \brief Prepare for the disaggregation KV data receive for the specified sequence and length.*/
virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;
virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length, int64_t seqlen_padding_factor = 0) = 0;

/*! \brief Mark which tokens' KV cache needs to be sent to other devices */
virtual void DisaggMarkSend(int64_t seq_id, int64_t begin,
Expand Down
Loading