Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer #6051

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

Add logits_soft_cap for flashinfer, which is needed by Gemma2 model, also add a simple gemma2 test.

@LiuXiaoxuanPKU LiuXiaoxuanPKU requested a review from Yard1 July 2, 2024 00:30
Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check the flashinfer version and raise if it's too old

Comment on lines 666 to 670
logger.warning("Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should just raise an exception IMO.

@comaniac comaniac self-assigned this Jul 2, 2024
Comment on lines 662 to 665
logits_soft_cap = getattr(self.model_config.hf_config,
'final_logit_softcapping', None)
if logits_soft_cap is not None and self.attn_backend.get_name(
) != "flashinfer":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I check if logits_soft_cap is supposed to be the attn_logit_softcapping value instead? The two values are different in the Gemma2 config.

"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongxb Nice catch! final_logit_softcapping is used to cap the final logits before sampling. @LiuXiaoxuanPKU Could you please fix this?

@zifeitong
Copy link
Contributor

I think this warning can be removed to avoid confusion:

if self.config.attn_logit_softcapping is not None:

@zifeitong
Copy link
Contributor

I am able to run and reproduce the reported MMLU scores for both 9b and 27b models 👍

However, if I don't disable CUDA graph, vLLM will crash with this error:

ERROR 07-04 03:34:04 async_llm_engine.py:53]   File "vllm/vllm/worker/model_runner.py", line 1202, in execute_model
ERROR 07-04 03:34:04 async_llm_engine.py:53]     model_input.attn_metadata.decode_wrapper = self.graph_runners[
ERROR 07-04 03:34:04 async_llm_engine.py:53] IndexError: list index out of range

@LiuXiaoxuanPKU
Copy link
Collaborator Author

I am able to run and reproduce the reported MMLU scores for both 9b and 27b models 👍

However, if I don't disable CUDA graph, vLLM will crash with this error:

ERROR 07-04 03:34:04 async_llm_engine.py:53]   File "vllm/vllm/worker/model_runner.py", line 1202, in execute_model
ERROR 07-04 03:34:04 async_llm_engine.py:53]     model_input.attn_metadata.decode_wrapper = self.graph_runners[
ERROR 07-04 03:34:04 async_llm_engine.py:53] IndexError: list index out of range

Thanks for reporting! Could you give me an minimal reproducible example since I can run gemma-2 with flashinfer cudagraph on my end. Thanks!

@zifeitong
Copy link
Contributor

I am using the run_batch script:

python -m vllm.entrypoints.openai.run_batch -i requests.jsonl -o /dev/null --model google/gemma-2-9b-it --disable-log-request

requests.jsonl.zip

@LiuXiaoxuanPKU
Copy link
Collaborator Author

I am using the run_batch script:

python -m vllm.entrypoints.openai.run_batch -i requests.jsonl -o /dev/null --model google/gemma-2-9b-it --disable-log-request

requests.jsonl.zip

I tried the script and data on H100, it seems work. Could you report your environment? Flashinfer only supports GPU with compute capability greater than 8.0 (https://developer.nvidia.com/cuda-gpus). Not sure if that might be the problem.

@zifeitong
Copy link
Contributor

I am using the run_batch script:

python -m vllm.entrypoints.openai.run_batch -i requests.jsonl -o /dev/null --model google/gemma-2-9b-it --disable-log-request

requests.jsonl.zip

I tried the script and data on H100, it seems work. Could you report your environment? Flashinfer only supports GPU with compute capability greater than 8.0 (https://developer.nvidia.com/cuda-gpus). Not sure if that might be the problem.

I am using H100 with CUDA 12.5. Can you try sync you branch to the latest? #4412 might be related (it refactors graph_runners).

@LiuXiaoxuanPKU
Copy link
Collaborator Author

I am using the run_batch script:

python -m vllm.entrypoints.openai.run_batch -i requests.jsonl -o /dev/null --model google/gemma-2-9b-it --disable-log-request

requests.jsonl.zip

I tried the script and data on H100, it seems work. Could you report your environment? Flashinfer only supports GPU with compute capability greater than 8.0 (https://developer.nvidia.com/cuda-gpus). Not sure if that might be the problem.

I am using H100 with CUDA 12.5. Can you try sync you branch to the latest? #4412 might be related (it refactors graph_runners).

yes, it's a merge conflict. Just fixed, please try again. Thanks!

@zifeitong
Copy link
Contributor

I am using the run_batch script:

python -m vllm.entrypoints.openai.run_batch -i requests.jsonl -o /dev/null --model google/gemma-2-9b-it --disable-log-request

requests.jsonl.zip

I tried the script and data on H100, it seems work. Could you report your environment? Flashinfer only supports GPU with compute capability greater than 8.0 (https://developer.nvidia.com/cuda-gpus). Not sure if that might be the problem.

I am using H100 with CUDA 12.5. Can you try sync you branch to the latest? #4412 might be related (it refactors graph_runners).

yes, it's a merge conflict. Just fixed, please try again. Thanks!

Thanks for the fix. It works now, w/ or w/o CUDA graph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants