-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathscorer_press.py
71 lines (55 loc) · 2.17 KB
/
scorer_press.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.base_press import BasePress
logger = logging.getLogger(__name__)
@dataclass
class ScorerPress(BasePress):
"""
Default press method for using a score method.
Any ScorerPress subclass must implement the `score` method that computes a tensor of scores for each key-value pair
The KV pairs with the lowest scores will be pruned in the `compress` method.
The cache is uniformly pruned across all heads and layers using the compression_ratio parameter.
"""
compression_ratio: float = 0.0
def __post_init__(self):
assert 0 <= self.compression_ratio < 1, "Compression ratio must be between 0 and 1"
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
"""
Compute a tensor of scores with shape (bsz, num_key_value_heads, q_len)
The KV pairs with lowest scores will be pruned in the `compress` method.
"""
raise NotImplementedError
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.compression_ratio == 0:
return keys, values
# Compute scores
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Get indices of KV pairs with the lowest scores
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
# Prune keys and values
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
return keys, values