Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Improved PagedAttention FP8 (faster kvcache dequant v1) #346

Open
wants to merge 2 commits into
base: llama_fp8_12062024
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)

NUM_BLOCKS = 1024 * 1024
PARTITION_SIZE = 512
NUM_BLOCKS = 256 * 1024
PARTITION_SIZE = 256

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is the reason changing the values of the two constants? and is this change ROCm specific?



@torch.inference_mode()
Expand Down Expand Up @@ -101,7 +101,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
start_time = time.perf_counter()

# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = 0.1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, can you explain the default kv-scale change?


for _ in range(num_iters):
if version == "v1":
Expand Down Expand Up @@ -161,6 +161,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_cache_dtype,
k_scale,
v_scale,
None,
PARTITION_SIZE
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down
177 changes: 173 additions & 4 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,185 @@
#include "common.cuh"
#include "dispatch_utils.h"

#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#ifndef USE_ROCM
#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

#if defined(USE_CUDA_FP8_FORMAT)
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#if defined(USE_CUDA_FP8_FORMAT)
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));

return old;
}
Comment on lines +33 to +39

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));


template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kindly name the variable with meaningful name, like scaledValue.

x = val * scale;
} else {
x = val / scale;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if scale is zero, error handling?


float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#if defined(USE_CUDA_FP8_FORMAT)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming r as result, or something like that

return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}

// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
int64_t index = blockDim.x * blockIdx.x + threadIdx.x;

// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;

__syncthreads();

// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}

template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};

typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;

template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);

int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More common c++ const usage and variable style

Suggested change
int64_t const num_vec_elems = num_elems >> 2;
const int64_t numVecElems = numElems >> 2;


#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}

return absmax_val;
}

template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int64_t const num_vec_elems = num_elems >> 2;

#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input,
Expand Down
Loading
Loading