
QuEST is a Quantization-Aware Training (QAT) method that allows stable training of 1-bit Llama-type LLMs, and makes 4-bit training Pareto-optimal.
Create a conda environment and install dependencies (we recommend Python 3.10):
conda create -n env python=3.10
conda activate env
pip install -r requirements.txt
Run a simple training on the SlimPajama 6B dataset:
python ./src/main.py
The above command trains a 213.34M parameters model with the Llama-style architecture. We recommend to use the --compile
flag that speeds up training noticeably (up to 20% in our setup).
One approach to reducing the massive costs of large language models (LLMs) is the use of quantized or sparse representations for training or deployment. While post-training compression methods are very popular, the question of obtaining even more accurate compressed models by directly training over such representations, i.e., Quantization-Aware Training (QAT), is still open: for example, a recent study (arXiv:2411.04330v2) put the "optimal" bit-width at which models can be trained using QAT, while staying accuracy-competitive with standard FP16/BF16 precision, at 8-bits weights and activations.
We advance this state-of-the-art via a new method called QuEST, which is Pareto-competitive with FP16, i.e., it provides better accuracy at lower model size, while training models with weights and activations in 4-bits or less. Moreover, QuEST allows stable training with 1-bit weights and activations. QuEST achieves this by improving two key aspects of QAT methods: (1) accurate and fast quantization of the (continuous) distributions of weights and activations via Hadamard normalization and MSE-optimal fitting; (2) a new trust gradient estimator based on the idea of explicitly minimizing the error between the noisy gradient computed over quantized states and the "true" (but unknown) full-precision gradient. Experiments on Llama-type architectures show that QuEST induces stable scaling laws across the entire range of hardware-supported precisions, and can be extended to sparse representations. We provide GPU kernel support showing that models produced by QuEST can be executed efficiently. Our code is available at this https URL.
See train.sh
for an example on how to run quantized training.
We provide Triton/CUDA kernels for INT4 Inference on NVIDIA Ampere GPUs. The code is particularly optimized for RTX 4090.
[Work in Progress] The floating point numbers are currently hardcoded to type float16 in the matrix multiplication and dequantization CUTLASS template. We are adapting it to type bfloat16. However, the difference in the data types should not affect the speedup benchmarks.
- cmake
- C++ compiler (GCC/clang/...)
- nvcc
cd gemm-quest
pip install -e .
After installation, please copy the Python scripts in the gemm-quest/quest
folder to the src
folder in our main QuEST project folder.
You can then run the layerwise and end-to-end benchmarks using these scripts. You may need to adjust some configurations in the files to adapt your machine.
This notebook and similar provide a way to test that the trained model generate coherent text and are, indeed, serveable in quantized formats.
To run the HellaSWAG benchmark, please run the following command:
python ./src/eval_hswag.py --model_name "<MODEL_OUTPUT_NAME>" --ckpts_dir "./ckpts"
This repository is based on the epfml/schedules-and-scaling repository for their "Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations" paper. We thank the original creators for making public and open-licenced.
@misc{panferov2025queststabletrainingllms,
title={QuEST: Stable Training of LLMs with 1-Bit Weights and Activations},
author={Andrei Panferov and Jiale Chen and Soroush Tabesh and Roberto L. Castro and Mahdi Nikdan and Dan Alistarh},
year={2025},
eprint={2502.05003},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2502.05003},
}