-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable gptqmodel #35012
base: main
Are you sure you want to change the base?
Enable gptqmodel #35012
Changes from 38 commits
4c567b3
1d8f83e
9f44604
62cd0dd
8c88315
ef0fb56
0191322
0655960
be914ea
aa9a5c6
a4bc251
9ae979b
a73a8c2
c18a5f1
27ac615
d3ad24b
3972d2e
2612dd7
99b2ed7
ac14b9f
0276854
8bde513
4ffc7d1
5474f89
f9e7e45
99b5f14
331b56a
409f6a2
c996a41
84e972c
dbf68e8
f4c2ad3
9185f8b
8d69ba4
226953a
65ee44b
34d0ec0
9d71301
153121a
b270b2d
7120899
a7fcfd7
8e36a0e
0aef2df
31a6baa
d7c8890
db33fd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -22,15 +22,41 @@ Try GPTQ quantization with PEFT in this [notebook](https://colab.research.google | |||||
|
||||||
</Tip> | ||||||
|
||||||
The [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) library implements the GPTQ algorithm, a post-training quantization technique where each row of the weight matrix is quantized independently to find a version of the weights that minimizes the error. These weights are quantized to int4, but they're restored to fp16 on the fly during inference. This can save your memory-usage by 4x because the int4 weights are dequantized in a fused kernel rather than a GPU's global memory, and you can also expect a speedup in inference because using a lower bitwidth takes less time to communicate. | ||||||
Both [GPTQModel](https://github.com/ModelCloud/GPTQModel) and [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) libraries implements the GPTQ algorithm, a post-training quantization technique where each row of the weight matrix is quantized independently to find a version of the weights that minimizes error. These weights are quantized to int4, stored as int32 (int4 x 8) and dequantized (restored) to fp16 on the fly during inference. This can save memory-usage by almost 4x because the int4 weights are often dequantized in a fused kernel. One can also expect a substantial speedup in inference due to lower bandwidth requirements for lower bitwidth. | ||||||
|
||||||
Before you begin, make sure the following libraries are installed: | ||||||
[GPTQModel](https://github.com/ModelCloud/GPTQModel) has its origin as a maintained fork of AutoGPTQ but has since differentiated itself with the following major differences: | ||||||
|
||||||
* Model support: GPTQModel continues to support all of the latest released LLM models. | ||||||
* Multi-Modal support: GPTQModel supports accurate quantization of Qwen 2-VL and Ovis 1.6-VL image-to-text models. | ||||||
* Platform support: Validated MacOS Apple Silicone and Windows 11 support. | ||||||
* Hardware support: Apple silicone M1+, Intel/AMD CPU, and Intel Datacetner Max + Arc GPUs. | ||||||
* IPEX kernel for Intel/AMD accelerated CPU and Intel GPU (Datacenter Max + ARc) support. | ||||||
* Updated Marlin kernel from Neural Magic that is higly optimized for A100 | ||||||
* Updated Kernels with auto-padding for legacy model support and models with non-uniform in/out-features. | ||||||
* Faster quantization, lower memory usage, and more accurate default quantization via GPTQModel quantization apis. | ||||||
* User and developer friendly apis. | ||||||
|
||||||
|
||||||
[AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) will likely be deprecated in the future due the lack of continued support for new models and features. | ||||||
|
||||||
Before you begin, make sure the following libraries are installed and updated to the latest release: | ||||||
|
||||||
```bash | ||||||
pip install auto-gptq | ||||||
pip install --upgrade accelerate optimum transformers | ||||||
``` | ||||||
|
||||||
Then install either GPTQModel or AutoGPTQ. | ||||||
|
||||||
```bash | ||||||
pip install gptqmodel --no-build-isolation | ||||||
``` | ||||||
|
||||||
or | ||||||
|
||||||
```bash | ||||||
pip install auto-gptq --no-build-isolation | ||||||
``` | ||||||
|
||||||
To quantize a model (currently only supported for text models), you need to create a [`GPTQConfig`] class and set the number of bits to quantize to, a dataset to calibrate the weights for quantization, and a tokenizer to prepare the dataset. | ||||||
|
||||||
```py | ||||||
|
@@ -92,9 +118,14 @@ from transformers import AutoModelForCausalLM | |||||
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto") | ||||||
``` | ||||||
|
||||||
## Marlin | ||||||
|
||||||
[Marlin](https://github.com/IST-DASLab/marlin) is a CUDA gptq kernel, 4-bit only, that is highly optimized for the Nvidia A100 GPU (Ampere) architecture where the the loading, dequantization, and execution of post-dequantized weights are highly parallelized offering a substantial inference improvement versus the original CUDA gptq kernel. Marlin is only available for quantized inference and does support model quantization. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can we add a snippet to show to the user who to use it ? Generally, it will help a lot to the user if we explain a bit how the backend attribute in GPTQConfig works.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SunMarc Good idea. Example usage of selection of Marlin via |
||||||
|
||||||
|
||||||
## ExLlama | ||||||
|
||||||
[ExLlama](https://github.com/turboderp/exllama) is a Python/C++/CUDA implementation of the [Llama](model_doc/llama) model that is designed for faster inference with 4-bit GPTQ weights (check out these [benchmarks](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)). The ExLlama kernel is activated by default when you create a [`GPTQConfig`] object. To boost inference speed even further, use the [ExLlamaV2](https://github.com/turboderp/exllamav2) kernels by configuring the `exllama_config` parameter: | ||||||
[ExLlama](https://github.com/turboderp/exllama) is a CUDA implementation of the [Llama](model_doc/llama) model that is designed for faster inference with 4-bit GPTQ weights (check out these [benchmarks](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)). The ExLlama kernel is activated by default when you create a [`GPTQConfig`] object. To boost inference speed even further, use the [ExLlamaV2](https://github.com/turboderp/exllamav2) kernels by configuring the `exllama_config` parameter: | ||||||
|
||||||
```py | ||||||
import torch | ||||||
|
@@ -110,11 +141,11 @@ Only 4-bit models are supported, and we recommend deactivating the ExLlama kerne | |||||
|
||||||
</Tip> | ||||||
|
||||||
The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2), then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file. | ||||||
The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ or GPTQModel, then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file. | ||||||
|
||||||
```py | ||||||
import torch | ||||||
from transformers import AutoModelForCausalLM, GPTQConfig | ||||||
gptq_config = GPTQConfig(bits=4, use_exllama=False) | ||||||
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="cpu", quantization_config=gptq_config) | ||||||
``` | ||||||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding 😊