Skip to content

Commit

Permalink
Revert "[OPT] improve rms_norm kernel (#258)"
Browse files Browse the repository at this point in the history
This reverts commit 15c78e7.
  • Loading branch information
gshtras authored Nov 26, 2024
1 parent 2302ad6 commit 2fa3f79
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 64 deletions.
145 changes: 83 additions & 62 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,68 +23,103 @@

namespace vllm {

template <typename scalar_t>
struct __align__(16) vec8_t {
scalar_t x, y, z, w, u, v, s, t;

__device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {}
__device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u,
scalar_t v, scalar_t s, scalar_t t)
: x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {}

__device__ vec8_t operator*(const vec8_t& other) const {
return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w,
u * other.u, v * other.v, s * other.s, t * other.t);
}

__device__ vec8_t operator*(const float& scale) const {
return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale,
v * scale, s * scale, t * scale);
}

__device__ vec8_t operator+(const vec8_t& other) const {
return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w,
u + other.u, v + other.v, s + other.s, t + other.t);
}

__device__ void operator+=(const vec8_t& other) {
x += other.x;
y += other.y;
z += other.z;
w += other.w;
u += other.u;
v += other.v;
s += other.s;
t += other.t;
}

__device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; }
};

#ifdef __HIP__MI300_MI250__

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float v8_variance_sum = 0.0f;

const int64_t tx = threadIdx.x;
const int64_t bx = blockIdx.x;
const int64_t num_threads = blockDim.x;
vec8_t<scalar_t> v8_variance = {0, 0, 0, 0, 0, 0, 0, 0};

auto* __restrict__ out_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(out);
auto* __restrict__ input_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(
input + bx * static_cast<int64_t>(hidden_size));
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
vec8_t<scalar_t>* vectorized_out = reinterpret_cast<vec8_t<scalar_t>*>(out);
vec8_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec8_t<scalar_t> const*>(input);
vec8_t<scalar_t> const* vectorized_weight =
reinterpret_cast<vec8_t<scalar_t> const*>(weight);
const int vec_hidden_size = hidden_size >> 3;

// Compute variance. Be careful, hidden_size should multiple of 4.
for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
v8_variance_sum += temp.sum_squares();
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> x = vectorized_in[blockIdx.x * vec_hidden_size + idx];
v8_variance += x * x;
}
float v8_variance_sum = v8_variance.sum();

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;

float variance =
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, num_threads);
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

variance = s_variance;

for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
temp *= variance;
temp *= weight_v[idx];
out_v[bx * static_cast<int64_t>(vec_hidden_size) + idx] = temp;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx];
vec8_t<scalar_t> v8_w = vectorized_weight[idx];
vectorized_out[blockIdx.x * vec_hidden_size + idx] =
v8_in * s_variance * v8_w;
}
}

template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
#else

// TODO(maleksan): Investigate why vectorization doesn't work for Navi.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}

Expand All @@ -98,13 +133,14 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
out[blockIdx.x * static_cast<int64_t>(hidden_size) + idx] =
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}

#endif

/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
Expand Down Expand Up @@ -218,37 +254,22 @@ struct Vec<c10::BFloat16, 8> {

} // namespace vllm

#define LAUNCH_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { \
vllm::rms_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size, \
vec_hidden_size); \
});

void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int vec_size = 16 / input.element_size();
int vec_hidden_size = hidden_size / vec_size;

dim3 grid(num_tokens);
dim3 block(std::min(vec_hidden_size, 1024));
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#ifdef __HIP__MI300_MI250__
if (vec_size % 8 == 0) {
LAUNCH_RMS_NORM(8);
} else {
LAUNCH_RMS_NORM(0);
}
#else
LAUNCH_RMS_NORM(0);
#endif
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
Expand Down
4 changes: 2 additions & 2 deletions csrc/type_convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct _typeConvert<c10::Half> {
}
};

#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
Expand Down Expand Up @@ -162,4 +162,4 @@ struct alignas(16) _f16Vec {
return result;
}
};
} // namespace vllm
} // namespace vllm

0 comments on commit 2fa3f79

Please sign in to comment.