Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ / Quantization support? #174

Closed
nikshepsvn opened this issue Jun 21, 2023 · 19 comments
Closed

GPTQ / Quantization support? #174

nikshepsvn opened this issue Jun 21, 2023 · 19 comments

Comments

@nikshepsvn
Copy link

Will vLLM support 4-bit GPTQ models?

@WoosukKwon
Copy link
Collaborator

Thanks for the feature request! Quantization is not currently supported, but it's definitely on our roadmap. Please stay tuned.

@nikshepsvn
Copy link
Author

How do I best go about tracking this? Is there a discord or public roadmap somewhere I can look at?

@Symbolk
Copy link

Symbolk commented Aug 15, 2023

How do I best go about tracking this? Is there a discord or public roadmap somewhere I can look at?

See Roadmap here: #244

@chu-tianxiang
Copy link
Contributor

chu-tianxiang commented Aug 23, 2023

I looked into this a bit today and it seems straight forward to integrate AutoGPTQ into vllm, so I implemented a preliminary version for LLaMA (see this commit) and did a few benchmarks on single A100-80G. I don't know why but It's slower than expected.

python benchmark_throughput.py --model TheBloke/Llama-2-13B-chat-GPTQ --dataset ShareGPT_V3_unfiltered_cleaned_split.json
Model Throughout (requests/s) Throughout (tokens/s)
meta-llama/Llama-2-13b-chat-hf 4.00 1915
TheBloke/Llama-2-13B-chat-GPTQ 3.32 1587
TheBloke/Llama-2-70B-chat-GPTQ 1.09 519

@osilverstein
Copy link

I looked into this a bit today and it seems straight forward to integrate AutoGPTQ into vllm, so I implemented a preliminary version for LLaMA (see this commit) and did a few benchmarks on single A100-80G. I don't know why but It's slower than expected.

python benchmark_throughput.py --model TheBloke/Llama-2-13B-chat-GPTQ --dataset ShareGPT_V3_unfiltered_cleaned_split.json

Model Throughout (requests/s) Throughout (tokens/s)
meta-llama/Llama-2-13b-chat-hf 4.00 1915
TheBloke/Llama-2-13B-chat-GPTQ 3.32 1587
TheBloke/Llama-2-70B-chat-GPTQ 1.09 519

what's the baseline with normal version?

@chu-tianxiang
Copy link
Contributor

chu-tianxiang commented Aug 25, 2023

I looked into this a bit today and it seems straight forward to integrate AutoGPTQ into vllm, so I implemented a preliminary version for LLaMA (see this commit) and did a few benchmarks on single A100-80G. I don't know why but It's slower than expected.

python benchmark_throughput.py --model TheBloke/Llama-2-13B-chat-GPTQ --dataset ShareGPT_V3_unfiltered_cleaned_split.json

Model Throughout (requests/s) Throughout (tokens/s)
meta-llama/Llama-2-13b-chat-hf 4.00 1915
TheBloke/Llama-2-13B-chat-GPTQ 3.32 1587
TheBloke/Llama-2-70B-chat-GPTQ 1.09 519

what's the baseline with normal version?

If you mean the throughput, in the above table TheBloke/Llama-2-13B-chat-GPTQ is quantized from meta-llama/Llama-2-13b-chat-hf and the throughput is about 17% less.

I dug into the kernel code of quant linear layer and found that it falls back to dequantization followed by fp16 matrix multiplication when the batch size is bigger than 8, so the performance degradation is understandable.

@chu-tianxiang
Copy link
Contributor

As an update, I added tensor parallel QuantLinear layer and supported most AutoGPT compatible models in this branch. The code has not been thoroughly tested yet because the combinations of model architectures and GPTQ settings are way too many.

@singularity-sg
Copy link

@chu-tianxiang I tried forking your vllm-gptq branch and was successful deploying the TheBloke/Llama-2-13b-Chat-GPTQ model. However, when I tried the TheBloke/Llama-2-7b-Chat-GPTQ model, it threw the following exception whenever I made a query to the model. I wonder if the issue is with the model itself or something else. I'll dig further into this when I have the chance but it's likely the Sampler was generating the probability tensor with invalid values

