From 131b217cf13a8defeb06c6a779160849d140fc4a Mon Sep 17 00:00:00 2001 From: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:47:38 -0700 Subject: [PATCH] adds wvSpltK optimization for skinny gemm. (#54) * adds wvSpltK optimization for skinny gemm. --------- Co-authored-by: Hashem Hashemi --- csrc/custom/custom.cu | 13 + csrc/custom/custom_kernels.cu | 1582 ++++++++++++++++++++++ vllm/model_executor/layers/tuned_gemm.py | 12 +- 3 files changed, 1605 insertions(+), 2 deletions(-) diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index 3da25ece3e87c..9e92187967d47 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -39,6 +39,18 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, at::cuda::getCurrentCUDAStream(), rows_per_block); } +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, + const int N, cudaStream_t stream, const int CuCount); + +void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in, + const int CuCount) { + int M = in_a.size(0); + int K = in_a.size(1); + int N = N_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); @@ -90,5 +102,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("LLZZ", &LLZZ); m.def("paged_attention_custom", &paged_attention_custom, "PagedAttention LL4Mi Custom."); + m.def("wvSpltK", &wvSpltK); // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index afecf82eb3d77..2c4698533332e 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -309,3 +309,1585 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, if (cudaSuccess != err) throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } + +///////////////////////////////////////////// + +using half8 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + +/*template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); + //return *((T*)addr); +}*/ + +#define THRDS 64 +#define YTILE 2 +#define WvPrGrp 16 +#define A_CHUNK 8 +#define UNRL 2 +#define M 1 +#define DTYPE half + +__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + while (n < N) { + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); +#if (YTILE >= 2) + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif + } + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + for (int m = 0; m < M; m++) { + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + } +} + +__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#undef YTILE +#undef UNRL +#undef M + +#define YTILE 2 +#define UNRL 2 +#define M 2 + +__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#undef YTILE +#undef UNRL +#undef M + +#define YTILE 5 +#define UNRL 2 +#define M 3 + +__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#undef YTILE +#undef UNRL +#undef M + +#define YTILE 7 +#define UNRL 1 +#define M 4 + +__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + dim3 block(THRDS, WvPrGrp); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + switch (N_in) { + case 1: + if ((K_in <= 32 * 1024) && (M_in % 2 == 0)) { + wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + } else { + wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + } + break; + case 2: + wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + break; + case 3: + wvSpltK_hf_m3_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + break; + case 4: + wvSpltK_hf_m4_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index a3d299c05caef..3ecacafd31977 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -23,6 +23,8 @@ def __init__(self): self.bestsols = {} self.load_best_sols() self.create_ds() + self.CuCount = torch.cuda.get_device_properties( + device='cuda').multi_processor_count if (self.save_gemm == 1): self.tuned_df = pd.DataFrame(columns=['M', 'N', 'K']) @@ -75,7 +77,6 @@ def mm(self, inp, weights): #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) else: - if (self.save_gemm == 1): #print('>>>Tgemm Default',inp_view.shape, # inp.shape,weights.shape,soltype,solidx) @@ -89,7 +90,14 @@ def mm(self, inp, weights): ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - if n == 1 and inp_view.dtype == torch.float16: + if ((n == 4 or n == 3 or n == 2 or n == 1) and k % 8 == 0 + and inp_view.dtype == torch.float16): + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + _custom_C.wvSpltK(weights, inp_view, out, n, self.CuCount) + elif n == 1 and inp_view.dtype == torch.float16: out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype,