Skip to content

Commit

Permalink
add support for QuantizedCache (#5)
Browse files Browse the repository at this point in the history
* add support for QuantizedCache

* update README

* upgrade version

* update README
  • Loading branch information
SimJeg authored Nov 21, 2024
1 parent 34a7f57 commit 64b3c17
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 41 deletions.
26 changes: 19 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ _Average performance on the RULER dataset with 4k context length and Loogle Shor

Please refer to the [evaluation](evaluation/README.md) directory for more details and results.

## KV cache quantization

We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline:

```python
from transformers import QuantizedCacheConfig, QuantoQuantizedCache

config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)

pipe(..., cache=cache)
```

By default, the `DynamicCache` is used (no quantization).

> [!IMPORTANT]
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto==0.2.4`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)).

## FAQ

<details><summary>
Expand Down Expand Up @@ -165,10 +184,3 @@ Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more d
</details>

<details><summary>

### Is quantization supported ?
</summary>

We don't support quantization of the KV cache yet. Quantization can achieve up to 4x compression moving from (b)float16 to int4 and we believe it is orthogonal to the KV cache pruning strategies proposed in this repository.

</details>
67 changes: 39 additions & 28 deletions kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, DynamicCache, Pipeline
from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor

Expand All @@ -32,6 +32,7 @@ def _sanitize_parameters(
press: Optional[BasePress] = None,
max_new_tokens: int = 50,
max_context_length: Optional[int] = None,
cache: Optional[Cache] = None,
**kwargs,
):
"""
Expand All @@ -42,7 +43,7 @@ def _sanitize_parameters(
----------
question : str, optional
The question to be asked about the context. Exclusive with `questions`.
questions : List[str], optional
questions : list[str], optional
A list of questions to be asked about the context. Exclusive with `question`.
answer_prefix : str, optional
The prefix to be added to the generated answer.
Expand All @@ -52,12 +53,14 @@ def _sanitize_parameters(
The maximum number of new tokens to generate for each answer.
max_context_length : int, optional
The maximum number of tokens in the context. By default will use the maximum length supported by the model.
cache : Cache, optional
The cache to use for the forward pass. Defaults to None (DynamicCache).
**kwargs : dict
Additional keyword arguments, currently ignored.
Returns
-------
Tuple[Dict, Dict, Dict]
Tuple[dict, dict, dict]
A tuple containing three dictionaries:
- preprocess_kwargs: The keyword arguments for the preprocess function.
- forward_kwargs: The keyword arguments for the forward function.
Expand All @@ -75,7 +78,7 @@ def _sanitize_parameters(
"answer_prefix": answer_prefix,
"max_context_length": max_context_length,
}
forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens}
forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens, "cache": cache}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs

def preprocess(
Expand All @@ -90,7 +93,7 @@ def preprocess(
Returns
-------
Dict[str, GenericTensor]
dict[str, GenericTensor]
A dictionary containing the tokenized context (key: "context_ids") and questions (key: "questions_ids").
"""
Expand Down Expand Up @@ -127,47 +130,56 @@ def preprocess(
return {"context_ids": context_ids, "questions_ids": question_ids}

def _forward(
self, input_tensors: dict[str, GenericTensor], max_new_tokens: int = 50, press: Optional[BasePress] = None
self,
input_tensors: dict[str, GenericTensor],
max_new_tokens: int = 50,
press: Optional[BasePress] = None,
cache: Optional[Cache] = None,
):
"""
Forward pass of the kv-press pipeline.
Parameters
----------
input_tensors : Dict[str, GenericTensor]
input_tensors : dict[str, GenericTensor]
A dictionary containing the tokenized context and questions.
max_new_tokens : int, optional
The maximum number of new tokens to generate for each answer. Defaults to 50.
press : BasePress, optional
The key-value press to use for compression. Defaults to None.
cache : Cache, optional
The cache to use for the forward pass. Defaults to None (DynamicCache).
Returns
-------
List[str]
list[str]
A list of generated answers.
"""

context_ids = input_tensors["context_ids"].to(self.model.device)
context_length = context_ids.shape[1]

# Prefilling using the press on the context
if cache is None:
cache = DynamicCache()

with press(self.model) if press is not None else contextlib.nullcontext():
past_key_values = self.model(
self.model(
input_ids=context_ids,
past_key_values=DynamicCache(),
past_key_values=cache,
output_attentions=isinstance(press, ObservedAttentionPress),
num_logits_to_keep=1,
).past_key_values
)

logger.debug(f"Context Length: {context_length}")
logger.debug(f"Compressed Context Length: {past_key_values.get_seq_length()}")
logger.debug(f"Compressed Context Length: {cache.get_seq_length()}")

# Greedy decoding for each question
answers = []
for question_ids in input_tensors["questions_ids"]:
answer = self.generate_answer(
question_ids=question_ids.to(self.model.device),
past_key_values=past_key_values,
cache=cache,
context_length=context_length,
max_new_tokens=max_new_tokens,
)
Expand All @@ -181,7 +193,7 @@ def postprocess(self, model_outputs, single_question):
return {"answers": model_outputs}

def generate_answer(
self, question_ids: torch.Tensor, past_key_values: DynamicCache, context_length: int, max_new_tokens: int
self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int
) -> str:
"""
Generate an answer to a question using greedy decoding.
Expand All @@ -190,7 +202,7 @@ def generate_answer(
----------
question_ids : torch.Tensor
The tokenized question.
past_key_values : DynamicCache
cache : Cache
The compressed key-value cache.
context_length : int
The length of the context.
Expand All @@ -203,18 +215,15 @@ def generate_answer(
The generated answer.
"""

cache_seq_lengths = [
past_key_values.get_seq_length(layer_idx=layer_idx) for layer_idx in range(len(past_key_values))
]

cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))]
position_ids = torch.arange(
context_length, context_length + question_ids.shape[1], device=self.model.device
).unsqueeze(0)

# if the user doesn't provide a question, skip forward pass
outputs = self.model(
input_ids=question_ids.to(self.model.device),
past_key_values=past_key_values,
past_key_values=cache,
position_ids=position_ids,
num_logits_to_keep=1,
)
Expand All @@ -229,7 +238,7 @@ def generate_answer(
for i in range(max_new_tokens - 1):
outputs = self.model(
input_ids=generated_ids[-1].unsqueeze(0).unsqueeze(0),
past_key_values=outputs.past_key_values,
past_key_values=cache,
position_ids=position_ids + i,
)
new_id = outputs.logits[0, -1].argmax()
Expand All @@ -238,13 +247,15 @@ def generate_answer(
break
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)

# remove the generated tokens from the cache
past_key_values.key_cache = [
key[:, :, :cache_seq_len] for key, cache_seq_len in zip(past_key_values.key_cache, cache_seq_lengths)
]
past_key_values.value_cache = [
value[:, :, :cache_seq_len] for value, cache_seq_len in zip(past_key_values.value_cache, cache_seq_lengths)
]
# Remove the generated tokens from the cache
if isinstance(cache, QuantizedCache):
key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache"
else:
key_attr, value_attr = "key_cache", "value_cache"

setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)])
setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)])

return answer


Expand Down
27 changes: 22 additions & 5 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@

import torch
from torch import nn
from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM
from transformers import (
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
PreTrainedModel,
Qwen2ForCausalLM,
QuantizedCache,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,8 +99,12 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
return output

keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]
if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

with torch.no_grad():
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
Expand All @@ -104,8 +115,14 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Update cache
cache.key_cache[module.layer_idx] = keys.gather(2, indices)
cache.value_cache[module.layer_idx] = values.gather(2, indices)
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.0.1"
version = "0.0.2"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down

0 comments on commit 64b3c17

Please sign in to comment.