@@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
384384 }
385385}
386386
387+ template <uint32_t block_size,
388+ uint32_t num_frags_z,
389+ uint32_t NUM_WARP_Q,
390+ typename T>
391+ __device__ __forceinline__ void produce_k_dynamic_scale (
392+ T* k_smem_scale,
393+ T* cache_k_reg,
394+ const int * block_table_now,
395+ const T* cache_k_scale,
396+ const uint32_t kv_idx,
397+ const uint32_t kv_num_heads,
398+ const uint32_t kv_head_idx,
399+ const uint32_t chunk_end
400+ ) {
401+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
402+ if constexpr (NUM_WARP_Q == 4 ) {
403+ // 4 warps shared block_size
404+ const uint32_t tid = ty * 32 + tx;
405+ int block_id = __ldg (&block_table_now[kv_idx / block_size]);
406+ if (block_id < 0 ) block_id = 0 ;
407+ const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
408+ if (tid < block_size) {
409+ k_smem_scale[tid] = cache_k_scale_now[tid];
410+ }
411+ __syncthreads ();
412+ const uint32_t row_id = tx / 4 ;
413+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
414+ cache_k_reg[fz * 2 ] = k_smem_scale[fz * 16 + row_id];
415+ cache_k_reg[fz * 2 + 1 ] = k_smem_scale[fz * 16 + row_id + 8 ];
416+ }
417+ } else {
418+ // 1 warp 32 tokens
419+ const uint32_t kv_idx_now = kv_idx + block_size * ty / 2 ;
420+ int block_id = __ldg (&block_table_now[kv_idx_now / block_size]);
421+ if (block_id < 0 ) block_id = 0 ;
422+ const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
423+ const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
424+ if (kv_idx_this_thread < chunk_end) {
425+ k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2 ) * 32 + tx];
426+ } else {
427+ k_smem_scale[ty * 32 + tx] = 0 ;
428+ }
429+ __syncwarp ();
430+ const uint32_t row_id = tx / 4 ;
431+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
432+ cache_k_reg[fz * 2 ] = k_smem_scale[ty * 32 + fz * 16 + row_id];
433+ cache_k_reg[fz * 2 + 1 ] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8 ];
434+ }
435+ }
436+ }
437+
438+ template <uint32_t block_size,
439+ uint32_t num_frags_z,
440+ uint32_t NUM_WARP_Q,
441+ typename T>
442+ __device__ __forceinline__ void produce_v_dynamic_scale (
443+ T* v_smem_scale,
444+ T* cache_v_reg,
445+ const int * block_table_now,
446+ const T* cache_v_scale,
447+ const uint32_t kv_idx,
448+ const uint32_t kv_num_heads,
449+ const uint32_t kv_head_idx,
450+ const uint32_t chunk_end
451+ ) {
452+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
453+
454+ if constexpr (NUM_WARP_Q == 4 ) {
455+ // 4 warps shared block_size
456+ const uint32_t tid = ty * 32 + tx;
457+ int block_id = __ldg (&block_table_now[kv_idx / block_size]);
458+ if (block_id < 0 ) block_id = 0 ;
459+ const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
460+ if (tid < block_size) {
461+ v_smem_scale[tid] = cache_v_scale_now[tid];
462+ }
463+ __syncthreads ();
464+ const uint32_t row_id = tx % 4 * 2 ;
465+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
466+ cache_v_reg[fz * 4 ] = v_smem_scale[fz * 16 + row_id];
467+ cache_v_reg[fz * 4 + 1 ] = v_smem_scale[fz * 16 + row_id + 1 ];
468+ cache_v_reg[fz * 4 + 2 ] = v_smem_scale[fz * 16 + row_id + 8 ];
469+ cache_v_reg[fz * 4 + 3 ] = v_smem_scale[fz * 16 + row_id + 9 ];
470+ }
471+ } else {
472+ // 1 warp 32 tokens
473+ const uint32_t kv_idx_now = kv_idx + block_size * ty / 2 ;
474+ int block_id = __ldg (&block_table_now[kv_idx_now / block_size]);
475+ if (block_id < 0 ) block_id = 0 ;
476+ const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
477+ const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
478+ if (kv_idx_this_thread < chunk_end) {
479+ v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2 ) * 32 + tx];
480+ } else {
481+ v_smem_scale[ty * 32 + tx] = 0 ;
482+ }
483+ __syncwarp ();
484+ const uint32_t row_id = tx % 4 * 2 ;
485+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
486+ cache_v_reg[fz * 4 ] = v_smem_scale[ty * 32 + fz * 16 + row_id];
487+ cache_v_reg[fz * 4 + 1 ] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1 ];
488+ cache_v_reg[fz * 4 + 2 ] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8 ];
489+ cache_v_reg[fz * 4 + 3 ] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9 ];
490+ }
491+ }
492+ }
493+
387494template <SharedMemFillMode fill_mode,
388495 uint32_t num_warps,
389496 uint32_t block_size,
@@ -816,7 +923,8 @@ template <uint32_t num_frags_x,
816923 typename T,
817924 typename CacheT,
818925 bool is_scale_channel_wise = false ,
819- bool IsFP8=false >
926+ bool IsFP8 = false ,
927+ bool IsDynamicC8 = false >
820928__device__ __forceinline__ void compute_qk_c8 (smem_t * q_smem,
821929 uint32_t * q_smem_offset_r,
822930 smem_t * k_smem,
@@ -860,20 +968,27 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
860968 convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2 ]);
861969 convert_c8<T,IsFP8>(b_frag_dq_T + 4 , b_frag[fy * 2 + 1 ]);
862970 // scale zp
863- if constexpr (is_scale_channel_wise) {
864- const int scale_col = (ky * 2 + fy) * 4 ;
865- b_frag_dq_T[0 ] *= cache_k_scale[scale_col];
866- b_frag_dq_T[1 ] *= cache_k_scale[scale_col + 1 ];
867- b_frag_dq_T[2 ] *= cache_k_scale[scale_col + 2 ];
868- b_frag_dq_T[3 ] *= cache_k_scale[scale_col + 3 ];
869- b_frag_dq_T[4 ] *= cache_k_scale[scale_col];
870- b_frag_dq_T[5 ] *= cache_k_scale[scale_col + 1 ];
871- b_frag_dq_T[6 ] *= cache_k_scale[scale_col + 2 ];
872- b_frag_dq_T[7 ] *= cache_k_scale[scale_col + 3 ];
971+ if constexpr (!IsDynamicC8) {
972+ if constexpr (is_scale_channel_wise) {
973+ const int scale_col = (ky * 2 + fy) * 4 ;
974+ b_frag_dq_T[0 ] *= cache_k_scale[scale_col];
975+ b_frag_dq_T[1 ] *= cache_k_scale[scale_col + 1 ];
976+ b_frag_dq_T[2 ] *= cache_k_scale[scale_col + 2 ];
977+ b_frag_dq_T[3 ] *= cache_k_scale[scale_col + 3 ];
978+ b_frag_dq_T[4 ] *= cache_k_scale[scale_col];
979+ b_frag_dq_T[5 ] *= cache_k_scale[scale_col + 1 ];
980+ b_frag_dq_T[6 ] *= cache_k_scale[scale_col + 2 ];
981+ b_frag_dq_T[7 ] *= cache_k_scale[scale_col + 3 ];
982+ } else {
983+ #pragma unroll
984+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
985+ b_frag_dq_T[b_i] *= cache_k_scale[0 ];
986+ }
987+ }
873988 } else {
874989#pragma unroll
875990 for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
876- b_frag_dq_T[b_i] *= cache_k_scale[0 ];
991+ b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4 ];
877992 }
878993 }
879994#pragma unroll
@@ -1093,7 +1208,9 @@ template <uint32_t num_frags_x,
10931208 uint32_t block_size,
10941209 typename T,
10951210 typename CacheT,
1096- bool is_scale_channel_wise = false , bool IsFP8=false >
1211+ bool is_scale_channel_wise = false ,
1212+ bool IsFP8 = false ,
1213+ bool IsDynamicC8 = false >
10971214__device__ __forceinline__ void compute_sfm_v_c8 (
10981215 smem_t * v_smem,
10991216 uint32_t * v_smem_offset_r,
@@ -1135,16 +1252,28 @@ __device__ __forceinline__ void compute_sfm_v_c8(
11351252 convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2 ]);
11361253 convert_c8<T,IsFP8>(b_frag_dq_T + 4 , b_frag[fz * 2 + 1 ]);
11371254 // scale zp
1138- if constexpr (is_scale_channel_wise) {
1255+ if constexpr (!IsDynamicC8) {
1256+ if constexpr (is_scale_channel_wise) {
11391257#pragma unroll
1140- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1141- b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1142- }
1143- } else {
1258+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1259+ b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1260+ }
1261+ } else {
11441262#pragma unroll
1145- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1146- b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1263+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1264+ b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1265+ }
11471266 }
1267+ } else {
1268+ const int scale_col = (kz * 2 + fz) * 4 ;
1269+ b_frag_dq_T[0 ] *= cache_v_scale[scale_col];
1270+ b_frag_dq_T[1 ] *= cache_v_scale[scale_col + 1 ];
1271+ b_frag_dq_T[2 ] *= cache_v_scale[scale_col + 2 ];
1272+ b_frag_dq_T[3 ] *= cache_v_scale[scale_col + 3 ];
1273+ b_frag_dq_T[4 ] *= cache_v_scale[scale_col];
1274+ b_frag_dq_T[5 ] *= cache_v_scale[scale_col + 1 ];
1275+ b_frag_dq_T[6 ] *= cache_v_scale[scale_col + 2 ];
1276+ b_frag_dq_T[7 ] *= cache_v_scale[scale_col + 3 ];
11481277 }
11491278#pragma unroll
11501279 for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
@@ -1171,7 +1300,9 @@ template <uint32_t num_frags_x,
11711300 uint32_t block_size,
11721301 typename T,
11731302 typename CacheT,
1174- bool is_scale_channel_wise = false , bool IsFP8=false >
1303+ bool is_scale_channel_wise = false ,
1304+ bool IsFP8 = false ,
1305+ bool IsDynamicC8 = false >
11751306__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec (
11761307 smem_t * v_smem,
11771308 uint32_t * v_smem_offset_r,
@@ -1215,16 +1346,28 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
12151346 convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2 ]);
12161347 convert_c8<T,IsFP8>(b_frag_dq_T + 4 , b_frag[fz * 2 + 1 ]);
12171348 // scale zp
1218- if constexpr (is_scale_channel_wise) {
1349+ if constexpr (!IsDynamicC8) {
1350+ if constexpr (is_scale_channel_wise) {
12191351#pragma unroll
1220- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1221- b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1352+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1353+ b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1354+ }
1355+ } else {
1356+ #pragma unroll
1357+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1358+ b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1359+ }
12221360 }
12231361 } else {
1224- #pragma unroll
1225- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1226- b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1227- }
1362+ const int scale_col = (kz * 2 + fz) * 4 ;
1363+ b_frag_dq_T[0 ] *= cache_v_scale[scale_col];
1364+ b_frag_dq_T[1 ] *= cache_v_scale[scale_col + 1 ];
1365+ b_frag_dq_T[2 ] *= cache_v_scale[scale_col + 2 ];
1366+ b_frag_dq_T[3 ] *= cache_v_scale[scale_col + 3 ];
1367+ b_frag_dq_T[4 ] *= cache_v_scale[scale_col];
1368+ b_frag_dq_T[5 ] *= cache_v_scale[scale_col + 1 ];
1369+ b_frag_dq_T[6 ] *= cache_v_scale[scale_col + 2 ];
1370+ b_frag_dq_T[7 ] *= cache_v_scale[scale_col + 3 ];
12281371 }
12291372#pragma unroll
12301373 for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
0 commit comments