Skip to content

Commit

Permalink
Moved rollout code to #254
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Feb 27, 2024
1 parent e57eca3 commit 19f781f
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 247 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238))
- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237))
- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173))
- Added `rollout` (`inseq.data.aggregation_functions.RolloutAggregationFunction`) aggregation function for `SequenceAttributionAggregator` class ([#173](https://github.com/inseq-team/inseq/pull/173)).
- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)).

## 🔧 Fixes & Refactoring
Expand Down
2 changes: 0 additions & 2 deletions inseq/attr/feat/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .rollout import rollout_fn
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing

Expand All @@ -10,6 +9,5 @@
"MonotonicPathBuilder",
"ValueZeroing",
"Lime",
"rollout_fn",
"SequentialIntegratedGradients",
]
180 changes: 0 additions & 180 deletions inseq/attr/feat/ops/rollout.py

This file was deleted.

40 changes: 0 additions & 40 deletions inseq/data/aggregation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
from torch.linalg import vector_norm

from ..attr.feat.ops import rollout_fn
from ..utils import Registry, available_classes
from ..utils.typing import (
ScoreTensor,
Expand Down Expand Up @@ -94,45 +93,6 @@ def __call__(self, scores: torch.Tensor, dim: int, vnorm_ord: int = 2) -> ScoreT
return vector_norm(scores, ord=vnorm_ord, dim=dim)


class RolloutAggregationFunction(AggregationFunction):
aggregation_function_name = "rollout"

def __init__(self):
super().__init__()
self.takes_single_tensor: bool = False
self.takes_sequence_scores: bool = True

def __call__(
self,
scores: Union[torch.Tensor, tuple[torch.Tensor, ...]],
dim: int,
sequence_scores: dict[str, torch.Tensor] = {},
) -> ScoreTensor:
dec_self_prefix = "decoder_self"
enc_self_prefix = "encoder_self"
dec_match = [name for name in sequence_scores.keys() if name.startswith(dec_self_prefix)]
enc_match = [name for name in sequence_scores.keys() if name.startswith(enc_self_prefix)]
if isinstance(scores, torch.Tensor):
# If no matching prefix is found, we assume the decoder-only target-only rollout case
if not dec_match or not enc_match:
return rollout_fn(scores, dim=dim)
# If both prefixes are found, we assume the encoder-decoder source-only rollout case
else:
enc_match = sequence_scores[enc_match[0]]
dec_match = sequence_scores[dec_match[0]]
return rollout_fn((enc_match, scores, dec_match), dim=dim)[0]
elif not enc_match:
raise KeyError(
"Could not find encoder self-importance scores in sequence scores. "
"Encoder self-importance scores are required for encoder-decoder rollout. They should be provided "
f"as an entry in the sequence scores dictionary with key starting with '{enc_self_prefix}', and "
"value being a tensor of shape (src_seq_len, src_seq_len, ..., rollout_dim)."
)
else:
enc_match = sequence_scores[enc_match[0]]
return rollout_fn((enc_match,) + scores, dim=dim)


DEFAULT_ATTRIBUTION_AGGREGATE_DICT = {
"source_attributions": {"spans": "absmax"},
"target_attributions": {"spans": "absmax"},
Expand Down
24 changes: 0 additions & 24 deletions tests/attr/feat/ops/test_rollout.py

This file was deleted.

0 comments on commit 19f781f

Please sign in to comment.