Skip to content

Commit

Permalink
Use larger block sizes for decode; optimize warp and block reduce fully
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Mar 24, 2024
1 parent 2d1baac commit cdfe6f2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
52 changes: 29 additions & 23 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,10 @@ struct _half2Vec {
}
};

/* Max blockSize to use for fused_add_rms_norm_kernel
This kernel is memory-latency bound in many scenarios, so a smaller
block size allows for increased block occupancy on CUs and better
latency hiding on global mem ops. */
#define _FUSED_RMS_MAX_BLOCKSIZE 256

/* Function overload in the case of FP16 tensors.
Additional optimizations we can make in this case are packed and
vectorized operations, which help with the aforementioned memory
latency bottleneck. */
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template<typename scalar_t, int width>
__global__ typename std::enable_if<
(width > 0) && std::is_same<scalar_t, c10::Half>::value,
Expand All @@ -122,15 +116,18 @@ fused_add_rms_norm_kernel(

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_half2Vec<width> temp = residual_v[id];
temp += input_v[id];
residual_v[id] = temp;
_half2Vec<width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
variance = blockReduceSum<float, _FUSED_RMS_MAX_BLOCKSIZE>(variance);
if (threadIdx.x == 0) {
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256)
variance = blockReduceSum<float, 1024>(variance);
else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0)
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
Expand All @@ -144,9 +141,8 @@ fused_add_rms_norm_kernel(


/* Generic fused_add_rms_norm_kernel
No optimizations in this case, the width field is not used
but necessary for the correct overloading to occur in the
FP16 case.
The width field is not used but necessary for the correct
overloading to occur in the FP16 case.
*/
template<typename scalar_t, int width> // width is not used in this overload
__global__ void fused_add_rms_norm_kernel(
Expand All @@ -160,12 +156,17 @@ __global__ void fused_add_rms_norm_kernel(
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
x += (float) residual[blockIdx.x * hidden_size + idx];
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
residual[blockIdx.x * hidden_size + idx] = z;
}
variance = blockReduceSum<float>(variance);
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256)
variance = blockReduceSum<float, 1024>(variance);
else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
Expand Down Expand Up @@ -229,7 +230,12 @@ void fused_add_rms_norm(
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, _FUSED_RMS_MAX_BLOCKSIZE));
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16, try to use the optimized kernel
Expand Down
22 changes: 18 additions & 4 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,26 @@ template<typename T, int numLanes = warpSize>
__inline__ __device__ T warpReduceSum(T val) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= warpSize);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
}

// Helper function to return the next largest power of 2
constexpr int _nextPow2(int num) {
if (num <= 1) return num;
return 1 << (8 * sizeof(num) - __builtin_clz(num - 1));
}

/* Calculate the sum of all elements in a block */
template<typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
val = warpReduceSum<T>(val);
// If the block fits into a single warp, we are already done
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > warpSize) {
val = warpReduceSum<T>(val);
// Calculates max number of lanes that need to participate in the last warpReduce
constexpr int maxActiveLanes = (maxBlockSize + warpSize - 1) / warpSize;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % warpSize;
Expand All @@ -57,8 +65,14 @@ __inline__ __device__ T blockReduceSum(T val) {

__syncthreads();

val = (threadIdx.x < (blockDim.x / (float) warpSize)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T, maxActiveLanes>(val);
// Only (a subset of) the first warp needs to participate in the last warpReduce
if (threadIdx.x < (blockDim.x / (float) warpSize)) {
val = shared[lane];
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
}
} else {
// A single warpReduce is equal to blockReduce
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
}
return val;
}
Expand Down

0 comments on commit cdfe6f2

Please sign in to comment.