Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[OPT] improve rms_norm kernel" #293

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 83 additions & 62 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,68 +23,103 @@

namespace vllm {

template <typename scalar_t>
struct __align__(16) vec8_t {
scalar_t x, y, z, w, u, v, s, t;

__device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {}
__device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u,
scalar_t v, scalar_t s, scalar_t t)
: x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {}

__device__ vec8_t operator*(const vec8_t& other) const {
return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w,
u * other.u, v * other.v, s * other.s, t * other.t);
}

__device__ vec8_t operator*(const float& scale) const {
return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale,
v * scale, s * scale, t * scale);
}

__device__ vec8_t operator+(const vec8_t& other) const {
return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w,
u + other.u, v + other.v, s + other.s, t + other.t);
}

__device__ void operator+=(const vec8_t& other) {
x += other.x;
y += other.y;
z += other.z;
w += other.w;
u += other.u;
v += other.v;
s += other.s;
t += other.t;
}

__device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; }
};

#ifdef __HIP__MI300_MI250__

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float v8_variance_sum = 0.0f;

const int64_t tx = threadIdx.x;
const int64_t bx = blockIdx.x;
const int64_t num_threads = blockDim.x;
vec8_t<scalar_t> v8_variance = {0, 0, 0, 0, 0, 0, 0, 0};

auto* __restrict__ out_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(out);
auto* __restrict__ input_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(
input + bx * static_cast<int64_t>(hidden_size));
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
vec8_t<scalar_t>* vectorized_out = reinterpret_cast<vec8_t<scalar_t>*>(out);
vec8_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec8_t<scalar_t> const*>(input);
vec8_t<scalar_t> const* vectorized_weight =
reinterpret_cast<vec8_t<scalar_t> const*>(weight);
const int vec_hidden_size = hidden_size >> 3;

// Compute variance. Be careful, hidden_size should multiple of 4.
for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
v8_variance_sum += temp.sum_squares();
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> x = vectorized_in[blockIdx.x * vec_hidden_size + idx];
v8_variance += x * x;
}
float v8_variance_sum = v8_variance.sum();

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;

float variance =
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, num_threads);
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

variance = s_variance;

for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
temp *= variance;
temp *= weight_v[idx];
out_v[bx * static_cast<int64_t>(vec_hidden_size) + idx] = temp;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx];
vec8_t<scalar_t> v8_w = vectorized_weight[idx];
vectorized_out[blockIdx.x * vec_hidden_size + idx] =
v8_in * s_variance * v8_w;
}
}

template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
#else

// TODO(maleksan): Investigate why vectorization doesn't work for Navi.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}

Expand All @@ -98,13 +133,14 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
out[blockIdx.x * static_cast<int64_t>(hidden_size) + idx] =
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}

#endif

/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
Expand Down Expand Up @@ -218,37 +254,22 @@ struct Vec<c10::BFloat16, 8> {

} // namespace vllm

#define LAUNCH_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { \
vllm::rms_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size, \
vec_hidden_size); \
});

void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int vec_size = 16 / input.element_size();
int vec_hidden_size = hidden_size / vec_size;

dim3 grid(num_tokens);
dim3 block(std::min(vec_hidden_size, 1024));
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#ifdef __HIP__MI300_MI250__
if (vec_size % 8 == 0) {
LAUNCH_RMS_NORM(8);
} else {
LAUNCH_RMS_NORM(0);
}
#else
LAUNCH_RMS_NORM(0);
#endif
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
Expand Down
4 changes: 2 additions & 2 deletions csrc/type_convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct _typeConvert<c10::Half> {
}
};

#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
Expand Down Expand Up @@ -162,4 +162,4 @@ struct alignas(16) _f16Vec {
return result;
}
};
} // namespace vllm
} // namespace vllm