Skip to content

Commit

Permalink
Re-apply the blockReduceSum fix for warp divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Mar 26, 2024
1 parent b036bac commit c109d8a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
4 changes: 2 additions & 2 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ void fused_add_rms_norm(
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16, try to use the optimized kernel
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16s
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, we have to narrow the vectors if the hidden_size does
not divide 8.
Expand Down
7 changes: 2 additions & 5 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,8 @@ __inline__ __device__ T blockReduceSum(T val) {

__syncthreads();

// Only (a subset of) the first warp needs to participate in the last warpReduce
if (threadIdx.x < (blockDim.x / float(warpSize))) {
val = shared[lane];
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
}
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : 0.0f;
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
} else {
// A single warpReduce is equal to blockReduce
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
Expand Down

0 comments on commit c109d8a

Please sign in to comment.