Skip to content

Commit

Permalink
Fix correctness regression (from PR#258) in Llama-3.2-90B-Vision-Inst…
Browse files Browse the repository at this point in the history
…ruct-FP8-KV test (#294)

* Fix correctness regression in Llama-3.2-90B-Vision-Instruct-FP8-KV test

* Fixed platform substitution

* Re-remove vectorization on Navi

* Typo

* Using the same signatures in both paths

* Using 3 versions of RMS norm kernel: 1. Optimal vectorization using _f16Vec for supported types; 2. Fallback for types that don't support conversion; 3. Fallback for shapes that can't be vectorized

* clang-format

* Thinking of it, we don't really need the alternative vectorized kernel

* clang-format

* Using 64 bit types for indices

---------

Co-authored-by: wunhuang <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
  • Loading branch information
3 people authored Nov 28, 2024
1 parent 529cefe commit 6cf8eb4
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,17 @@
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
// This kernel uses the _f16Vec to represent vectorized data.
// A conversion to/from float should exist
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
const size_t hidden_size, const size_t vec_hidden_size) {
__shared__ float s_variance;
float v8_variance_sum = 0.0f;

Expand All @@ -46,7 +42,7 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);

// Compute variance. Be careful, hidden_size should multiple of 4.
for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
v8_variance_sum += temp.sum_squares();
}
Expand All @@ -64,27 +60,27 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]

variance = s_variance;

for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
temp *= variance;
temp *= weight_v[idx];
out_v[bx * static_cast<int64_t>(vec_hidden_size) + idx] = temp;
}
}

// Non vectorized kernel for unusual shapes/types without conversion
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
const size_t hidden_size, const size_t) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}

Expand All @@ -97,10 +93,9 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
out[blockIdx.x * static_cast<int64_t>(hidden_size) + idx] =
for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
Expand Down Expand Up @@ -234,21 +229,18 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
int num_tokens = input.numel() / hidden_size;
int vec_size = 16 / input.element_size();
int vec_hidden_size = hidden_size / vec_size;
bool can_run_vectorize = (hidden_size % vec_size) == 0;

dim3 grid(num_tokens);
dim3 block(std::min(vec_hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#ifdef __HIP__MI300_MI250__
if (vec_size % 8 == 0) {
if (vec_size % 8 == 0 && can_run_vectorize) {
dim3 block(std::min(vec_hidden_size, 1024));
LAUNCH_RMS_NORM(8);
} else {
dim3 block(std::min(hidden_size, 1024));
LAUNCH_RMS_NORM(0);
}
#else
LAUNCH_RMS_NORM(0);
#endif
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
Expand Down

0 comments on commit 6cf8eb4

Please sign in to comment.