-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Improved PagedAttention FP8 (faster kvcache dequant v1) #346
base: llama_fp8_12062024
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,8 @@ | |
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, | ||
create_kv_caches_with_random) | ||
|
||
NUM_BLOCKS = 1024 * 1024 | ||
PARTITION_SIZE = 512 | ||
NUM_BLOCKS = 256 * 1024 | ||
PARTITION_SIZE = 256 | ||
|
||
|
||
@torch.inference_mode() | ||
|
@@ -101,7 +101,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: | |
start_time = time.perf_counter() | ||
|
||
# Using default kv_scale | ||
k_scale = v_scale = 1.0 | ||
k_scale = v_scale = 0.1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, can you explain the default kv-scale change? |
||
|
||
for _ in range(num_iters): | ||
if version == "v1": | ||
|
@@ -161,6 +161,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: | |
kv_cache_dtype, | ||
k_scale, | ||
v_scale, | ||
None, | ||
PARTITION_SIZE | ||
) | ||
else: | ||
raise ValueError(f"Invalid version: {version}") | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,16 +1,185 @@ | ||||||||||||||||||||||||||
#include "common.cuh" | ||||||||||||||||||||||||||
#include "dispatch_utils.h" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#include <ATen/cuda/CUDAContext.h> | ||||||||||||||||||||||||||
#include <torch/all.h> | ||||||||||||||||||||||||||
#include <c10/cuda/CUDAGuard.h> | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#ifndef USE_ROCM | ||||||||||||||||||||||||||
#include <cmath> | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#include "cuda_compat.h" | ||||||||||||||||||||||||||
#include "dispatch_utils.h" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#if defined(USE_CUDA_FP8_FORMAT) | ||||||||||||||||||||||||||
#include <cub/util_type.cuh> | ||||||||||||||||||||||||||
#include <cub/cub.cuh> | ||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||
#include <hipcub/util_type.hpp> | ||||||||||||||||||||||||||
#include <hipcub/hipcub.hpp> | ||||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#if defined(USE_CUDA_FP8_FORMAT) | ||||||||||||||||||||||||||
using FP8_TYPE = c10::Float8_e4m3fn; | ||||||||||||||||||||||||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = | ||||||||||||||||||||||||||
std::numeric_limits<FP8_TYPE>::max(); | ||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||
#include "amd/hip_float8.h" | ||||||||||||||||||||||||||
using FP8_TYPE = c10::Float8_e4m3fnuz; | ||||||||||||||||||||||||||
// Using the default max value from pytorch (240.0) will cause accuracy | ||||||||||||||||||||||||||
// issue when running dynamic quantization. Here use 224.0f for rocm. | ||||||||||||||||||||||||||
constexpr auto FP8_E4M3_MAX = 224.0f; | ||||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
namespace vllm { | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { | ||||||||||||||||||||||||||
float old; | ||||||||||||||||||||||||||
old = (value >= 0) | ||||||||||||||||||||||||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) | ||||||||||||||||||||||||||
: __uint_as_float( | ||||||||||||||||||||||||||
atomicMin((unsigned int*)addr, __float_as_uint(value))); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
return old; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
Comment on lines
+33
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <bool is_scale_inverted> | ||||||||||||||||||||||||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, | ||||||||||||||||||||||||||
float const scale) { | ||||||||||||||||||||||||||
float x = 0.0f; | ||||||||||||||||||||||||||
if constexpr (is_scale_inverted) { | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kindly name the variable with meaningful name, like scaledValue. |
||||||||||||||||||||||||||
x = val * scale; | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
x = val / scale; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if scale is zero, error handling? |
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); | ||||||||||||||||||||||||||
#if defined(USE_CUDA_FP8_FORMAT) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naming r as result, or something like that |
||||||||||||||||||||||||||
return static_cast<c10::Float8_e4m3fn>(r); | ||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||
// Use hardware cvt instruction for fp8 on rocm | ||||||||||||||||||||||||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data, | ||||||||||||||||||||||||||
c10::Float8_e4m3fnuz::from_bits()); | ||||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// Compute the absolute maximum m of the input tensor and store | ||||||||||||||||||||||||||
// m / float8_e4m3::max() in *scale. Each thread block performs a | ||||||||||||||||||||||||||
// reduction tree and the memory in scale is atomically updated. | ||||||||||||||||||||||||||
// So to get the right answer, *scale needs to be initialized to | ||||||||||||||||||||||||||
// a value <= 0.0 and we need to wait for all thread blocks to | ||||||||||||||||||||||||||
// finish before consuming *scale. | ||||||||||||||||||||||||||
template <typename scalar_t> | ||||||||||||||||||||||||||
__global__ void segmented_max_reduction(float* __restrict__ scale, | ||||||||||||||||||||||||||
const scalar_t* __restrict__ input, | ||||||||||||||||||||||||||
int64_t num_elems) { | ||||||||||||||||||||||||||
__shared__ float cache[1024]; | ||||||||||||||||||||||||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
// First store maximum for all values processes by | ||||||||||||||||||||||||||
// the current thread in cache[threadIdx.x] | ||||||||||||||||||||||||||
scalar_t tmp = 0.0; | ||||||||||||||||||||||||||
while (i < num_elems) { | ||||||||||||||||||||||||||
float x = static_cast<float>(input[i]); | ||||||||||||||||||||||||||
tmp = max(tmp, fabs(x)); | ||||||||||||||||||||||||||
i += blockDim.x * gridDim.x; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
cache[threadIdx.x] = tmp; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
__syncthreads(); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// Now perform parallel reduction within the thread block | ||||||||||||||||||||||||||
int ib = blockDim.x / 2; | ||||||||||||||||||||||||||
while (ib != 0) { | ||||||||||||||||||||||||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { | ||||||||||||||||||||||||||
cache[threadIdx.x] = cache[threadIdx.x + ib]; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
__syncthreads(); | ||||||||||||||||||||||||||
ib /= 2; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
// Finally, since cache[0] contains the maximum for this thread block, | ||||||||||||||||||||||||||
// atomically write the max to the target location | ||||||||||||||||||||||||||
if (threadIdx.x == 0) { | ||||||||||||||||||||||||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename scalar_t> | ||||||||||||||||||||||||||
struct __align__(8) vec4_t { | ||||||||||||||||||||||||||
scalar_t x; | ||||||||||||||||||||||||||
scalar_t y; | ||||||||||||||||||||||||||
scalar_t z; | ||||||||||||||||||||||||||
scalar_t w; | ||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
typedef struct __align__(4) { | ||||||||||||||||||||||||||
FP8_TYPE x; | ||||||||||||||||||||||||||
FP8_TYPE y; | ||||||||||||||||||||||||||
FP8_TYPE z; | ||||||||||||||||||||||||||
FP8_TYPE w; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
float8x4_t; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename scalar_t> | ||||||||||||||||||||||||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input, | ||||||||||||||||||||||||||
int64_t const num_elems, int const tid, | ||||||||||||||||||||||||||
int const step) { | ||||||||||||||||||||||||||
// Vectorized input/output to better utilize memory bandwidth. | ||||||||||||||||||||||||||
vec4_t<scalar_t> const* vectorized_in = | ||||||||||||||||||||||||||
reinterpret_cast<vec4_t<scalar_t> const*>(input); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
int64_t const num_vec_elems = num_elems >> 2; | ||||||||||||||||||||||||||
float absmax_val = 0.0f; | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More common c++ const usage and variable style
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#pragma unroll 4 | ||||||||||||||||||||||||||
for (int64_t i = tid; i < num_vec_elems; i += step) { | ||||||||||||||||||||||||||
vec4_t<scalar_t> in_vec = vectorized_in[i]; | ||||||||||||||||||||||||||
absmax_val = max(absmax_val, fabs(in_vec.x)); | ||||||||||||||||||||||||||
absmax_val = max(absmax_val, fabs(in_vec.y)); | ||||||||||||||||||||||||||
absmax_val = max(absmax_val, fabs(in_vec.z)); | ||||||||||||||||||||||||||
absmax_val = max(absmax_val, fabs(in_vec.w)); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// Handle the remaining elements if num_elems is not divisible by 4 | ||||||||||||||||||||||||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { | ||||||||||||||||||||||||||
absmax_val = max(absmax_val, fabs(input[i])); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
return absmax_val; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename scalar_t, bool is_scale_inverted> | ||||||||||||||||||||||||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, | ||||||||||||||||||||||||||
scalar_t const* __restrict__ input, | ||||||||||||||||||||||||||
float const scale, | ||||||||||||||||||||||||||
int64_t const num_elems, | ||||||||||||||||||||||||||
int const tid, int const step) { | ||||||||||||||||||||||||||
// Vectorized input/output to better utilize memory bandwidth. | ||||||||||||||||||||||||||
vec4_t<scalar_t> const* vectorized_in = | ||||||||||||||||||||||||||
reinterpret_cast<vec4_t<scalar_t> const*>(input); | ||||||||||||||||||||||||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
int64_t const num_vec_elems = num_elems >> 2; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#pragma unroll 4 | ||||||||||||||||||||||||||
for (int64_t i = tid; i < num_vec_elems; i += step) { | ||||||||||||||||||||||||||
vec4_t<scalar_t> in_vec = vectorized_in[i]; | ||||||||||||||||||||||||||
float8x4_t out_vec; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>( | ||||||||||||||||||||||||||
static_cast<float>(in_vec.x), scale); | ||||||||||||||||||||||||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>( | ||||||||||||||||||||||||||
static_cast<float>(in_vec.y), scale); | ||||||||||||||||||||||||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>( | ||||||||||||||||||||||||||
static_cast<float>(in_vec.z), scale); | ||||||||||||||||||||||||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>( | ||||||||||||||||||||||||||
static_cast<float>(in_vec.w), scale); | ||||||||||||||||||||||||||
vectorized_out[i] = out_vec; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// Handle the remaining elements if num_elems is not divisible by 4 | ||||||||||||||||||||||||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { | ||||||||||||||||||||||||||
out[i] = scaled_fp8_conversion<is_scale_inverted>( | ||||||||||||||||||||||||||
static_cast<float>(input[i]), scale); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename scalar_t> | ||||||||||||||||||||||||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, | ||||||||||||||||||||||||||
const scalar_t* __restrict__ input, | ||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what is the reason changing the values of the two constants? and is this change ROCm specific?