diff --git a/src/turbomind/kernels/gemm_s_f16/gemm_template.h b/src/turbomind/kernels/gemm_s_f16/gemm_template.h index 2c43ade3f..a429ba853 100644 --- a/src/turbomind/kernels/gemm_s_f16/gemm_template.h +++ b/src/turbomind/kernels/gemm_s_f16/gemm_template.h @@ -26,7 +26,26 @@ mma_m16n8k16_row_col(Array& d, const Array& a, const Array= 11) && (__CUDACC_VER_MINOR__ >= 8) +__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a) { #if TURBOMIND_ARCH_SM75 uint d; @@ -37,6 +56,18 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a) return 0; #endif } +#endif + +__inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id) +{ + +#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) + (void)lane_id; + return transpose_m8n8_b16_movmatrix(a); +#else + return transpose_m8n8_b16_warp_shuffle(a, lane_id); +#endif +} namespace ops { @@ -246,7 +277,7 @@ struct Gemm { // convert to half half2 half_C = __float22half2_rn(frag_C[j * 2 + x]); // transpose 8x8 accum tile - uint trans_C = transpose_m8n8_b16((uint&)half_C); + uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id); // store to global memory OutputOps::template apply(trans_C, mm, nn, C, m, n); }