Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] Enable Dynamic Per Token fp8 #6547

Merged

Conversation

robertgshaw2-neuralmagic
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic commented Jul 18, 2024

SUMMARY:

  • turns on dynamic per token quantization for fp8
  • enables fallback if cutlass kernels are not supported
  • add test cases

RESULTS:

nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors

  • GSM with dynamic per token:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.753|±  |0.0136|
|     |       |strict-match    |     5|exact_match||0.756|±  |0.0136|
  • GSM with dynamic per tensor:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.748|±  |0.0137|
|     |       |strict-match    |     5|exact_match||0.747|±  |0.0138|

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title enable dyn per token [ Kernel ] Enable Dynamic Per Token fp8 Jul 18, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 18, 2024
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic marked this pull request as ready for review July 18, 2024 16:55
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
use_per_token_if_dynamic=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't use_per_token_if_dynamic be set by a scheme or something? I don't see why it should always be true for this case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what case would we not want to use dynamic per token?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin the hypothesis here is that dynamic per token is an overall win over dynamic per tensor when supported. Higher accuracy, but also easier to fuse RMSNorm + Quant, and fewer dependencies so better parallelization for the quantize kernels overall. Downside is more scales.

We'll need to benchmark, but IMO ok for this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume there is a decent performance hit compared to produce/using per-tensor scale, is this not the case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't assume a decent performance hit -- the overheads from the CUTLASS epilogues are quite small (like 3%), and there are advantages to doing per-token when quantizing as well. We'll measure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed my mind -- as I am implementing the per-token/per-channel wrapper for torch._scaled_mm, I think it would be nicer to grab this from a config somewhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to pass this from the scheme. However, this became very hard because I needed to check if cutlass is supported in multiple places

I think that deciding this here is the right place

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic enabled auto-merge (squash) July 19, 2024 18:34
@robertgshaw2-neuralmagic
Copy link
Collaborator Author

DO NOT LAND UNTIL VARUNS PR GOES IN FIRST

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 4cc24f0 into vllm-project:main Jul 19, 2024
72 checks passed
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic deleted the turn-on-fp8-dyn-per-token branch July 19, 2024 23:08
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
gnpinkert pushed a commit to gnpinkert/vllm that referenced this pull request Jul 26, 2024
cduk pushed a commit to cduk/vllm-pascal that referenced this pull request Aug 6, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants