Skip to content

Commit

Permalink
Format the codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Sx Lau committed Sep 19, 2024
1 parent 3ab5634 commit 36c063a
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 53 deletions.
60 changes: 44 additions & 16 deletions dattri/algorithm/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,19 +818,36 @@ def _get_layer_wise_reps(self,


class IFAttributorEKFAC(BaseInnerProductAttributor):
"""The inner product attributor with DataInf inverse hessian transformation."""
"""The inner product attributor with EK-FAC inverse hessian transformation."""

def __init__(self,
task: AttributionTask,
module_name: Optional[Union[str, List[str]]] = None,
device: Optional[str] = "cpu",
damping: float = 0.0,
device: Optional[str] = "cpu"
) -> None:
"""Initialize the DataInf inverse Hessian attributor.
Args:
task (AttributionTask): The task to be attributed. Must be an instance of
`AttributionTask`.
module_name (Optional[Union[str, List[str]]]): The name of the module to be
used to calculate the train/test representations. If None, all linear
modules are used. This should be a string or a list of strings if
multiple modules are needed. The name of module should follow the
key of model.named_modules(). Default: None.
device (str): Device to run the attributor on. Default is "cpu".
damping (float): Damping factor used for non-convexity in EK-FAC IFVP
calculation. Default is 0.0.
Raises:
ValueError: If there are multiple checkpoints in `task`.
"""
super().__init__(task, None, device)
if len(self.task.checkpoints) > 1:
error_msg = ("Received more than one checkpoint. "
"Ensemble of EK-FAC is not supported.")
raise ValueError(error_msg)
raise ValueError(error_msg)

if not module_name:
# Select all linear layers by default
Expand All @@ -852,9 +869,12 @@ def __init__(self,
}
self.module_to_name = {v: k for k, v in self.name_to_module.items()}

self.layer_cache = {} # cache for each layer
self.layer_cache = {} # cache for each layer

def _ekfac_hook(module, inputs, outputs):
def _ekfac_hook(module: torch.nn.Module,
inputs: Union[Tensor, Tuple[Tensor]],
outputs: Union[Tensor, Tuple[Tensor]],
) -> None:
# Unpack tuple outputs if necessary
if isinstance(inputs, tuple):
inputs = inputs[0]
Expand All @@ -873,10 +893,10 @@ def _ekfac_hook(module, inputs, outputs):
self.handles.append(mod.register_forward_hook(_ekfac_hook))

def cache(
self,
self,
full_train_dataloader: DataLoader,
max_iter: Optional[int] = None,
):
) -> None:
"""Cache the dataset and statistics for inverse hessian/fisher calculation.
Cache the full training dataset as other attributors.
Expand All @@ -890,7 +910,11 @@ def cache(
batches that will be used for estimating the the covariance matrices
and lambdas.
"""
from dattri.func.fisher import estimate_covariance, estimate_eigenvector, estimate_lambda
from dattri.func.fisher import (
estimate_covariance,
estimate_eigenvector,
estimate_lambda,
)

self._set_full_train_data(full_train_dataloader)

