Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Unify the kernel used in flash attention backend #6052

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Jul 2, 2024

Currently, we are using different kernels for different phases. Concretely, we use flash_attn_with_kvcache for decoding phase and flash_attn_varlen_func for prefill phase and prefix caching. For chunked prefill, we will launch both kernels to handle prefill tokens and decoding tokens separately. The current way has some drawbacks:

  1. This complicates the attention backend logic because we need prefill_metadata and decode_metadata.
  2. Pass more fields from the model_runner to the backend than needed because flash_attn_with_kvcache and flash_attn_varlen_func have different requirements for the input.
  3. Use two kernels for the chunked prefill, which is not performance optimal.
  4. Potential performance degradation because we need to build prefill_metadata and decode_metadata on the fly. But this might be minor since we cache the two metadata.

Moreover, flash_attn_with_kvcache and flash_attn_varlen_func have similar performance as they share the same underlying implementation.
Ideally, we should use a single kernel to handle all cases, including prefill phase, decoding phase, and prefix caching. For chunked prefill, we should just launch a single kernel to handle both the prefill tokens and decoding tokens.

This PR tries to simply the logic in the attention backend and use a single kernel. This is also needed for the MQA scorer (#5691) for speculative decoding.

@LiuXiaoxuanPKU LiuXiaoxuanPKU marked this pull request as draft July 2, 2024 04:55
@LiuXiaoxuanPKU LiuXiaoxuanPKU changed the title [Kernel] Unify the kernel used in flash attention backend [WIP][Kernel] Unify the kernel used in flash attention backend Jul 2, 2024
@rkooo567
Copy link
Collaborator

rkooo567 commented Jul 2, 2024

I think the direction makes sense! It is also more cuda graph friendly approach

QQ

  1. Is this PR ready?
  2. Original reason why I didn't try this before was that I heard the perf wasn't that different (or worse due to some optimizations for decode case). Can you share the benchmark result?

@LiuXiaoxuanPKU LiuXiaoxuanPKU marked this pull request as ready for review July 9, 2024 04:30
@LiuXiaoxuanPKU
Copy link
Collaborator Author

LiuXiaoxuanPKU commented Jul 9, 2024

Yeah, the PR should be ready for review.

Some kernel benchmark numbers on a single A100, all numbers are in ms.

Number of queries tokens Number of heads Head dim flash_attn_varlen_func flash_attn_with_kvcache
100 12 64 0.0917 0.0365
500 12 64 0.379 0.383
1000 12 64 1.292 1.290
100 32 128 0.0989 0.100
500 32 128 1.550 1.549
1000 32 128 5.819 5.837
100 64 128 0.161 0.160
500 64 128 2.965 3.004
1000 64 128 11.308 11.388

Only one case we see great performance degradation

Number of queries tokens Number of heads Head dim flash_attn_varlen_func flash_attn_with_kvcache
100 12 64 0.0917 0.0365

In all other cases, the performance is quite similar.

@LiuXiaoxuanPKU LiuXiaoxuanPKU changed the title [WIP][Kernel] Unify the kernel used in flash attention backend [Kernel] Unify the kernel used in flash attention backend Jul 9, 2024
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM and it's much cleaner now!
cc @rkooo567 and @cadedaniel to have a final pass.

Comment on lines +111 to +116
# Fields that are not used in flash attention backend,
# but used in other backends
context_lens_tensor: Optional[torch.Tensor] = None
seq_lens_tensor: Optional[torch.Tensor] = None
max_prefill_seq_len: Optional[int] = None
max_decode_seq_len: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good finding! I'll remove them after refactoring prepare input.

Comment on lines 235 to 239
if kv_cache is None or (attn_metadata.block_tables is not None
and attn_metadata.block_tables.numel()) == 0:
k = key
v = value
block_tables = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be for pure prefill or memory profiling? Better to add comment for it.

@@ -151,8 +151,8 @@ def execute_model(
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
if prefill_meta is None and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This code snippet is removed by #6338 so this isn't a problem anymore.

@@ -655,6 +655,7 @@ def _prepare_model_input_tensors(
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
seq_lens.append(1)
query_lens.append(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used when calculating query_start_loc, which is the input for the flash attention backend when using the unified kernel.

for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]
if attn_metadata.prefill_metadata:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always None for flash attention backend now?

@rkooo567
Copy link
Collaborator

The review ETA is tonight!

Besides, I'd like to know the e2e performance improvement (or that it matches the performance). Is it possible to run some e2e benchmark with/without the PR and share the result?

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I prefer to see the e2e result before merging it! but PR looks beautiful :))

@rkooo567 rkooo567 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 16, 2024
@jjjjohnson
Copy link

jjjjohnson commented Jul 22, 2024

Looks like the model output is chaos and totally different after unifying the kernel...
I changed the flash_attn.py to the original implementation with flash_attn_with_kvcache for decode and flash_attn_varlen_func for prefill and the result is normal.
Have you check the correctness? @LiuXiaoxuanPKU

@jjjjohnson
Copy link

jjjjohnson commented Jul 22, 2024

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567
My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

@LiuXiaoxuanPKU
Copy link
Collaborator Author

Looks like the model output is chaos and totally different after unifying the kernel... I changed the flash_attn.py to the original implementation with flash_attn_with_kvcache for decode and flash_attn_varlen_func for prefill and the result is normal. Have you check the correctness? @LiuXiaoxuanPKU

Thanks for reporting, will take a look.

@LiuXiaoxuanPKU
Copy link
Collaborator Author

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

@jjjjohnson Could you provide the model/prompt you used for testing. The results seem correct for basic_correctness. Thanks!

@LiuXiaoxuanPKU
Copy link
Collaborator Author

@rkooo567
Some e2e performance numbers of llama-7b on a single H100 with cuda graph. All numbers are 50% percentile request latency in seconds measured with the script.

input_len output_len batch_size this PR main branch
32 128 1 0.883 0.917
32 128 2 0.877 0.912
32 128 4 0.906 0.933
32 128 8 0.946 0.956
32 128 16 1.065 1.084
32 128 32 1.236 1.259
512 32 1 0.229 0.251
512 32 2 0.237 0.259
512 32 4 0.267 0.288
512 32 8 0.326 0.347
512 32 16 0.456 0.498
512 32 32 0.699 0.779

@LiuXiaoxuanPKU
Copy link
Collaborator Author

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

@rkooo567
Copy link
Collaborator

Hmm that's pretty odd. there's nothing lora-related in this kernel iiuc

@comaniac
Copy link
Collaborator

btw I saw a CI failure in LM Eval Small Models as follows

[2024-07-23T14:54:54Z] >               assert numpy.isclose(ground_truth, measured_value, rtol=RTOL)
[2024-07-23T14:54:54Z] E               assert False
--
  | [2024-07-23T14:54:54Z] E                +  where False = <function isclose at 0x7f72938ba070>(0.593, 0.0, rtol=0.02)
  | [2024-07-23T14:54:54Z] E                +    where <function isclose at 0x7f72938ba070> = numpy.isclose

Looks like the measured_value is 0, so the output may be garbage in this case.

@jjjjohnson
Copy link

jjjjohnson commented Jul 24, 2024

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager

@jjjjohnson
Copy link

jjjjohnson commented Jul 24, 2024

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

Looks like short prompt is OK, if you change to example_long_prompts the tests fails...
image

@LiuXiaoxuanPKU
Copy link
Collaborator Author

example_long_prompts

Qwen/Qwen-14B-Chat

I tried the example_long_prompts with Qwen and it did fail. But after looking into that, it fails for both eager and non-eager mode. It also failed for other backends such as XFORMERS. Therefore, it seems like numerical issues in that case. Did you observe similar things?

@LiuXiaoxuanPKU
Copy link
Collaborator Author

LiuXiaoxuanPKU commented Jul 24, 2024

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager

Could you provide the exact prompt and the hardware you use? After some manual checking on H100 with Qwen/Qwen-14B-Chat. Setting enforce-eager or not give the same output. It might also be possible that bugs with cuda graph preparation are not stable. Thanks!

@jjjjohnson
Copy link

jjjjohnson commented Jul 25, 2024

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager

Could you provide the exact prompt and the hardware you use? After some manual checking on H100 with Qwen/Qwen-14B-Chat. Setting enforce-eager or not give the same output. It might also be possible that bugs with cuda graph preparation are not stable. Thanks!

I use A800 TP1.
Prompt:
The rapid advancement in artificial intelligence (AI) has yielded a variety of groundbreaking technologies, among which Large Language Models (LLMs) have garnered widespread attention and utility. LLMs, such as OpenAI’s GPT-4, Google's BERT, and others, have profoundly transformed the landscape of natural language processing (NLP) over the past few years. But what exactly are LLMs, and why are they so significant?At their core, LLMs are a subset of machine learning models designed to understand.LLMs are versatile and can be fine-tuned for a variety of applications. From drafting emails and writing code to translating languages and composing poetry, the potential use cases are vast.
image

If I change
@pytest.mark.parametrize("backend", ["FLASH_ATTN","XFORMERS"])
to
@pytest.mark.parametrize("backend", ["XFORMERS","FLASH_ATTN"])
The tests get passes... Pretty odd...
image

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 9, 2024

@jjjjohnson , when you say the test fails, was the output gibberish or still something reasonable? Changing the kernel may change the numerics slightly?

example_long_prompts

I think there is more likelihood to accumulate numerical error for long prompts so this checks out?

@LiuXiaoxuanPKU
Copy link
Collaborator Author

Updates for this PR:

  1. We will take a less aggressive approach. We will keep the original kernels for prefill and decoding. We will use flash_attn_varlen_func for mixed batch. Mixed batch means batches with prefill tokens and decoding tokens. The goal is to enable cuda graph for chunked prefill and speculative decoding.
  2. We need to debug the cudagraph compability for flash_attn_varlen_func kernel as it fails the unit tests.

@pengwu22
Copy link

Hi @LiuXiaoxuanPKU

Based on your current test case defined in tests/kernels/test_flash_attn.py; here is a modified version: test_varlen_cg.py

It should pass the given case for mixed prefill and decode now, with vllm_flash_attn v2.6.2. python3 -m pytest test_varlen_cg.py

The major modifications are the following when use flash_attn_varlen_func with cuda graph:

  • We need to keep the max_query_len and max_kv_len static, to ensure CPU var takes no effect on results.
    • Given the params.num_split is 0 now and we still provide page table, it will dispatch to flash_fwd_splitkv_kernel. So we need to keep i) all the kernel grid dims static, which uses the max_query_len ii) the kernel template static, which uses the max_kv_len.
  • Static GPU memory for g_cu_query_lens and g_cu_kv_lens and g_block_tables; and pad the rest batch index with the non-decreasing seqlens.
    • Cuz their shape has a batch dimension, we need to keep them static and pad the rest. It thus requires the capture uses the largest number of query token needed. For example, we capture with [(1, 1), (1, 1), (1,1), (1, 1), (1, 1),(1, 1),(1, 1)] and can run with [[(5, 18), (1, 473), (1, 6),(0,0),(0,0),(0,0),(0,0)]]. (0,0) here are needed to keep the prepared padded cu_*_lens non-decreasing, so that GPU blocks responsible for the padded dim won't pollute the result.

Feel free to try it out. Hope it helps :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants