Skip to content

Conversation

@phymhan
Copy link

@phymhan phymhan commented Apr 14, 2025

Motivation

The KV cache stores the intermediate representations from previous tokens to accelerate autoregressive decoding. For long sequences, the KV cache can consume more GPU memory than the model weights. During inference, LLM decoding becomes memory-bound, with most of the time spent on data transfer rather than computation. This has led to active research on KV cache quantization, but quantization errors can accumulate as more tokens are generated, causing later tokens to deviate from expected outputs.

This PR

This PR adds the state-of-the-art training-free KV cache quantization method: SQuat (Subspace-orthogonal KV cache quantization). It can significantly reduce memory overhead and latency while maintaining model accuracy.
SQuat constructs a subspace that captures critical task-relevant information, then enforces quantization errors to lie orthogonal to this subspace, minimizing their effect on the output of the attention mechanism.

🌟 Highlights

  • Training-free: No fine-tuning or calibration data needed
  • On-the-fly: Runs during inference without modifying the model
  • Theory-grounded: Built on a theoretical foundation

⚡ Efficient

  • Reduces GPU peak memory by 2.17× to 2.82×
  • Improves throughput by 2.45× to 3.60×
  • Outperforms existing KV cache quantization methods on benchmark tasks

🏃🏻 Example

Run example.py or:

# LLaMA model with SQuat
import torch
import os
from models.llama_squat import LlamaForCausalLM_SQuat
from transformers import LlamaConfig, AutoTokenizer
config = LlamaConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

config.k_bits = K_BITS # current support 2/4 bit for KV Cache
config.v_bits = V_BITS # current support 2/4 bit for KV Cache
config.group_size = GROUP_SIZE
config.residual_length = RESIDUAL_LENGTH # the number of recent fp16 tokens
config.use_flash = True
config.method = "squat_pre"
config.subspace_dim = 5
config.squat_lambda = 0.001
config.quant_group_size = 64
CACHE_DIR = PATH_TO_YOUR_SAVE_DIR

model = LlamaForCausalLM_SQuat.from_pretrained(
    pretrained_model_name_or_path='meta-llama/Llama-3.1-8B-Instruct',
    config=config,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(
    'meta-llama/Llama-3.1-8B-Instruct', 
    use_fast=False, 
    trust_remote_code=True)

# Inference
# e.g., model.generate(...)

@Rocketknight1
Copy link
Member

cc @MekkCyber

@haowang94
Copy link

@MekkCyber Sorry to bother you --- just wondering if you could take a look at this PR when you have time. Happy to make any changes needed!

@MekkCyber
Copy link

Sorry @haowang94 forgot about it ! will take a look asap

Copy link

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @phymhan! The project looks great — thanks for sharing. That said, it adds quite a bit of complexity on the model side. Some of that logic might be better suited for a new cache implementation instead, like what’s done here:
https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L821

If you’re interested in exploring an integration into transformers, I’d be happy to help! The only limitation I see for now is that we probably can’t support calling kernels inside the model for on-the-fly unpacking & dequantization, we would dequantize on the cache side.

@phymhan
Copy link
Author

phymhan commented Apr 30, 2025

Hi @MekkCyber Definitely---this sounds interesting to us. Thanks for the suggestion and willingness to help! We'll experiment with the current implementation of quantized cache class and explore integrating this method as a new cache implementation. We'll keep you posted.

@MekkCyber
Copy link

Sounds great ! very excited about this 🔥

@phymhan
Copy link
Author

phymhan commented May 9, 2025

Hey @MekkCyber, just a quick follow-up, we've made a PR here: huggingface/transformers#38055. Would love your thoughts when you get a chance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants