Skip to content

Commit

Permalink
[Fix] Implement movmatrix using warp shuffling for CUDA < 11.8 (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz authored Aug 17, 2023
1 parent 903707b commit f8ed456
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions src/turbomind/kernels/gemm_s_f16/gemm_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,26 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
#endif
}

__inline__ __device__ uint transpose_m8n8_b16(uint a)
__inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id)
{
int src_lane = lane_id / 8 + lane_id % 4 * 8;
uint u0 = __shfl_sync(0xffffffff, value, src_lane);
uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4);
short2 r;

if (lane_id % 8 < 4) {
r.x = ((short2&)u0).x;
r.y = ((short2&)u1).x;
}
else {
r.x = ((short2&)u0).y;
r.y = ((short2&)u1).y;
}
return (uint&)r;
}

#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
{
#if TURBOMIND_ARCH_SM75
uint d;
Expand All @@ -37,6 +56,18 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a)
return 0;
#endif
}
#endif

__inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
{

#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
(void)lane_id;
return transpose_m8n8_b16_movmatrix(a);
#else
return transpose_m8n8_b16_warp_shuffle(a, lane_id);
#endif
}

namespace ops {

Expand Down Expand Up @@ -246,7 +277,7 @@ struct Gemm {
// convert to half
half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// transpose 8x8 accum tile
uint trans_C = transpose_m8n8_b16((uint&)half_C);
uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
// store to global memory
OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
}
Expand Down

0 comments on commit f8ed456

Please sign in to comment.