diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index b0b86541d166d..06384abfb635b 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -36,53 +36,90 @@ __global__ void rms_norm_kernel( } } -/* Helper struct to generate vectorized and packed FP16 ops + +/* Helper POD struct to generate vectorized and packed FP16 ops for appropriate overloads of fused_add_rms_norm_kernel. Only special member functions and functions that are necessary in that kernel are implemented. */ template -struct _half2Vec { +struct _halfVec { /* Not theoretically necessary that width is a power of 2 but should almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); - __half2 data[width]; + __half data[width]; - __device__ _half2Vec() = default; - __device__ ~_half2Vec() = default; - __device__ _half2Vec(const _half2Vec&) = default; - __device__ _half2Vec& operator=(const _half2Vec&) = default; - __device__ _half2Vec(_half2Vec&&) = default; - __device__ _half2Vec& operator=(_half2Vec&&) = default; + __device__ _halfVec() = default; + __device__ ~_halfVec() = default; + __device__ _halfVec(const _halfVec&) = default; + __device__ _halfVec& operator=(const _halfVec&) = default; + __device__ _halfVec(_halfVec&&) = default; + __device__ _halfVec& operator=(_halfVec&&) = default; - __device__ inline _half2Vec& operator+=(const _half2Vec& other) { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] += other.data[i]; + __device__ inline _halfVec& operator+=(const _halfVec& other) { + if constexpr (width % 2 == 0) { + for (int i = 0; i < width; i += 2) { + __half2 z = __half2{data[i], data[i+1]}; + z += __half2{other.data[i], other.data[i+1]}; + data[i] = z.x; + data[i+1] = z.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) + data[i] += other.data[i]; + } return *this; } - __device__ inline _half2Vec& operator*=(const _half2Vec& other) { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] *= other.data[i]; + __device__ inline _halfVec& operator*=(const _halfVec& other) { + if constexpr (width % 2 == 0) { + for (int i = 0; i < width; i += 2) { + __half2 z = __half2{data[i], data[i+1]}; + z *= __half2{other.data[i], other.data[i+1]}; + data[i] = z.x; + data[i+1] = z.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) + data[i] *= other.data[i]; + } return *this; } - __device__ inline _half2Vec& operator*=(const float scale) { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] = __float22half2_rn(__half22float2(data[i]) * scale); + __device__ inline _halfVec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { + #pragma unroll + for (int i = 0; i < width; i += 2) { + float2 zf = __half22float2(__half2{data[i], data[i+1]}); + __half2 z = __float22half2_rn(zf * scale); + data[i] = z.x; + data[i+1] = z.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) + data[i] = __float2half_rn(__half2float(data[i]) * scale); + } return *this; } __device__ inline float sum_squares() const { float result = 0.0f; - #pragma unroll - for (int i = 0; i < width; ++i) { - float2 z = __half22float2(data[i]); - result += z.x * z.x + z.y * z.y; + if constexpr (width % 2 == 0) { + #pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = __half22float2(__half2{data[i], data[i+1]}); + result += z.x * z.x + z.y * z.y; + } + } else { + #pragma unroll + for (int i = 0; i < width; ++i) { + float x = __half2float(data[i]); + result += x * x; + } } return result; } @@ -93,9 +130,8 @@ struct _half2Vec { packed and vectorized operations, which help with the memory latency bottleneck. */ template -__global__ typename std::enable_if< - (width > 0) && std::is_same::value, - void>::type +__global__ std::enable_if_t< + (width > 0) && std::is_same_v> fused_add_rms_norm_kernel( c10::Half* __restrict__ input, // [..., hidden_size] c10::Half* __restrict__ residual, // [..., hidden_size] @@ -104,35 +140,41 @@ fused_add_rms_norm_kernel( const int num_tokens, const int hidden_size) { - static_assert(sizeof(_half2Vec) == sizeof(c10::Half) * width * 2); - const int vec_hidden_size = hidden_size / (width * 2); + // Ensures reinterpret_cast does not mutate address for alignment reasons + static_assert(alignof(c10::Half) == alignof(_halfVec)); + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_halfVec>); + static_assert(sizeof(_halfVec) == sizeof(c10::Half) * width); + const int vec_hidden_size = hidden_size / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are - not aliased in practice */ - auto* __restrict__ input_v = reinterpret_cast<_half2Vec*>(input); - auto* __restrict__ residual_v = reinterpret_cast<_half2Vec*>(residual); - auto* __restrict__ weight_v = reinterpret_cast*>(weight); + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = reinterpret_cast<_halfVec*>(input); + auto* __restrict__ residual_v = reinterpret_cast<_halfVec*>(residual); + auto* __restrict__ weight_v = reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; - _half2Vec temp = input_v[id]; + _halfVec temp = input_v[id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; } /* Keep the following if-else block in sync with the calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) + if (num_tokens < 256) { variance = blockReduceSum(variance); - else variance = blockReduceSum(variance); - if (threadIdx.x == 0) + } else variance = blockReduceSum(variance); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); + } __syncthreads(); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; - _half2Vec temp = residual_v[id]; + _halfVec temp = residual_v[id]; temp *= s_variance; temp *= weight_v[idx]; input_v[id] = temp; @@ -164,9 +206,9 @@ __global__ void fused_add_rms_norm_kernel( } /* Keep the following if-else block in sync with the calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) + if (num_tokens < 256) { variance = blockReduceSum(variance); - else variance = blockReduceSum(variance); + } else variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -239,32 +281,32 @@ void fused_add_rms_norm( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); /*If the tensor types are FP16, try to use the optimized kernel - with packed vectors. Max optimization is achieved with a width-4 - vector of 2-packed-FP16s (equivalent to a vector of 8 FP16s) + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16s since we can load at most 128 bits at once in a global memory op. However, we have to narrow the vectors if the hidden_size does not divide 8. Specifically, assuming hidden-size does not divide 8: - If the hidden_size divides 4, we can use a width-2 packed vector - (equivalent to a vector of 4 FP16s). - If the hidden_size divides 2 or 6, we can use a width-1 - packed vector (equiv. to vector of 2 FP16s). - If the hidden_size is odd, we cannot use packed vectors - => cannot use the optimized kernel, which is signified - by setting (packed vector) width = 0. + If the hidden_size divides 4, we can use a width-4 vector. + If the hidden_size divides 2 or 6, we can use a width-2 + vector. + If the hidden_size is odd, we can only use a width-1 vector + which provides no benefit over the base implementation + => we do not use the optimized kernel, which is signified + by setting width = 0. */ switch (hidden_size % 8) { case 0: - LAUNCH_FUSED_ADD_RMS_NORM(4); + LAUNCH_FUSED_ADD_RMS_NORM(8); break; case 2: [[fallthrough]]; case 6: - LAUNCH_FUSED_ADD_RMS_NORM(1); + LAUNCH_FUSED_ADD_RMS_NORM(2); break; case 4: - LAUNCH_FUSED_ADD_RMS_NORM(2); + LAUNCH_FUSED_ADD_RMS_NORM(4); break; default: LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 8d04d1f70e1ee..27750519fee12 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -44,9 +44,9 @@ __inline__ __device__ T warpReduceSum(T val) { } // Helper function to return the next largest power of 2 -constexpr int _nextPow2(int num) { +static constexpr int _nextPow2(unsigned int num) { if (num <= 1) return num; - return 1 << (8 * sizeof(num) - __builtin_clz(num - 1)); + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } /* Calculate the sum of all elements in a block */