diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 732545465d9c9..ae09f0aaad8f8 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -33,7 +33,6 @@ namespace { constexpr int kCUDANumThreads = 256; constexpr int kColwiseReduceTileSize = 32; -constexpr int kWarpSize = 32; constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh) @@ -556,108 +555,8 @@ __global__ void GammaBetaBackwardCUDAKernel1( } } -template -__global__ void GammaBetaBackwardCUDAKernel_32x32( - int64_t M, - int64_t N, - const T* dY, - const T* X, - const T_ACC* mean, - const T_ACC* rstd, - T* dg, - T* db) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; - T_ACC dg_sum = 0; - T_ACC db_sum = 0; - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; - - if (j < N) { - constexpr int unroll_factor = 8; - int laneId = threadIdx.x & 0x1f; - - T_ACC mean_reg, mean_reg_tmp; - T_ACC rstd_reg, rstd_reg_tmp; - T dY_reg; - T X_reg; - - // Main loop - int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); - bcounter++) { - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - - if (laneId < unroll_factor) { - mean_reg_tmp = mean[offset + laneId]; - rstd_reg_tmp = rstd[offset + laneId]; - } -#if !defined(USE_ROCM) - // Volta and newer architectures allow lane divergence within a warp. - __syncwarp(); -#endif - - #pragma unroll - for (int ii = 0; ii < unroll_factor; ++ii) { - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize); - rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize); - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } - } - - // Remainder loop - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - for (int ii = 0; ii < unroll_factor; ii++) { - if ((offset + ii) < M) { - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } - } - - // This kernel uses a block of (32 x 32) and gets called when M; N - // divide by 32. We can use warp shuffles for the final reduction - // step. This removes 4 shmem loads and stores with their - // corresponding __syncthreads() - - // This greatly reduces bank conflicts at the expense of a little - // extra shared memory. It does not impact occupancy - int padded_bx = (1 + blockDim.x); - - s_dg = s_data_typed; - s_db = s_data_typed + (padded_bx * blockDim.y); - s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; - s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; - __syncthreads(); - - // Load transposed so that a warp holds an entire column - T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y]; - T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y]; - for (int delta = 16; delta >= 1; delta /= 2) { - reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); - reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); - } - - if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; - if (dg) { - dg[j] = reg_dg; - } - if (db) { - db[j] = reg_db; - } - } - } -} template __global__ void GammaBetaBackwardCUDAKernel( @@ -670,75 +569,66 @@ __global__ void GammaBetaBackwardCUDAKernel( T* dg, T* db) { alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; - + T_ACC * s_data_typed = reinterpret_cast(&s_data1); const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; - + constexpr int unroll = 8; + T dYs[unroll]; + T Xs[unroll]; + T_ACC * means = s_data_typed; + T_ACC * rstds = s_data_typed + unroll * blockDim.y; T_ACC dg_sum = 0; T_ACC db_sum = 0; - if (j < N) { - constexpr int unroll_factor = 8; - - T_ACC mean_reg; - T_ACC rstd_reg; - T dY_reg; - T X_reg; - - // Main Loop int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){ - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + for (bcounter = 0; bcounter < M/(blockDim.y * unroll); bcounter++){ + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; + #pragma unroll + for (int ii=0; ii= 1; offset /= 2) { + for (int offset = blockDim.y/2; offset >=1; offset /= 2){ if (threadIdx.y < offset) { - s_dg[threadIdx.y * blockDim.x + threadIdx.x] += - s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - s_db[threadIdx.y * blockDim.x + threadIdx.x] += - s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - } + s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] += s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] += + s_data_typed[blockDim.x * blockDim.y + (threadIdx.y + offset) * blockDim.x + threadIdx.x]; + } __syncthreads(); } - if (threadIdx.y == 0) { if (dg) { - dg[j] = s_dg[threadIdx.x]; + dg[j] = s_data_typed[threadIdx.x]; } if (db) { - db[j] = s_db[threadIdx.x]; + db[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y]; } } } @@ -873,8 +763,7 @@ void LayerNormBackwardKernelImplInternal( T* dgamma_data = dgamma->defined() ? dgamma->template data_ptr() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; - - if (M < 128) { + if (M < 512) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; GammaBetaBackwardSimpleCUDAKernel @@ -889,42 +778,19 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { - // This implementation relies on warp primitives and requires that M and N divide - // exactly to warp size. - dim3 threads{kWarpSize, kWarpSize}; - int blocks = (N + threads.x - 1) / threads.x; - - // If M and N divide by 32, we can use warp shuffles for the final reduction. That requires - // transposing values in shared memory, so we apply a padding to reduce bank conflicts. - size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y; - GammaBetaBackwardCUDAKernel_32x32 - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - dim3 threads{16, 32}; - int blocks = (N + threads.x - 1) / threads.x; - size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y; - GammaBetaBackwardCUDAKernel - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } + dim3 threads{16, 32}; + int blocks = (N + threads.x-1)/threads.x; + GammaBetaBackwardCUDAKernel + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } }