Skip to content

Commit

Permalink
Add chunk press
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjeblick authored Jan 21, 2025
1 parent 7f6730d commit de204fc
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Finally we provide special presses:
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio
- `ComposedPress`: compose multiple presses together by chaining their forward hooks
- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`.
- `ChunkPress`: compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences. The press can be used with any other press that inherits from `ScorerPress`. The method was introduced [here](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280).

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
10 changes: 0 additions & 10 deletions evaluation/ruler/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def string_match_all(preds, refs):
return round(score, 2)


METRICS_DICT = {
"niah": string_match_all,
"vt": string_match_all,
"cwe": string_match_all,
"fwe": string_match_all,
"qa": string_match_part,
}


def calculate_metrics(df: pd.DataFrame) -> dict:

scores = {}

np_pattern = re.compile(r"[\x00-\x1f]")
Expand Down
4 changes: 3 additions & 1 deletion kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# SPDX-License-Identifier: Apache-2.0


from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
Expand All @@ -19,7 +21,6 @@
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

from kvpress.attention_patch import patch_attention_functions
# Patch the attention functions to support head-wise compression
patch_attention_functions()

Expand All @@ -40,4 +41,5 @@
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
"ChunkPress",
]
10 changes: 5 additions & 5 deletions kvpress/presses/adakv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ class AdaKVPress(BasePress):
This press has been reviewed by Yuan Feng, first author of AdaKV.
"""

scorer: ScorerPress
press: ScorerPress
alpha_safeguard: float = 0.20

def __post_init__(self):
assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input"
assert isinstance(self.press, ScorerPress), "AdaKVPress requires a ScorerPress as input"
assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]"

@property
def compression_ratio(self):
return self.scorer.compression_ratio
return self.press.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.scorer.compression_ratio = value
self.press.compression_ratio = value

def compress(self, module, hidden_states, keys, values, attentions, kwargs):
if self.compression_ratio == 0:
Expand All @@ -41,7 +41,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
assert module.config._attn_implementation != "eager", "eager mode not supported"

# Compute scores
scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs)
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
bsz, num_key_value_heads, q_len = scores.shape

# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
Expand Down
75 changes: 75 additions & 0 deletions kvpress/presses/chunk_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import torch
from torch import nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
class ChunkPress(BasePress):
"""
Wrapper class for any ScorerPress.
Chunks keys and values into chunks of size chunk_length and compresses each chunk separately.
This ensures that the context is compressed uniformly across the entire context.
This method was proposed in FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models
https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280
"""

press: ScorerPress
chunk_length: int = 1024

def __post_init__(self):
assert isinstance(self.press, ScorerPress), "ChunkPress requires a ScorerPress as input"

@property
def compression_ratio(self):
return self.press.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.press.compression_ratio = value

def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:

if self.press.compression_ratio == 0:
return keys, values

assert attentions is None, "ChunkPress does not support attentions."

kv_len = keys.shape[2]

indices = []
for i in range(0, kv_len, self.chunk_length):
chunk_scores = self.press.score(
module,
hidden_states[:, i : i + self.chunk_length],
keys[:, :, i : i + self.chunk_length],
values[:, :, i : i + self.chunk_length],
attentions,
kwargs,
)
chunk_length = keys[:, :, i : i + self.chunk_length].shape[2]
n_kept = max(1, int(chunk_length * (1 - self.press.compression_ratio)))
chunk_indices = i + chunk_scores.topk(n_kept, dim=-1).indices
indices.append(chunk_indices)

indices = torch.cat(indices, dim=-1)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()

return keys, values
2 changes: 1 addition & 1 deletion kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass

from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.adakv_press import AdaKVPress


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_window_queries(self, module, hidden_states, position_embeddings):

# Apply RoPE
cos, sin = position_embeddings
cos, sin = cos[:, -self.window_size:], sin[:, -self.window_size:]
cos, sin = cos[:, -self.window_size :], sin[:, -self.window_size :]
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))

return query_states
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.2.0"
version = "0.2.1"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down
20 changes: 17 additions & 3 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from transformers import DynamicCache

from kvpress import (
AdaKVPress,
ChunkPress,
ComposedPress,
KeyRerotationPress,
KnormPress,
ObservedAttentionPress,
AdaKVPress,
ThinKPress,
ScorerPress,
ThinKPress,
)
from tests.default_presses import default_presses
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401
Expand All @@ -29,8 +30,19 @@ def test_composed_press(unit_test_model): # noqa: F811
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values


def test_chunk_press(unit_test_model): # noqa: F811
press = KnormPress(compression_ratio=0.5)
for chunk_length in [2, 4, 8, 128]:
composed_press = ChunkPress(press=press, chunk_length=chunk_length)
with composed_press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 256))
cache = DynamicCache()
unit_test_model(input_ids, past_key_values=cache).past_key_values
assert cache.get_seq_length() == 128


@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress])
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress])
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
cls = press_dict["cls"]
for kwargs in press_dict["kwargs"]:
Expand All @@ -44,6 +56,8 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
return
else:
press = AdaKVPress(press=press)
if isinstance(wrapper_press, ChunkPress):
press = ChunkPress(press=press, chunk_length=2)

with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
Expand Down
1 change: 1 addition & 0 deletions tests/test_attention_patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from kvpress.attention_patch import search_hyperplane


Expand Down

0 comments on commit de204fc

Please sign in to comment.