Skip to content

Commit

Permalink
Implement RelatIF for Explicit, CG, and LISSA (#170)
Browse files Browse the repository at this point in the history
* Implement RelatIF for explicit, CG, and Lissa methods

* add test for RelatIF

* add an example for relatIF

* change parameter name

* use separate train and test representations for relatif_method

* update readme and examples test

* fix example

* make test_batch_rep calculation conditional
  • Loading branch information
jxbb824 authored Feb 10, 2025
1 parent 02aa090 commit e909f76
Show file tree
Hide file tree
Showing 7 changed files with 460 additions and 49 deletions.
1 change: 1 addition & 0 deletions .github/workflows/examples_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
python examples/pretrained_benchmark/trak_dropout_lds.py --device cpu
python examples/brittleness/mnist_lr_brittleness.py --method cg --device cpu
python examples/data_cleaning/influence_function_data_cleaning.py --device cpu --train_size 1000 --val_size 100 --test_size 100 --remove_number 10
python examples/relatIF/influence_function_comparison.py --no_output
- name: Uninstall the package
run: |
pip uninstall -y dattri
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ We have implemented most of the state-of-the-art methods. The categories and ref
| | [Arnoldi](https://arxiv.org/abs/2112.03052) |
| | [DataInf](https://arxiv.org/abs/2310.00902)|
| | [EK-FAC](https://arxiv.org/abs/2308.03296) |
| | [RelatIF](https://arxiv.org/pdf/2003.11630) |
| [TracIn](https://arxiv.org/abs/2002.08484) | [TracInCP](https://arxiv.org/abs/2002.08484) |
| | [Grad-Dot](https://arxiv.org/abs/2102.05262) |
| | [Grad-Cos](https://arxiv.org/abs/2102.05262) |
Expand Down Expand Up @@ -249,7 +250,6 @@ We have implemented most of the state-of-the-art methods. The categories and ref
- More algorithms and low-level utility functions to come
- KNN filter
- TF-IDF filter
- RelativeIF
- In-Run Shapley
- [LoGra](https://arxiv.org/abs/2405.13954)
- Better documentation
Expand Down
60 changes: 59 additions & 1 deletion dattri/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def attribute(
self,
train_dataloader: DataLoader,
test_dataloader: DataLoader,
relatif_method: Optional[str] = None,
) -> torch.Tensor:
"""Calculate the influence of the training set on the test set.
Expand All @@ -270,6 +271,12 @@ def attribute(
not be shuffled.
test_dataloader (DataLoader): Dataloader for test samples to calculate
the influence. The dataloader should not be shuffled.
relatif_method (Optional[str]): Method for normalizing the
influence values.
Supported options:
- `"l"`: Normalizes by `sqrt(g_i^T (H^-1 g_i))`.
- `"theta"`: Normalizes by `||H^-1 g_i||`.
- `None`: No normalization applied.
Returns:
torch.Tensor: The influence of the training set on the test set, with
Expand Down Expand Up @@ -314,6 +321,23 @@ def attribute(
ckpt_idx=checkpoint_idx,
data=train_batch_data,
)

denom = None
if relatif_method is not None:
if relatif_method == "l":
test_batch_rep = self.generate_test_rep(
ckpt_idx=checkpoint_idx,
data=train_batch_data,
)
else:
test_batch_rep = None
denom = self._compute_denom(
checkpoint_idx,
train_batch_rep,
test_batch_rep,
relatif_method=relatif_method,
)

# transform the train representations
train_batch_rep = self.transform_train_rep(
ckpt_idx=checkpoint_idx,
Expand Down Expand Up @@ -356,9 +380,43 @@ def attribute(
)

tda_output[row_st:row_ed, col_st:col_ed] += (
train_batch_rep @ test_batch_rep.T
train_batch_rep @ test_batch_rep.T / denom.unsqueeze(-1)
if denom is not None
else train_batch_rep @ test_batch_rep.T
)

tda_output /= checkpoint_idx + 1

return tda_output

def _compute_denom(
self,
ckpt_idx: int, # noqa: ARG002
train_batch_rep: torch.Tensor,
test_batch_rep: Optional[torch.Tensor] = None,
relatif_method: Optional[str] = None, # noqa: ARG002
) -> torch.Tensor:
"""Compute the denominator for the influence calculation.
Args:
ckpt_idx (int): The index of the checkpoint being used for influence
calculation.
train_batch_rep (torch.Tensor): The representation of the training batch
at the given checkpoint.
test_batch_rep (Optional[torch.Tensor]): The representation of the
training batch, generated using `generate_test_rep` at the given
checkpoint.
relatif_method (Optional[str]): Normalization method.
- `"l"`: Computes `sqrt(g_i^T (H^-1 g_i))`.
- `"theta"`: Computes `||H^-1 g_i||`.
- `None`: Raises an error.
Returns:
torch.Tensor: The computed denominator for normalization. It is a
1-d dimensional tensor with the shape of (batch_size).
"""
_ = self
_ = test_batch_rep

batch_size = train_batch_rep.size(0)
return train_batch_rep.new_ones(batch_size)
Loading

0 comments on commit e909f76

Please sign in to comment.