Skip to content

Commit a751d97

Browse files
authored
[Optimization] Fuse get_max_len and get_kv_max_len (#4369)
* opt split_q_block * fuse max_lens and max kv_len
1 parent 425205b commit a751d97

File tree

15 files changed

+29
-116
lines changed

15 files changed

+29
-116
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ void AppendAttentionKernel(
5959
const paddle::Tensor& decoder_tile_ids_per_batch,
6060
const paddle::Tensor& decoder_num_blocks,
6161
const paddle::Tensor& set_max_lengths,
62-
const paddle::Tensor& max_len_kv,
6362
paddle::Tensor& fmha_out,
6463
const paddle::optional<paddle::Tensor>& rotary_embs,
6564
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -103,6 +102,7 @@ void AppendAttentionKernel(
103102
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
104103
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
105104
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
105+
int max_kv_len_this_time = set_max_lengths.data<int>()[8];
106106

107107
auto main_stream = qkv.stream();
108108
static cudaEvent_t main_event;
@@ -245,7 +245,6 @@ void AppendAttentionKernel(
245245

246246
if (max_just_dec_len_this_time > 0) {
247247
int decoder_num_blocks_data = decoder_num_blocks.data<int>()[0];
248-
int max_len_kv_data = max_len_kv.data<int>()[0];
249248

250249
cudaStream_t exec_stream;
251250
if (max_enc_len_this_time > 0) {
@@ -371,20 +370,20 @@ void AppendAttentionKernel(
371370
case paddle::DataType::INT8:{
372371
int8_t tmp;
373372
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
374-
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
373+
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
375374
break;
376375
}
377376
case paddle::DataType::FLOAT8_E4M3FN:{
378377
phi::dtype::float8_e4m3fn tmp;
379378
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
380-
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
379+
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
381380
break;
382381
}
383382
}
384383
} else {
385384
data_t tmp;
386385
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
387-
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
386+
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
388387
}
389388
if (max_enc_len_this_time > 0) {
390389
cudaEventRecord(decoder_event, exec_stream);
@@ -413,7 +412,6 @@ std::vector<paddle::Tensor> AppendAttention(
413412
const paddle::Tensor& decoder_tile_ids_per_batch,
414413
const paddle::Tensor& decoder_num_blocks,
415414
const paddle::Tensor& set_max_lengths,
416-
const paddle::Tensor& max_len_kv,
417415
const paddle::optional<paddle::Tensor>& rotary_embs,
418416
const paddle::optional<paddle::Tensor>& attn_mask,
419417
const paddle::optional<paddle::Tensor>& qkv_bias,
@@ -539,7 +537,6 @@ std::vector<paddle::Tensor> AppendAttention(
539537
decoder_tile_ids_per_batch,
540538
decoder_num_blocks,
541539
set_max_lengths,
542-
max_len_kv,
543540
fmha_out,
544541
rotary_embs,
545542
attn_mask,
@@ -616,7 +613,6 @@ void AppendAttentionWithOutput(
616613
const paddle::Tensor& decoder_tile_ids_per_batch,
617614
const paddle::Tensor& decoder_num_blocks,
618615
const paddle::Tensor& set_max_lengths,
619-
const paddle::Tensor& max_len_kv,
620616
paddle::Tensor& fmha_out,
621617
const paddle::optional<paddle::Tensor>& rotary_embs,
622618
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -695,7 +691,6 @@ void AppendAttentionWithOutput(
695691
decoder_tile_ids_per_batch,
696692
decoder_num_blocks,
697693
set_max_lengths,
698-
max_len_kv,
699694
fmha_out,
700695
rotary_embs,
701696
attn_mask,
@@ -784,7 +779,6 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
784779
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
785780
const std::vector<int64_t>& decoder_num_blocks_shape,
786781
const std::vector<int64_t>& set_max_lengths_shape,
787-
const std::vector<int64_t>& max_len_kv_shape,
788782
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
789783
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
790784
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
@@ -848,7 +842,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
848842
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
849843
const paddle::DataType& decoder_num_blocks_dtype,
850844
const paddle::DataType& set_max_lengths_dtype,
851-
const paddle::DataType& max_len_kv_dtype,
852845
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
853846
const paddle::optional<paddle::DataType>& attn_mask_dtype,
854847
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
@@ -930,7 +923,6 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
930923
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
931924
const std::vector<int64_t>& decoder_num_blocks_shape,
932925
const std::vector<int64_t>& set_max_lengths_shape,
933-
const std::vector<int64_t>& max_len_kv_shape,
934926
const std::vector<int64_t>& fmha_out_shape,
935927
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
936928
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
@@ -987,7 +979,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
987979
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
988980
const paddle::DataType& decoder_num_blocks_dtype,
989981
const paddle::DataType& set_max_lengths_dtype,
990-
const paddle::DataType& max_len_kv_dtype,
991982
const paddle::DataType& fmha_out_dtype,
992983
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
993984
const paddle::optional<paddle::DataType>& attn_mask_dtype,
@@ -1046,7 +1037,6 @@ PD_BUILD_STATIC_OP(append_attention)
10461037
"decoder_tile_ids_per_batch",
10471038
"decoder_num_blocks",
10481039
"set_max_lengths",
1049-
"max_len_kv",
10501040
paddle::Optional("rotary_embs"),
10511041
paddle::Optional("attn_mask"),
10521042
paddle::Optional("qkv_bias"),
@@ -1105,7 +1095,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
11051095
"decoder_tile_ids_per_batch",
11061096
"decoder_num_blocks",
11071097
"set_max_lengths",
1108-
"max_len_kv",
11091098
"fmha_out",
11101099
paddle::Optional("rotary_embs"),
11111100
paddle::Optional("attn_mask"),

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
template <int THREADBLOCK_SIZE>
2121
__global__ void
22-
GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
22+
GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
2323
const int *seq_lens_encoder,
2424
const int *seq_lens_this_time_merged,
2525
const int *seq_lens_encoder_merged, const int *seq_mapping,
@@ -37,41 +37,27 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
3737
int max_just_dec_merged_len_this_time_this_thread = 0;
3838
int max_system_len_this_thread = 0;
3939
int max_dec_len_without_system_this_thread = 0;
40+
int max_len_kv_this_thread = 0;
4041
for (int i = tid; i < batch_size; i += blockDim.x) {
4142
const int seq_len_this_time = seq_lens_this_time[i];
43+
const int seq_len_decoder = seq_lens_decoder[i];
4244
max_len_this_time_this_thread =
4345
max(seq_len_this_time, max_len_this_time_this_thread);
4446
max_len_encoder_this_thread =
4547
max(seq_lens_encoder[i], max_len_encoder_this_thread);
46-
max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread);
48+
max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread);
4749
if (seq_len_this_time <= 0)
4850
continue;
49-
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i];
51+
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
5052
max_len_this_thread =
51-
max(seq_lens[i] + seq_len_this_time, max_len_this_thread);
53+
max(seq_len_decoder + seq_len_this_time, max_len_this_thread);
5254
max_just_dec_len_this_thread =
5355
max(max_just_dec_len_this_thread, max_just_dec_len_now);
54-
if (system_lens) {
55-
const int real_bid = seq_mapping[i];
56-
const int system_len_now = system_lens[real_bid];
57-
max_system_len_this_thread =
58-
max(max_system_len_this_thread, system_len_now);
59-
max_dec_len_without_system_this_thread =
60-
max(max_dec_len_without_system_this_thread,
61-
max_just_dec_len_now - system_len_now);
62-
}
63-
}
64-
if (system_lens) {
65-
for (int i = tid; i < batch_size; i += blockDim.x) {
66-
const int ori_seq_len_this_time = seq_lens_this_time_merged[i];
67-
if (ori_seq_len_this_time <= 0)
68-
continue;
69-
const int max_just_dec_merged_len_this_time_now =
70-
seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time;
71-
max_just_dec_merged_len_this_time_this_thread =
72-
max(max_just_dec_merged_len_this_time_this_thread,
73-
max_just_dec_merged_len_this_time_now);
74-
}
56+
57+
if (seq_len_decoder == 0)
58+
continue;
59+
max_len_kv_this_thread =
60+
max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread);
7561
}
7662
int total_max_len_this_time =
7763
BlockReduce(temp_storage)
@@ -94,6 +80,8 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
9480
int total_dec_len_without_system =
9581
BlockReduce(temp_storage)
9682
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
83+
int total_max_len_kv =
84+
BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp<int>());
9785
if (tid == 0) {
9886
max_lens[0] = total_max_len_this_time;
9987
max_lens[1] = total_max_len_encoder;
@@ -103,6 +91,7 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
10391
max_lens[5] = total_just_dec_merged;
10492
max_lens[6] = total_system_len;
10593
max_lens[7] = total_dec_len_without_system;
94+
max_lens[8] = total_max_len_kv;
10695
}
10796
}
10897

@@ -256,29 +245,6 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
256245
}
257246
}
258247

259-
template <int THREADBLOCK_SIZE>
260-
__global__ void
261-
get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
262-
const int *seq_lens_decoder, const int batch_size) {
263-
const int tid = threadIdx.x;
264-
265-
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
266-
__shared__ typename BlockReduce::TempStorage temp_storage;
267-
268-
int max_len_this_thread = 0;
269-
for (int i = tid; i < batch_size; i += blockDim.x) {
270-
if (seq_lens_decoder[i] == 0)
271-
continue;
272-
max_len_this_thread =
273-
max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
274-
}
275-
int total =
276-
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
277-
if (tid == 0) {
278-
*max_seq_lens_out = total;
279-
}
280-
}
281-
282248
void GetBlockShapeAndSplitKVBlock(
283249
const paddle::Tensor &seq_lens_encoder,
284250
const paddle::Tensor &seq_lens_decoder,
@@ -295,7 +261,6 @@ void GetBlockShapeAndSplitKVBlock(
295261
paddle::Tensor &kv_batch_ids, // Inplace
296262
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
297263
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
298-
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
299264
const int encoder_block_shape_q,
300265
const int decoder_block_shape_q,
301266
const int group_size,
@@ -319,15 +284,7 @@ void GetBlockShapeAndSplitKVBlock(
319284
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
320285
int max_system_len = max_len_cpu_ptr[6];
321286
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
322-
323-
auto max_len_kv =
324-
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
325-
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
326-
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
327-
seq_lens_decoder.data<int>(), bsz);
328-
329-
330-
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
287+
int max_kv_len_this_time = max_len_cpu_ptr[8];
331288

332289
// decoder
333290
if (max_dec_len_this_time > 0) {
@@ -430,7 +387,7 @@ void GetBlockShapeAndSplitKVBlock(
430387
decoder_num_blocks_cpu.copy_(
431388
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
432389
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
433-
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
390+
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
434391
}
435392
} else {
436393
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
@@ -492,7 +449,6 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
492449
"kv_batch_ids",
493450
"kv_tile_ids_per_batch",
494451
"kv_num_blocks_x_cpu",
495-
"max_len_kv_cpu"
496452
})
497453
.Outputs({
498454

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ std::vector<paddle::Tensor> AppendAttention(
6464
const paddle::Tensor &decoder_batch_ids,
6565
const paddle::Tensor &decoder_tile_ids_per_batch,
6666
const paddle::Tensor &decoder_num_blocks_cpu,
67-
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
67+
const paddle::Tensor &set_max_lengths,
6868
const paddle::optional<paddle::Tensor> &rotary_embs,
6969
const paddle::optional<paddle::Tensor> &attn_mask,
7070
const paddle::optional<paddle::Tensor> &qkv_bias,
@@ -106,7 +106,7 @@ void AppendAttentionWithOutput(
106106
const paddle::Tensor &decoder_batch_ids,
107107
const paddle::Tensor &decoder_tile_ids_per_batch,
108108
const paddle::Tensor &decoder_num_blocks_cpu,
109-
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
109+
const paddle::Tensor &set_max_lengths,
110110
paddle::Tensor &fmha_out,
111111
const paddle::optional<paddle::Tensor> &rotary_embs,
112112
const paddle::optional<paddle::Tensor> &attn_mask,
@@ -315,7 +315,6 @@ void GetBlockShapeAndSplitKVBlock(
315315
paddle::Tensor &kv_batch_ids, // Inplace
316316
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
317317
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
318-
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
319318
const int encoder_block_shape_q,
320319
const int decoder_block_shape_q,
321320
const int group_size,

fastdeploy/model_executor/forward_meta.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ class ForwardMeta:
119119
kv_tile_ids_per_batch: Optional[paddle.Tensor] = None
120120
# The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x.
121121
kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None
122-
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
123-
max_len_kv_cpu: Optional[paddle.Tensor] = None
124122

125123
decoder_chunk_size_device: Optional[paddle.Tensor] = None
126124

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
150150
forward_meta.kv_batch_ids,
151151
forward_meta.kv_tile_ids_per_batch,
152152
forward_meta.kv_num_blocks_x_cpu,
153-
forward_meta.max_len_kv_cpu,
154153
self.encoder_block_shape_q,
155154
self.decoder_block_shape_q,
156155
self.group_size,
@@ -291,7 +290,6 @@ def forward_mixed(
291290
forward_meta.decoder_tile_ids_per_batch,
292291
forward_meta.decoder_num_blocks_cpu,
293292
forward_meta.max_len_tensor_cpu,
294-
forward_meta.max_len_kv_cpu,
295293
res,
296294
metadata.rotary_embs,
297295
metadata.attn_mask,
@@ -347,7 +345,6 @@ def forward_mixed(
347345
forward_meta.decoder_tile_ids_per_batch,
348346
forward_meta.decoder_num_blocks_cpu,
349347
forward_meta.max_len_tensor_cpu,
350-
forward_meta.max_len_kv_cpu,
351348
metadata.rotary_embs,
352349
metadata.attn_mask,
353350
layer.qkv_bias,

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
207207
forward_meta.kv_batch_ids,
208208
forward_meta.kv_tile_ids_per_batch,
209209
forward_meta.kv_num_blocks_x_cpu,
210-
forward_meta.max_len_kv_cpu,
211210
self.encoder_block_shape_q,
212211
self.decoder_block_shape_q,
213212
self.group_size,
@@ -340,7 +339,6 @@ def forward_mixed(
340339
forward_meta.decoder_tile_ids_per_batch, # from buffer
341340
forward_meta.decoder_num_blocks_cpu,
342341
metadata.max_len_tensor_cpu_decoder,
343-
forward_meta.max_len_kv_cpu,
344342
metadata.rotary_embs,
345343
forward_meta.attn_mask,
346344
layer.qkv_bias,

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class MLAAttentionMetadata(AttentionMetadata):
8383

8484
max_enc_len_this_time: Optional[paddle.Tensor] = None
8585
max_dec_len_this_time: Optional[paddle.Tensor] = None
86+
max_kv_len_this_time: Optional[paddle.Tensor] = None
8687

8788

8889
class MLAAttentionBackend(AttentionBackend):
@@ -199,7 +200,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
199200
forward_meta.kv_batch_ids,
200201
forward_meta.kv_tile_ids_per_batch,
201202
forward_meta.kv_num_blocks_x_cpu,
202-
forward_meta.max_len_kv_cpu,
203203
self.encoder_block_shape_q,
204204
self.decoder_block_shape_q,
205205
self.group_size,
@@ -210,6 +210,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
210210
# MLA
211211
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
212212
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
213+
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
213214

214215
# pd_disaggregation
215216
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -362,7 +363,7 @@ def forward_decode(
362363
forward_meta.decoder_num_blocks_device,
363364
forward_meta.decoder_chunk_size_device,
364365
metadata.max_dec_len_this_time,
365-
forward_meta.max_len_kv_cpu,
366+
metadata.max_kv_len_this_time,
366367
None, # attn_mask
367368
None, # qkv_bias
368369
None, # qkv_out_scales
@@ -478,7 +479,7 @@ def forward_mixed(
478479
forward_meta.decoder_num_blocks_device,
479480
forward_meta.decoder_chunk_size_device,
480481
metadata.max_dec_len_this_time,
481-
forward_meta.max_len_kv_cpu,
482+
metadata.max_kv_len_this_time,
482483
None, # attn_mask
483484
None, # qkv_bias
484485
None, # qkv_out_scales

0 commit comments

Comments
 (0)