INFO:     13.229.18.8:52663 - "POST /v1/completions HTTP/1.1" 500 Internal Server Error
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/home/ubuntu/vllm-gptq/vllm/engine/async_llm_engine.py", line 28, in _raise_exception_on_finish
    task.result()
  File "/home/ubuntu/vllm-gptq/vllm/engine/async_llm_engine.py", line 351, in run_engine_loop
    has_requests_in_progress = await self.engine_step()
                               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/engine/async_llm_engine.py", line 330, in engine_step
    request_outputs = await self.engine.step_async()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/engine/async_llm_engine.py", line 191, in step_async
    output = await self._run_workers_async(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/engine/async_llm_engine.py", line 220, in _run_workers_async
    all_outputs = await asyncio.gather(*all_outputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/.pyenv/versions/3.11.3/lib/python3.11/asyncio/tasks.py", line 684, in _wrap_awaitable
    return (yield from awaitable.__await__())
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(RuntimeError): ray::RayWorker.execute_method() (pid=68296, ip=172.31.40.30, actor_id=3b90ca9f90ebf20a67ae6c2c01000000, repr=<vllm.engine.ray_utils.RayWorker object at 0x7ef09122dad0>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/engine/ray_utils.py", line 32, in execute_method
    return executor(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/.pyenv/versions/vllm-gptq/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/worker/worker.py", line 305, in execute_model
    output = self.model(
             ^^^^^^^^^^^
  File "/home/ubuntu/.pyenv/versions/vllm-gptq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/model_executor/models/llama.py", line 296, in forward
    next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/.pyenv/versions/vllm-gptq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/model_executor/layers/sampler.py", line 85, in forward
    return _sample(probs, logprobs, input_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/model_executor/layers/sampler.py", line 451, in _sample
    sample_results = _random_sample(seq_groups, is_prompts,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/vllm-gptq/vllm/model_executor/layers/sampler.py", line 342, in _random_sample
    random_samples = torch.multinomial(probs,
                     ^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

@wejoncy
Copy link
Contributor

wejoncy commented Oct 25, 2023

Hi,
If anyone wants try GPTQ quantizationo in vLLM.
Please use this repo QLLM to quantize model(LLama) and it would compatiable AWQ in vLLM.
And Of courcr you can select AWQ to quantize it as well.

@David-Lee-1990
Copy link

is baichuan-gptq supported?

@sssuperrrr
Copy link

支持Qwen-72B-Chat-Int4加速吗?

@uncensorie
Copy link

uncensorie commented Dec 9, 2023

Hi, If anyone wants try GPTQ quantizationo in vLLM. Please use this repo QLLM to quantize model(LLama) and it would compatiable AWQ in vLLM. And Of courcr you can select AWQ to quantize it as well.

Something is off with this QLLM gptq quantization @wejoncy ... all the dependencies aren't specified in requirements file. Also, 3 times tried quantizing and every time it breaks when it tries to save the file or quant is done. Tried Llama2-70b and mistral 7b

Screenshot 2023-12-09 at 15 21 57

@wejoncy
Copy link
Contributor

wejoncy commented Dec 10, 2023

Hi, If anyone wants try GPTQ quantizationo in vLLM. Please use this repo QLLM to quantize model(LLama) and it would compatiable AWQ in vLLM. And Of courcr you can select AWQ to quantize it as well.

Something is off with this QLLM gptq quantization @wejoncy ... all the dependencies aren't specified in requirements file. Also, 3 times tried quantizing and every time it breaks when it tries to save the file or quant is done. Tried Llama2-70b and mistral 7b

Hi,
Thanks for try this out and sorry for the inconvenient. This bug has been fixed in latest:main.
For now, you can have too ways to use GPTQ quant method in vLLM with qllm tool.

  1. such as Llama-families, convert to AWQ ifi you didn't enable act_order and set bits==4 and there is no mix bits inside.
  2. use GPTQ directly. But the GPTQ branch in vLLM is on the way merged.

@jacobwarren
Copy link

Is there any update for 8bit support? That would help Mixtral generate useable outputs on a single (non-overpriced) GPU.

@hmellor
Copy link
Collaborator

hmellor commented Feb 2, 2024

I have successfully used both GPTQ and AWQ models with vLLM.

Should this issue be considered solved @WoosukKwon?

@jacobwarren
Copy link

@hmellor it currently works with 4-bit, but not 8-bit. Currently you have to use chu-tianxiang/vllm-gptq to get 8-bit support.

@hmellor
Copy link
Collaborator

hmellor commented Mar 6, 2024

Closing as this was resolved by #2330

@hmellor hmellor closed this as completed Mar 6, 2024
@SuperBruceJia
Copy link

Is vLLM well-supporting the int-2 GPTW models?

Thank you very much!

@SuperBruceJia
Copy link

mht-sharma pushed a commit to mht-sharma/vllm that referenced this issue Oct 30, 2024
Sanity check done:
Server mode; BS1 perf; Llama405b FP8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests