@@ -1064,7 +1064,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_fp8(
1064
1064
double hi_freq_factor = 32 ,
1065
1065
bool write_k_back = false ,
1066
1066
bool k_norm = false ,
1067
- float * amax = nullptr ) {
1067
+ float * amax = nullptr ,
1068
+ bool * is_precalculated_qparam = nullptr ) {
1068
1069
// Launch b_t_(sum(h)) warps.
1069
1070
auto b_t_hh = blockIdx .x * blockDim .y +
1070
1071
threadIdx .y ; // Block = [kThreadsPerWarp, kWarpsPerBlock]
@@ -1167,7 +1168,10 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_fp8(
1167
1168
*reinterpret_cast <uint2 *>(&dst_bf16);
1168
1169
}
1169
1170
// This kernel does not write to xq_o
1170
- if (qkv != QKV::Q) {
1171
+ bool is_precalculated_qparam_b_t = is_precalculated_qparam
1172
+ ? is_precalculated_qparam[b_t ]
1173
+ : true ; // for decode it is true
1174
+ if (qkv != QKV::Q && is_precalculated_qparam_b_t ) {
1171
1175
// only write to cache if batch lane has a pre-calculated qparam
1172
1176
// quantize and write to dst_row
1173
1177
CUDA_KERNEL_ASSERT (qparam_k_ptr != nullptr )
@@ -1182,13 +1186,9 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_fp8(
1182
1186
}
1183
1187
quantize_fp8_kv<at::Float8_e4m3fn, KVQuantRecipe::perHeadScaling>(
1184
1188
dst, dst_row_q, qparam_row);
1185
- }
1186
- // amax calculation for QKV
1187
- // TODO: avoid amax calculation for KV if batch lane has a pre-calculated
1188
- // qparam if (qkv != QKV::Q && is_precalculated_qparam[b_t] ){
1189
- // return;
1190
- // }
1191
- if (amax != nullptr ) {
1189
+ } else {
1190
+ // qkv == Q or qparam is not precalculated
1191
+ CUDA_KERNEL_ASSERT (amax != nullptr );
1192
1192
// per_row_amax(dst, &amax[b_t * HH + hh]);
1193
1193
per_head_amax (dst, &amax[b * HH + hh]);
1194
1194
}
@@ -1211,7 +1211,8 @@ at::Tensor nope_qkv_varseq_prefill(
1211
1211
std::optional<at::Tensor> qparam_v = std::nullopt,
1212
1212
bool k_norm = false ,
1213
1213
bool update_kv = true ,
1214
- std::optional<at::Tensor> amax_qkv = std::nullopt) {
1214
+ std::optional<at::Tensor> amax_qkv = std::nullopt,
1215
+ std::optional<at::Tensor> kv_quant_scale_precomputed = std::nullopt) {
1215
1216
auto B_T = XQ.size (0 );
1216
1217
auto N_H = XQ.size (1 );
1217
1218
@@ -1292,9 +1293,14 @@ at::Tensor nope_qkv_varseq_prefill(
1292
1293
CUDA_KERNEL_ASSERT (num_groups_ == 1 );
1293
1294
if (cache_K.dtype () == at::kFloat8_e4m3fn ) {
1294
1295
float * amax_ptr = nullptr ;
1296
+ bool * is_precalculated_qparam = nullptr ;
1295
1297
if (amax_qkv.has_value ()) {
1296
1298
amax_ptr = static_cast <float *>(amax_qkv.value ().data_ptr ());
1297
1299
}
1300
+ if (kv_quant_scale_precomputed.has_value ()) {
1301
+ is_precalculated_qparam =
1302
+ static_cast <bool *>(kv_quant_scale_precomputed.value ().data_ptr ());
1303
+ }
1298
1304
rope_xpos_qkv_varseq_prefill_kernel_fp8<
1299
1305
PositionEmbeddingMode::NOPE,
1300
1306
CacheLogicalDtype::FP8,
@@ -1331,7 +1337,8 @@ at::Tensor nope_qkv_varseq_prefill(
1331
1337
0 ,
1332
1338
true , // write_k_back and q too if we are doing norm.
1333
1339
k_norm,
1334
- amax_ptr);
1340
+ amax_ptr,
1341
+ is_precalculated_qparam);
1335
1342
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1336
1343
} else {
1337
1344
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
@@ -1530,7 +1537,8 @@ at::Tensor nope_qkv_decoding(
1530
1537
0 ,
1531
1538
true , // write_k_back and q too if we are doing norm.
1532
1539
k_norm,
1533
- amax_ptr);
1540
+ amax_ptr,
1541
+ nullptr );
1534
1542
1535
1543
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1536
1544
@@ -1618,7 +1626,8 @@ at::Tensor rope_qkv_varseq_prefill(
1618
1626
bool write_k_back = false ,
1619
1627
bool k_norm = false ,
1620
1628
bool update_kv = true ,
1621
- std::optional<at::Tensor> amax_qkv = std::nullopt) {
1629
+ std::optional<at::Tensor> amax_qkv = std::nullopt,
1630
+ std::optional<at::Tensor> kv_quant_scale_precomputed = std::nullopt) {
1622
1631
auto B_T = XQ.size (0 );
1623
1632
auto N_H = XQ.size (1 );
1624
1633
auto N_KVH = 0 ;
@@ -1707,9 +1716,14 @@ at::Tensor rope_qkv_varseq_prefill(
1707
1716
CUDA_KERNEL_ASSERT (num_groups_ == 1 );
1708
1717
if (cache_K.dtype () == at::kFloat8_e4m3fn ) {
1709
1718
float * amax_ptr = nullptr ;
1719
+ bool * is_precalculated_qparam = nullptr ;
1710
1720
if (amax_qkv.has_value ()) {
1711
1721
amax_ptr = static_cast <float *>(amax_qkv.value ().data_ptr ());
1712
1722
}
1723
+ if (kv_quant_scale_precomputed.has_value ()) {
1724
+ is_precalculated_qparam =
1725
+ static_cast <bool *>(kv_quant_scale_precomputed.value ().data_ptr ());
1726
+ }
1713
1727
rope_xpos_qkv_varseq_prefill_kernel_fp8<
1714
1728
PositionEmbeddingMode::ROPE,
1715
1729
CacheLogicalDtype::FP8,
@@ -1746,7 +1760,8 @@ at::Tensor rope_qkv_varseq_prefill(
1746
1760
hi_freq_factor,
1747
1761
true ,
1748
1762
k_norm,
1749
- amax_ptr);
1763
+ amax_ptr,
1764
+ is_precalculated_qparam);
1750
1765
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1751
1766
1752
1767
} else {
@@ -2113,7 +2128,8 @@ at::Tensor rope_qkv_decoding(
2113
2128
hi_freq_factor,
2114
2129
true ,
2115
2130
k_norm,
2116
- amax_ptr);
2131
+ amax_ptr,
2132
+ nullptr );
2117
2133
2118
2134
C10_CUDA_KERNEL_LAUNCH_CHECK ();
2119
2135
} else {
@@ -2849,8 +2865,7 @@ __global__ void quantizeQKVPerHead(
2849
2865
at::BFloat16* xqkv, // [B_T, HH, D_H]
2850
2866
const int32_t * varseq_seqpos, // [B_T]
2851
2867
const int32_t * varseq_batch, // [B_T]
2852
- at::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
2853
- q_seqstarts, // [B+1]
2868
+ const bool * is_precalculated_qparam, // [B_T]
2854
2869
at::PackedTensorAccessor64<at::Float8_e4m3fn, 3 , at::RestrictPtrTraits>
2855
2870
XQ_O, // [B_T][N_H][D]
2856
2871
at::PackedTensorAccessor64<at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>
@@ -2913,8 +2928,9 @@ __global__ void quantizeQKVPerHead(
2913
2928
}
2914
2929
// Skip quantization if scale is pre-calculated for K/V
2915
2930
// as in decode and partial prefill cases
2916
- int start = q_seqstarts[b];
2917
- if (qkv != QKV::Q && varseq_seqpos[start] > 0 )
2931
+ bool is_precalculated_qparam_b_t =
2932
+ is_precalculated_qparam ? is_precalculated_qparam[b_t ] : true ;
2933
+ if (qkv != QKV::Q && is_precalculated_qparam_b_t )
2918
2934
return ;
2919
2935
2920
2936
CUDA_KERNEL_ASSERT (uintptr_t (qparam) % 4 == 0 );
@@ -2965,11 +2981,11 @@ at::Tensor quantize_qkv_per_head(
2965
2981
at::Tensor xqkv, // [B_T, HH, D_H]
2966
2982
at::Tensor varseq_seqpos, // [B_T]
2967
2983
std::optional<at::Tensor> varseq_batch, // [B_T]
2968
- at::Tensor q_seqstarts , // [B+1 ]
2984
+ std::optional< at::Tensor> is_precalculated_qparam , // [B_T ]
2969
2985
at::Tensor cache_K, // [B][MAX_T][N_KVH][D_H]
2970
2986
at::Tensor cache_V, // [B][MAX_T][N_KVH][D_H]
2971
2987
at::Tensor XQ_O, // [B_T][N_H][D]
2972
- int64_t max_seq_length , // Length of the sequence
2988
+ int64_t B , // Batch size
2973
2989
std::optional<at::Tensor> qparam_k,
2974
2990
std::optional<at::Tensor> qparam_v) {
2975
2991
auto B_T = XQ_O.size (0 );
@@ -2980,16 +2996,14 @@ at::Tensor quantize_qkv_per_head(
2980
2996
float * qparam_v_ptr = nullptr ;
2981
2997
if (qparam_k.has_value ()) {
2982
2998
// prefill case
2983
- // HH += N_KVH_L * 2;
2984
2999
qparam_k_ptr = qparam_k.value ().data_ptr <float >();
2985
3000
qparam_v_ptr = qparam_v.value ().data_ptr <float >();
2986
3001
}
2987
3002
auto num_warps = B_T * HH;
2988
3003
dim3 block_size (kThreadsPerWarp , kWarpsPerBlock );
2989
3004
dim3 grid_size (cuda_calc_xblock_count (num_warps, kWarpsPerBlock ));
2990
3005
2991
- auto scale_q = at::zeros (
2992
- {q_seqstarts.size (0 ) - 1 , N_KVH_L}, XQ_O.options ().dtype (at::kFloat ));
3006
+ auto scale_q = at::zeros ({B, N_KVH_L}, XQ_O.options ().dtype (at::kFloat ));
2993
3007
float * const scale_q_ptr = scale_q.data_ptr <float >();
2994
3008
// Launch the kernel
2995
3009
// TODO: Launch the kernel with B_T * N_H_L blocks only in case of decode.
@@ -3005,7 +3019,9 @@ at::Tensor quantize_qkv_per_head(
3005
3019
varseq_seqpos.data_ptr<int32_t>(),
3006
3020
varseq_batch.has_value() ? varseq_batch.value().data_ptr<int32_t>()
3007
3021
: nullptr, // not needed for decode
3008
- q_seqstarts.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
3022
+ is_precalculated_qparam.has_value()
3023
+ ? is_precalculated_qparam.value().data_ptr<bool>()
3024
+ : nullptr,
3009
3025
XQ_O.packed_accessor64<at::Float8_e4m3fn, 3, at::RestrictPtrTraits>(),
3010
3026
cache_K.packed_accessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>(),
3011
3027
cache_V.packed_accessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>(),
0 commit comments