We extend gpt-fast to support SparQ attention, a bandwidth-efficient attention algorithm that speeds up generation for existing LLMs with no fine tuning. For details of SparQ, see the paper.
The main
branch tracks the gpt-fast repo. The with-sparq
branch contains our modifications. You can compare "main" and "with-sparq" to see what we added.
You might also be interested in sparq-llama.cpp, our implementation of SparQ in llama.cpp.
We obtain the following speedups on an H100 PCIe, using BF16 for the model parameters and KV cache, and compressing the memory transfers 8x with SparQ:
"estimated theoretical max" shows an estimate of the best-cast speedup that could be achieved by SparQ if the attention operation was purely memory-bound, and all compute and communication was overlapped. See theoretical_speedups.py
for how this is calculated.
- Install Python >=3.10
- Install the requirements:
pip install -r requirements.txt
- Run
huggingface-cli login
or set theHF_TOKEN
environment variable. The associated account must have access tometa-llama/Llama-2-7b-chat-hf
- Download Llama 2 7b from Hugging Face, and prepare it for gpt-fast:
./scripts/prepare.sh "meta-llama/Llama-2-7b-chat-hf"
- Updated
expected_gpu
inrun_speedup_benchmark.py
to the expected model of GPU (this avoid accidentally comparing results from different GPUs) - Run the benchmark:
python run_speedup_benchmark.py
SparQ is implemented in PyTorch, not as a custom kernel. However, we found that torch.compile() was able to generate a performant implementation.
This repo is based off gpt-fast, which is released under the BSD 3 license. We also release our modifications under the BSD 3 license. See LICENSE.