From 6cf8eb42a4d3b1744784cdc0f4bcf0de77d6d3c4 Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:22:59 +0800 Subject: [PATCH] Fix correctness regression (from PR#258) in Llama-3.2-90B-Vision-Instruct-FP8-KV test (#294) * Fix correctness regression in Llama-3.2-90B-Vision-Instruct-FP8-KV test * Fixed platform substitution * Re-remove vectorization on Navi * Typo * Using the same signatures in both paths * Using 3 versions of RMS norm kernel: 1. Optimal vectorization using _f16Vec for supported types; 2. Fallback for types that don't support conversion; 3. Fallback for shapes that can't be vectorized * clang-format * Thinking of it, we don't really need the alternative vectorized kernel * clang-format * Using 64 bit types for indices --------- Co-authored-by: wunhuang Co-authored-by: Gregory Shtrasberg --- csrc/layernorm_kernels.cu | 40 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 405ba213628f6..e14ad972a0a45 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -16,21 +16,17 @@ #include "quantization/fp8/nvidia/quant_utils.cuh" #endif -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ - defined(__gfx941__) || defined(__gfx942__)) - #define __HIP__MI300_MI250__ -#endif - namespace vllm { -// TODO(woosuk): Further optimize this kernel. +// This kernel uses the _f16Vec to represent vectorized data. +// A conversion to/from float should exist template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, - const int hidden_size, const int vec_hidden_size) { + const size_t hidden_size, const size_t vec_hidden_size) { __shared__ float s_variance; float v8_variance_sum = 0.0f; @@ -46,7 +42,7 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] reinterpret_cast*>(weight); // Compute variance. Be careful, hidden_size should multiple of 4. - for (int idx = tx; idx < vec_hidden_size; idx += num_threads) { + for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) { _f16Vec temp = input_v[idx]; v8_variance_sum += temp.sum_squares(); } @@ -64,7 +60,7 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] variance = s_variance; - for (int idx = tx; idx < vec_hidden_size; idx += num_threads) { + for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) { _f16Vec temp = input_v[idx]; temp *= variance; temp *= weight_v[idx]; @@ -72,19 +68,19 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] } } +// Non vectorized kernel for unusual shapes/types without conversion template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, - const int hidden_size, const int vec_hidden_size) { + const size_t hidden_size, const size_t) { __shared__ float s_variance; float variance = 0.0f; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = - (float)input[blockIdx.x * static_cast(hidden_size) + idx]; + for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } @@ -97,10 +93,9 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] } __syncthreads(); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = - (float)input[blockIdx.x * static_cast(hidden_size) + idx]; - out[blockIdx.x * static_cast(hidden_size) + idx] = + for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } @@ -234,21 +229,18 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] int num_tokens = input.numel() / hidden_size; int vec_size = 16 / input.element_size(); int vec_hidden_size = hidden_size / vec_size; + bool can_run_vectorize = (hidden_size % vec_size) == 0; dim3 grid(num_tokens); - dim3 block(std::min(vec_hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - -#ifdef __HIP__MI300_MI250__ - if (vec_size % 8 == 0) { + if (vec_size % 8 == 0 && can_run_vectorize) { + dim3 block(std::min(vec_hidden_size, 1024)); LAUNCH_RMS_NORM(8); } else { + dim3 block(std::min(hidden_size, 1024)); LAUNCH_RMS_NORM(0); } -#else - LAUNCH_RMS_NORM(0); -#endif } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \