Skip to content

Commit

Permalink
Revert "[pytorch] Layer norm backward speed gain with warp shuffles (p…
Browse files Browse the repository at this point in the history
…ytorch#87445)"

This reverts commit b6f2833.

Reverted pytorch#87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
  • Loading branch information
pytorchmergebot committed Oct 26, 2022
1 parent 585d715 commit 9639cb8
Showing 1 changed file with 54 additions and 188 deletions.
242 changes: 54 additions & 188 deletions aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ namespace {

constexpr int kCUDANumThreads = 256;
constexpr int kColwiseReduceTileSize = 32;
constexpr int kWarpSize = 32;
constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types

// aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh)
Expand Down Expand Up @@ -556,108 +555,8 @@ __global__ void GammaBetaBackwardCUDAKernel1(
}
}

template <typename T, typename T_ACC>
__global__ void GammaBetaBackwardCUDAKernel_32x32(
int64_t M,
int64_t N,
const T* dY,
const T* X,
const T_ACC* mean,
const T_ACC* rstd,
T* dg,
T* db) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
T_ACC* s_dg;
T_ACC* s_db;

T_ACC dg_sum = 0;
T_ACC db_sum = 0;

const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;

if (j < N) {
constexpr int unroll_factor = 8;
int laneId = threadIdx.x & 0x1f;

T_ACC mean_reg, mean_reg_tmp;
T_ACC rstd_reg, rstd_reg_tmp;
T dY_reg;
T X_reg;

// Main loop
int bcounter;
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
bcounter++) {
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;

if (laneId < unroll_factor) {
mean_reg_tmp = mean[offset + laneId];
rstd_reg_tmp = rstd[offset + laneId];
}
#if !defined(USE_ROCM)
// Volta and newer architectures allow lane divergence within a warp.
__syncwarp();
#endif

#pragma unroll
for (int ii = 0; ii < unroll_factor; ++ii) {
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
}
}

// Remainder loop
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
for (int ii = 0; ii < unroll_factor; ii++) {
if ((offset + ii) < M) {
mean_reg = mean[offset + ii];
rstd_reg = rstd[offset + ii];
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
}
}

// This kernel uses a block of (32 x 32) and gets called when M; N
// divide by 32. We can use warp shuffles for the final reduction
// step. This removes 4 shmem loads and stores with their
// corresponding __syncthreads()

// This greatly reduces bank conflicts at the expense of a little
// extra shared memory. It does not impact occupancy
int padded_bx = (1 + blockDim.x);

s_dg = s_data_typed;
s_db = s_data_typed + (padded_bx * blockDim.y);
s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
__syncthreads();

// Load transposed so that a warp holds an entire column
T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
for (int delta = 16; delta >= 1; delta /= 2) {
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
}

if (threadIdx.x == 0) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
if (dg) {
dg[j] = reg_dg;
}
if (db) {
db[j] = reg_db;
}
}
}
}

template <typename T, typename T_ACC>
__global__ void GammaBetaBackwardCUDAKernel(
Expand All @@ -670,75 +569,66 @@ __global__ void GammaBetaBackwardCUDAKernel(
T* dg,
T* db) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
T_ACC* s_dg;
T_ACC* s_db;

T_ACC * s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;

constexpr int unroll = 8;
T dYs[unroll];
T Xs[unroll];
T_ACC * means = s_data_typed;
T_ACC * rstds = s_data_typed + unroll * blockDim.y;
T_ACC dg_sum = 0;
T_ACC db_sum = 0;

if (j < N) {
constexpr int unroll_factor = 8;

T_ACC mean_reg;
T_ACC rstd_reg;
T dY_reg;
T X_reg;

// Main Loop
int bcounter;
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
for (bcounter = 0; bcounter < M/(blockDim.y * unroll); bcounter++){
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll;
#pragma unroll
for (int ii=0; ii<unroll; ii++){
if (threadIdx.x == 0) {
means[ii*blockDim.y + threadIdx.y] = mean[offset + ii];
rstds[ii*blockDim.y + threadIdx.y] = rstd[offset + ii];
}
dYs[ii] = dY[(offset + ii) * N + j ];
Xs[ii] = X[(offset + ii) * N + j];

}
__syncthreads();
#pragma unroll
for (int ii = 0; ii < unroll_factor; ++ii) {
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
mean_reg = mean[offset + ii];
rstd_reg = rstd[offset + ii];
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
for (int ii=0; ii<unroll; ii++){
dg_sum += dYs[ii] * (Xs[ii] - means[ii*blockDim.y + threadIdx.y]) * rstds[ii * blockDim.y + threadIdx.y];
db_sum += dYs[ii];
}
__syncthreads();
}

// Remainder loop
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
for (int ii = 0; ii < unroll_factor; ii++ ){
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll;
for (int ii = 0; ii<8; ii++ ){
T_ACC mean_val, rstd_val; // we don't use smem in the tail to avoid awkward synchronizations, perf penalty is negligible
if ((offset + ii) < M) {
dY_reg = dY[(offset + ii) * N + j ];
X_reg = X[(offset + ii) * N + j];
mean_reg = mean[offset + ii];
rstd_reg = rstd[offset + ii];
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
mean_val = mean[offset+ii];
rstd_val = rstd[offset+ii];
dYs[0] = dY[(offset + ii) * N + j ];
Xs[0] = X[(offset + ii) * N + j];
dg_sum += dYs[0] * (Xs[0] - mean_val) * rstd_val;
db_sum += dYs[0];
}
}

// Do the final reduction in shared memory
s_dg = s_data_typed;
s_db = s_data_typed + blockDim.x * blockDim.y;
s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
__syncthreads();

for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
for (int offset = blockDim.y/2; offset >=1; offset /= 2){
if (threadIdx.y < offset) {
s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
}
s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] += s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] +=
s_data_typed[blockDim.x * blockDim.y + (threadIdx.y + offset) * blockDim.x + threadIdx.x];
}
__syncthreads();
}

if (threadIdx.y == 0) {
if (dg) {
dg[j] = s_dg[threadIdx.x];
dg[j] = s_data_typed[threadIdx.x];
}
if (db) {
db[j] = s_db[threadIdx.x];
db[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y];
}
}
}
Expand Down Expand Up @@ -873,8 +763,7 @@ void LayerNormBackwardKernelImplInternal(
T* dgamma_data =
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;

if (M < 128) {
if (M < 512) {
// For small batch size, do colwise reduce directly.
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
GammaBetaBackwardSimpleCUDAKernel<T, T_ACC>
Expand All @@ -889,42 +778,19 @@ void LayerNormBackwardKernelImplInternal(
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
// This implementation relies on warp primitives and requires that M and N divide
// exactly to warp size.
dim3 threads{kWarpSize, kWarpSize};
int blocks = (N + threads.x - 1) / threads.x;

// If M and N divide by 32, we can use warp shuffles for the final reduction. That requires
// transposing values in shared memory, so we apply a padding to reduce bank conflicts.
size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
dim3 threads{16, 32};
int blocks = (N + threads.x - 1) / threads.x;
size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
GammaBetaBackwardCUDAKernel<T, T_ACC>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
dim3 threads{16, 32};
int blocks = (N + threads.x-1)/threads.x;
GammaBetaBackwardCUDAKernel<T, T_ACC>
<<<blocks, threads, 2 * sizeof(T_ACC) * threads.x * threads.y, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
Expand Down

0 comments on commit 9639cb8

Please sign in to comment.