Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Improved PagedAttention FP8 (faster kvcache dequant v1) #346

Open
wants to merge 2 commits into
base: llama_fp8_12062024
Choose a base branch
from

Conversation

tjtanaa
Copy link

@tjtanaa tjtanaa commented Dec 24, 2024

Description

This is a PR to merge https://github.com/ROCm/vllm/tree/shsanyal_devpa_308_opt optimized attention.cu kernel into llama_fp8_12062024 branch.

CAVEAT

Currently the attention.cu kernel does not support block size of 32 and head size of 64.
The vLLM model unittests are failing as it uses small models e.g. Gemma, Llama which has head size of 64.

Performance

The following is a benchmark_throughput results of Llama-3.1-70B with fp8 dynamic quantization and kv-cache-dtype of fp8_e4m3. For sequence input token length 2048 and output token length 2048:

Branch of vll-rocmfork Req/s Total Tokens/s Output Tokens/s
main 0.29 1196.2 598.1
llama-fp8-12062024 0.28 1152.46 576.23
pagged-attn-fp8 0.47 1932.74 966.37

Copy link

@hongxiayang hongxiayang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the integration. Left some comments and suggestions about coding style.

Comment on lines +33 to +39
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));

return old;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));

template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kindly name the variable with meaningful name, like scaledValue.

if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if scale is zero, error handling?

x = val / scale;
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming r as result, or something like that

const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
int64_t index = blockDim.x * blockIdx.x + threadIdx.x;

@@ -150,6 +160,8 @@ def test_paged_attention(
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
#query = torch.ones_like(query)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#query = torch.ones_like(query)

Comment on lines +64 to +67
#print('>>> ref qkout shape',attn_weights.shape)
#print('>>> ref qkout',attn_weights)
#global REF_TENSOR
#REF_TENSOR = attn_weights

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#print('>>> ref qkout shape',attn_weights.shape)
#print('>>> ref qkout',attn_weights)
#global REF_TENSOR
#REF_TENSOR = attn_weights

SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to change the multi-gpu test to only single gpu test. are you sure you want to have this change committed?

NUM_BLOCKS = 1024 * 1024
PARTITION_SIZE = 512
NUM_BLOCKS = 256 * 1024
PARTITION_SIZE = 256

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is the reason changing the values of the two constants? and is this change ROCm specific?

@@ -101,7 +101,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
start_time = time.perf_counter()

# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = 0.1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, can you explain the default kv-scale change?

@tjtanaa tjtanaa changed the title [FEAT] Improved PagedAttention FP8 (faster kvcache dequant) [FEAT] Improved PagedAttention FP8 (faster kvcache dequant shsanyal) Dec 27, 2024
@tjtanaa tjtanaa changed the title [FEAT] Improved PagedAttention FP8 (faster kvcache dequant shsanyal) [FEAT] Improved PagedAttention FP8 (faster kvcache dequant v1) Dec 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants