Skip to content

Commit

Permalink
Refactor vector to use half to maintain same alignment as c10::Half; …
Browse files Browse the repository at this point in the history
…move packed logic into member functions
  • Loading branch information
mawong-amd committed Mar 25, 2024
1 parent cdfe6f2 commit 20f8bd1
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 55 deletions.
148 changes: 95 additions & 53 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,53 +36,90 @@ __global__ void rms_norm_kernel(
}
}

/* Helper struct to generate vectorized and packed FP16 ops

/* Helper POD struct to generate vectorized and packed FP16 ops
for appropriate overloads of fused_add_rms_norm_kernel.
Only special member functions and functions that are necessary
in that kernel are implemented.
*/
template<int width>
struct _half2Vec {
struct _halfVec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
__half2 data[width];
__half data[width];

__device__ _half2Vec() = default;
__device__ ~_half2Vec() = default;
__device__ _half2Vec(const _half2Vec<width>&) = default;
__device__ _half2Vec& operator=(const _half2Vec<width>&) = default;
__device__ _half2Vec(_half2Vec<width>&&) = default;
__device__ _half2Vec& operator=(_half2Vec<width>&&) = default;
__device__ _halfVec() = default;
__device__ ~_halfVec() = default;
__device__ _halfVec(const _halfVec<width>&) = default;
__device__ _halfVec& operator=(const _halfVec<width>&) = default;
__device__ _halfVec(_halfVec<width>&&) = default;
__device__ _halfVec& operator=(_halfVec<width>&&) = default;

__device__ inline _half2Vec& operator+=(const _half2Vec<width>& other) {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] += other.data[i];
__device__ inline _halfVec& operator+=(const _halfVec<width>& other) {
if constexpr (width % 2 == 0) {
for (int i = 0; i < width; i += 2) {
__half2 z = __half2{data[i], data[i+1]};
z += __half2{other.data[i], other.data[i+1]};
data[i] = z.x;
data[i+1] = z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] += other.data[i];
}
return *this;
}

__device__ inline _half2Vec& operator*=(const _half2Vec<width>& other) {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] *= other.data[i];
__device__ inline _halfVec& operator*=(const _halfVec<width>& other) {
if constexpr (width % 2 == 0) {
for (int i = 0; i < width; i += 2) {
__half2 z = __half2{data[i], data[i+1]};
z *= __half2{other.data[i], other.data[i+1]};
data[i] = z.x;
data[i+1] = z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] *= other.data[i];
}
return *this;
}

__device__ inline _half2Vec& operator*=(const float scale) {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] = __float22half2_rn(__half22float2(data[i]) * scale);
__device__ inline _halfVec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 zf = __half22float2(__half2{data[i], data[i+1]});
__half2 z = __float22half2_rn(zf * scale);
data[i] = z.x;
data[i+1] = z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] = __float2half_rn(__half2float(data[i]) * scale);
}
return *this;
}

__device__ inline float sum_squares() const {
float result = 0.0f;
#pragma unroll
for (int i = 0; i < width; ++i) {
float2 z = __half22float2(data[i]);
result += z.x * z.x + z.y * z.y;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = __half22float2(__half2{data[i], data[i+1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = __half2float(data[i]);
result += x * x;
}
}
return result;
}
Expand All @@ -93,9 +130,8 @@ struct _half2Vec {
packed and vectorized operations, which help with the
memory latency bottleneck. */
template<typename scalar_t, int width>
__global__ typename std::enable_if<
(width > 0) && std::is_same<scalar_t, c10::Half>::value,
void>::type
__global__ std::enable_if_t<
(width > 0) && std::is_same_v<scalar_t, c10::Half>>
fused_add_rms_norm_kernel(
c10::Half* __restrict__ input, // [..., hidden_size]
c10::Half* __restrict__ residual, // [..., hidden_size]
Expand All @@ -104,35 +140,41 @@ fused_add_rms_norm_kernel(
const int num_tokens,
const int hidden_size)
{
static_assert(sizeof(_half2Vec<width>) == sizeof(c10::Half) * width * 2);
const int vec_hidden_size = hidden_size / (width * 2);
// Ensures reinterpret_cast does not mutate address for alignment reasons
static_assert(alignof(c10::Half) == alignof(_halfVec<width>));
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_halfVec<width>>);
static_assert(sizeof(_halfVec<width>) == sizeof(c10::Half) * width);
const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice */
auto* __restrict__ input_v = reinterpret_cast<_half2Vec<width>*>(input);
auto* __restrict__ residual_v = reinterpret_cast<_half2Vec<width>*>(residual);
auto* __restrict__ weight_v = reinterpret_cast<const _half2Vec<width>*>(weight);
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_halfVec<width>*>(input);
auto* __restrict__ residual_v = reinterpret_cast<_halfVec<width>*>(residual);
auto* __restrict__ weight_v = reinterpret_cast<const _halfVec<width>*>(weight);

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_half2Vec<width> temp = input_v[id];
_halfVec<width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256)
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0)
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_half2Vec<width> temp = residual_v[id];
_halfVec<width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
Expand Down Expand Up @@ -164,9 +206,9 @@ __global__ void fused_add_rms_norm_kernel(
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256)
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
else variance = blockReduceSum<float, 256>(variance);
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
Expand Down Expand Up @@ -239,32 +281,32 @@ void fused_add_rms_norm(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16, try to use the optimized kernel
with packed vectors. Max optimization is achieved with a width-4
vector of 2-packed-FP16s (equivalent to a vector of 8 FP16s)
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16s
since we can load at most 128 bits at once in a global memory op.
However, we have to narrow the vectors if the hidden_size does
not divide 8.
Specifically, assuming hidden-size does not divide 8:
If the hidden_size divides 4, we can use a width-2 packed vector
(equivalent to a vector of 4 FP16s).
If the hidden_size divides 2 or 6, we can use a width-1
packed vector (equiv. to vector of 2 FP16s).
If the hidden_size is odd, we cannot use packed vectors
=> cannot use the optimized kernel, which is signified
by setting (packed vector) width = 0.
If the hidden_size divides 4, we can use a width-4 vector.
If the hidden_size divides 2 or 6, we can use a width-2
vector.
If the hidden_size is odd, we can only use a width-1 vector
which provides no benefit over the base implementation
=> we do not use the optimized kernel, which is signified
by setting width = 0.
*/
switch (hidden_size % 8) {
case 0:
LAUNCH_FUSED_ADD_RMS_NORM(4);
LAUNCH_FUSED_ADD_RMS_NORM(8);
break;
case 2:
[[fallthrough]];
case 6:
LAUNCH_FUSED_ADD_RMS_NORM(1);
LAUNCH_FUSED_ADD_RMS_NORM(2);
break;
case 4:
LAUNCH_FUSED_ADD_RMS_NORM(2);
LAUNCH_FUSED_ADD_RMS_NORM(4);
break;
default:
LAUNCH_FUSED_ADD_RMS_NORM(0);
Expand Down
4 changes: 2 additions & 2 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ __inline__ __device__ T warpReduceSum(T val) {
}

// Helper function to return the next largest power of 2
constexpr int _nextPow2(int num) {
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (8 * sizeof(num) - __builtin_clz(num - 1));
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

/* Calculate the sum of all elements in a block */
Expand Down

0 comments on commit 20f8bd1

Please sign in to comment.