Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix empty (nullptr) channelwise scales when loading wNa16 using compressed tensors #6798

Merged

Conversation

LucasWilkinson
Copy link
Contributor

When running row parallel with channelwise scales, the integer divide in:

// in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py:
scales_and_zp_size = input_size_per_partition // group_size

would result in a 0 sized tensor (with shape (0, 8192), i.e. tensor([], device='cuda:2', size=(0, 8192), dtype=torch.bfloat16)) leading to a nullptr getting passed into marlin.

This is because when its channelwise group_size was getting set to input_size which is greater than input_size_per_partition in the row parallel case when tp > 1

Now for channelwise scales when running row parallel with a tp > 1 we replicate the scales to all gpus.

This issue was present for the model: "nm-testing/Meta-Llama-3.1-70B-Instruct-quantized.w8a16"

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/channelwise-bug-fix branch from cdbe26d to a3b0efa Compare July 25, 2024 20:22
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tracking this down

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 25, 2024
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice debugging, thank you!

@mgoin mgoin enabled auto-merge (squash) July 25, 2024 21:28
@simon-mo simon-mo disabled auto-merge July 25, 2024 22:05
@simon-mo simon-mo merged commit cd7edc4 into vllm-project:main Jul 25, 2024
69 of 72 checks passed
cadedaniel pushed a commit to cadedaniel/vllm-public that referenced this pull request Jul 27, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants