From f8ed456e79c726a490b5223d9ecf4bcbc1811648 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 17 Aug 2023 20:31:28 +0800 Subject: [PATCH] [Fix] Implement movmatrix using warp shuffling for CUDA < 11.8 (#267) --- .../kernels/gemm_s_f16/gemm_template.h | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/src/turbomind/kernels/gemm_s_f16/gemm_template.h b/src/turbomind/kernels/gemm_s_f16/gemm_template.h index 2c43ade3f5..a429ba8536 100644 --- a/src/turbomind/kernels/gemm_s_f16/gemm_template.h +++ b/src/turbomind/kernels/gemm_s_f16/gemm_template.h @@ -26,7 +26,26 @@ mma_m16n8k16_row_col(Array& d, const Array& a, const Array= 11) && (__CUDACC_VER_MINOR__ >= 8) +__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a) { #if TURBOMIND_ARCH_SM75 uint d; @@ -37,6 +56,18 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a) return 0; #endif } +#endif + +__inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id) +{ + +#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) + (void)lane_id; + return transpose_m8n8_b16_movmatrix(a); +#else + return transpose_m8n8_b16_warp_shuffle(a, lane_id); +#endif +} namespace ops { @@ -246,7 +277,7 @@ struct Gemm { // convert to half half2 half_C = __float22half2_rn(frag_C[j * 2 + x]); // transpose 8x8 accum tile - uint trans_C = transpose_m8n8_b16((uint&)half_C); + uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id); // store to global memory OutputOps::template apply(trans_C, mm, nn, C, m, n); }