Expand Down Expand Up @@ -936,7 +960,15 @@ def transform_test_rep(
Returns:
torch.Tensor: Transformed test representations. Typically a 2-d
tensor with shape (batch_size, transformed_dimension).
Raises:
ValueError: If specifies a non-zero `ckpt_idx`.
"""
if ckpt_idx != 0:
error_msg = ("EK-FAC only supports single model checkpoint, "
"but receives non-zero `ckpt_idx`.")
raise ValueError(error_msg)

# Unflatten the test_rep
full_model_params = {
k: p for k, p in self.task.model.named_parameters() if p.requires_grad
Expand All @@ -949,8 +981,8 @@ def transform_test_rep(
for name, params in partial_model_params.items():
size = math.prod(params.shape)
layer_test_rep[name] = test_rep[
:, current_index : current_index + size
].reshape((-1,) + params.shape)
:, current_index : current_index + size,
].reshape(-1, *params.shape)
current_index += size

ifvp = {}
Expand All @@ -960,14 +992,10 @@ def transform_test_rep(
_lambda = self.cached_lambdas[name]
q_a, q_s = self.cached_q[name]

_ifvp= q_s.T @ ((q_s @ _v @ q_a.T) / (_lambda + self.damping)) @ q_a
_ifvp = q_s.T @ ((q_s @ _v @ q_a.T) / (_lambda + self.damping)) @ q_a
ifvp[name] = _ifvp.flatten(start_dim=1)

# Flatten the parameters again
transformed_test_rep_layers = []

for name in self.module_name:
transformed_test_rep_layers.append(ifvp[name])
transformed_test_rep_layers = [ifvp[name] for name in self.module_name]

return torch.cat(transformed_test_rep_layers, dim=1)

137 changes: 101 additions & 36 deletions dattri/func/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@

if TYPE_CHECKING:
from collections.abc import Callable
from typing import ClassVar, Dict, Generator, List, Optional, Tuple, Union
from typing import Dict, Generator, Optional, Tuple, Union

from torch import Tensor


import warnings
from functools import wraps

import torch
from torch.func import grad
Expand Down Expand Up @@ -243,29 +242,30 @@ def _update_covariance(
total_samples: int,
mask: torch.Tensor,
) -> Dict[str, Tuple[torch.tensor]]:
"""Update the running estimation of the 'covariance' matrices S and A in EK-FAC IFVP.
"""Update the running estimation of the covariance matrices S and A in EK-FAC IFVP.
Args:
curr_estimate (List[List[Tuple[torch.Tensor]]]): A list of lists of tuples
of tensors, storing the running estimation of the layer-wise covariances.
mlp_cache (List[MLPCache]): A list of `MLPCache` passed to the main
EK-FAC function.
layer_cache (Dict[str, Tuple[torch.tensor]]): A dict that caches a pair
of (inputs, outputs) for each module during the forward process.
total_samples (int): An integer indicating the number of total valid
samples in the current batch.
samples in all previous batchs.
mask (torch.Tensor): A tensor of shape (batch_size, t), where 1's
indicate that the IFVP will be estimated on these input positions and
0's indicate that these positions are irrelevant (e.g. padding tokens).
Returns:
A list of lists of tuples of tensors, storing the updated running covariances.
Dict[str, Tuple[torch.tensor]]: A dict of tuples of tensors, storing the
updated running covariances.
"""
batch_samples = int(mask.sum())
for layer_name, (a_prev, s_curr) in layer_cache.items():
for layer_name, (a_prev_raw, s_curr_raw) in layer_cache.items():
# Uniformly reshape the tensors into (batch_size, t, ...)
# The t here is the sequence length or time steps for sequential input
# t = 1 if the given input is not sequential
if a_prev.ndim == 2: # noqa: PLR2004
a_prev = a_prev.unsqueeze(1)
if a_prev_raw.ndim == 2: # noqa: PLR2004
a_prev = a_prev_raw.unsqueeze(1)

a_prev_masked = a_prev * mask[..., None].to(a_prev.device)

Expand All @@ -275,18 +275,20 @@ def _update_covariance(
batch_cov_a /= batch_samples

# Calculate batch covariance matrix for S
ds_curr = s_curr.grad
ds_curr = s_curr_raw.grad

ds_curr_reshaped = ds_curr.view(-1, s_curr.size(-1))
ds_curr_reshaped = ds_curr.view(-1, s_curr_raw.size(-1))
batch_cov_s = ds_curr_reshaped.transpose(0, 1) @ ds_curr_reshaped
batch_cov_s /= batch_samples

# Update the running covariance matrices for A and S
if layer_name in curr_estimate:
old_weight = total_samples / (total_samples + batch_samples)
new_weight = batch_samples / (total_samples + batch_samples)
new_cov_a = old_weight * curr_estimate[layer_name][0] + new_weight * batch_cov_a
new_cov_s = old_weight * curr_estimate[layer_name][1] + new_weight * batch_cov_s
new_cov_a = (old_weight * curr_estimate[layer_name][0] +
new_weight * batch_cov_a)
new_cov_s = (old_weight * curr_estimate[layer_name][1] +
new_weight * batch_cov_s)
curr_estimate[layer_name] = (new_cov_a, new_cov_s)
else:
# First time access
Expand All @@ -296,27 +298,26 @@ def _update_covariance(


def _update_lambda(
curr_estimate: List[List[torch.Tensor]],
curr_estimate: Dict[str, torch.tensor],
layer_cache: Dict[str, Tuple[torch.tensor]],
cached_q: List[List[Tuple[torch.Tensor]]],
cached_q: Dict[str, Tuple[torch.tensor]],
total_samples: int,
mask: torch.Tensor,
max_steps_for_vec: int = 10,
) -> List[List[torch.Tensor]]:
) -> Dict[str, torch.tensor]:
"""Update the running estimation of the corrected eigenvalues in EK-FAC IFVP.
Args:
curr_estimate (List[List[torch.Tensor]]): A list of lists of tensors,
curr_estimate (Dict[str, torch.tensor]): A list of lists of tensors,
storing the running estimation of the layer-wise lambdas. The list
has the same length as `mlp_cache` in the main function, and each
of the member has the same length as the list in the cache.
mlp_cache (List[MLPCache]): A list of `MLPCache` passed to the main
EK-FAC function.
cached_q (List[List[Tuple[torch.Tensor]]]): A list of lists of tuples
of tensors, storing the layer-wise eigenvector matrices calculated
in the EK-FAC main function.
layer_cache (Dict[str, Tuple[torch.tensor]]): A dict that caches a pair
of (inputs, outputs) for each module during the forward process.
cached_q (Dict[str, Tuple[torch.tensor]]): A dict of tuples of tensors,
storing the layer-wise eigenvector matrices.
total_samples (int): An integer indicating the number of total valid
samples in the current batch.
samples in all previous batchs.
mask (torch.Tensor): A tensor of shape (batch_size, t), where 1's
indicate that the IFVP will be estimated on these input positions and
0's indicate that these positions are irrelevant (e.g. padding tokens).
Expand All @@ -327,15 +328,15 @@ def _update_lambda(
`dtheta`.
Returns:
A list of lists of tensors, storing the updated running lambdas.
Dict[str, torch.tensor]: A dict of tensors, storing the updated running lambdas.
"""
for layer_name, (a_prev, s_curr) in layer_cache.items():
for layer_name, (a_prev_raw, s_curr_raw) in layer_cache.items():
# Uniformly reshape the tensors into (batch_size, t, ...)
# The t here is the sequence length or time steps for sequential input
# t = 1 if the given input is not sequential
ds_curr = s_curr.grad
if a_prev.ndim == 2: # noqa: PLR2004
a_prev = a_prev.unsqueeze(1)
ds_curr = s_curr_raw.grad
if a_prev_raw.ndim == 2: # noqa: PLR2004
a_prev = a_prev_raw.unsqueeze(1)
ds_curr = ds_curr.unsqueeze(1)

a_prev_masked = a_prev * mask[..., None].to(a_prev.device)
Expand Down Expand Up @@ -385,8 +386,34 @@ def estimate_covariance(
dataloader: torch.utils.data.DataLoader,
layer_cache: Dict[str, Tuple[torch.tensor]],
max_iter: Optional[int] = None,
device: Optional[str] = "cpu"
device: Optional[str] = "cpu",
) -> Dict[str, Tuple[torch.tensor]]:
"""Estimate the 'covariance' matrices S and A in EK-FAC IFVP.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return the following,
- loss: a single tensor of loss. Should be the mean loss by the
batch size.
- mask (optional): a tensor of shape (batch_size, t), where 1's
indicate that the IFVP will be estimated on these
input positions and 0's indicate that these positions
are irrelevant (e.g. padding tokens).
t is the number of steps, or sequence length of the input data. If the
input data are non-sequential, t should be set to 1.
The FIM will be estimated on this function.
dataloader (torch.utils.data.DataLoader): The dataloader with full training
samples for FIM estimation.
layer_cache (Dict[str, Tuple[torch.tensor]]): A dict that caches a pair
of (inputs, outputs) for each module during the forward process.
max_iter (Optional[int]): An integer indicating the maximum number of
batches that will be used for estimating the covariance matrices.
device (Optional[str]): Device to run the attributor on. Default is "cpu".
Returns:
Dict[str, Tuple[torch.tensor]]: A dict that contains a pair of
estimated covariance for each module.
"""
if max_iter is None:
max_iter = len(dataloader)

Expand All @@ -410,7 +437,7 @@ def estimate_covariance(

with torch.no_grad():
# Estimate covariance
cov_matrices = _update_covariance(
covariances = _update_covariance(
covariances,
layer_cache,
total_samples,
Expand All @@ -425,8 +452,18 @@ def estimate_covariance(


def estimate_eigenvector(
covariances: List[List[Tuple[torch.Tensor]]],
) -> List[List[Tuple[torch.Tensor]]]:
covariances: Dict[str, Tuple[torch.Tensor]],
) -> Dict[str, Tuple[torch.Tensor]]:
"""Perform eigenvalue decomposition to covarince matrices.
Args:
covariances (Dict[str, Tuple[torch.Tensor]]): A dict that
contains a pair of estimated covariance for each module.
Returns:
Dict[str, Tuple[torch.Tensor]]: A dict that contains a
pair of eigenvector matrices for each module.
"""
cached_q = {}
for layer_name, (cov_a, cov_s) in covariances.items():
_, q_a = torch.linalg.eigh(cov_a, UPLO="U")
Expand All @@ -439,11 +476,39 @@ def estimate_eigenvector(
def estimate_lambda(
func: Callable,
dataloader: torch.utils.data.DataLoader,
eigenvectors: List[List[Tuple[torch.Tensor]]],
eigenvectors: Dict[str, Tuple[torch.tensor]],
layer_cache: Dict[str, Tuple[torch.tensor]],
max_iter: Optional[int] = None,
device: Optional[str] = "cpu"
) -> Dict[str, Tuple[torch.tensor]]:
device: Optional[str] = "cpu",
) -> Dict[str, torch.tensor]:
"""Estimate the 'covariance' matrices S and A in EK-FAC IFVP.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return the following,
- loss: a single tensor of loss. Should be the mean loss by the
batch size.
- mask (optional): a tensor of shape (batch_size, t), where 1's
indicate that the IFVP will be estimated on these
input positions and 0's indicate that these positions
are irrelevant (e.g. padding tokens).
t is the number of steps, or sequence length of the input data. If the
input data are non-sequential, t should be set to 1.
The FIM will be estimated on this function.
dataloader (torch.utils.data.DataLoader): The dataloader with full training
samples for FIM estimation.
eigenvectors (Dict[str, Tuple[torch.tensor]]): A dict that contains a
pair of eigenvector matrices for each module.
layer_cache (Dict[str, Tuple[torch.tensor]]): A dict that caches a pair
of (inputs, outputs) for each module during the forward process.
max_iter (Optional[int]): An integer indicating the maximum number of
batches that will be used for estimating the lambdas.
device (Optional[str]): Device to run the attributor on. Default is "cpu".
Returns:
Dict[str, torch.tensor]: A dict that contains the estimated lambda
for each module.
"""
if max_iter is None:
max_iter = len(dataloader)

Expand Down
1 change: 0 additions & 1 deletion test/dattri/func/test_fisher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Unit test for ifvp calculator."""

import types

import numpy as np
import torch
Expand Down

0 comments on commit 36c063a

Please sign in to comment.