Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dattri.algorithm, dattri.func] Refactor the implementation of EKFAC #143

Merged
merged 20 commits into from
Oct 27, 2024
229 changes: 229 additions & 0 deletions dattri/algorithm/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,232 @@ def _get_layer_wise_reps(self,
query_layers.append(query[:, current_idx : split_index[i] + current_idx])
current_idx += split_index[i]
return query_layers


class IFAttributorEKFAC(BaseInnerProductAttributor):
"""The inner product attributor with EK-FAC inverse FIM transformation."""

def __init__(self,
task: AttributionTask,
module_name: Optional[Union[str, List[str]]] = None,
device: Optional[str] = "cpu",
damping: float = 0.0,
) -> None:
"""Initialize the EK-FAC inverse FIM attributor.

Args:
task (AttributionTask): The task to be attributed. Must be an instance of
`AttributionTask`. The loss function for EK-FAC attributor should return
the following,
- loss: a single tensor of loss. Should be the mean loss by the
batch size.
- mask (optional): a tensor of shape (batch_size, t), where 1's
indicate that the IFVP will be estimated on these
input positions and 0's indicate that these positions
are irrelevant (e.g. padding tokens).
t is the number of steps, or sequence length of the input data. If the
input data are non-sequential, t should be set to 1.
The FIM will be estimated on this function.
module_name (Optional[Union[str, List[str]]]): The name of the module to be
used to calculate the train/test representations. If None, all linear
modules are used. This should be a string or a list of strings if
multiple modules are needed. The name of module should follow the
key of model.named_modules(). Default: None.
device (str): Device to run the attributor on. Default is "cpu".
damping (float): Damping factor used for non-convexity in EK-FAC IFVP
calculation. Default is 0.0.

Raises:
ValueError: If there are multiple checkpoints in `task`.
"""
super().__init__(task, None, device)
if len(self.task.checkpoints) > 1:
error_msg = ("Received more than one checkpoint. "
"Ensemble of EK-FAC is not supported.")
raise ValueError(error_msg)

if module_name is None:
# Select all linear layers by default
module_name = [
name for name, mod in self.task.model.named_modules()
if isinstance(mod, torch.nn.Linear)
]
if not isinstance(module_name, list):
module_name = [module_name]

self.module_name = module_name

self.damping = damping
self.name_to_module = {
name: self.task.model.get_submodule(name) for name in module_name
}
self.module_to_name = {v: k for k, v in self.name_to_module.items()}

self.layer_cache = {} # cache for each layer

# Update layer_name corresponding to selected modules
self.layer_name = []
for name in self.module_name:
self.layer_name.append(name + ".weight")
if self.name_to_module[name].bias is not None:
self.layer_name.append(name + ".bias")

def cache(
self,
full_train_dataloader: DataLoader,
max_iter: Optional[int] = None,
) -> None:
"""Cache the dataset and statistics for inverse FIM calculation.

Cache the full training dataset as other attributors.
Estimate and cache the covariance matrices, eigenvector matrices
and corrected eigenvalues based on the samples of training data.

Args:
full_train_dataloader (DataLoader): The dataloader
with full training samples for inverse FIM calculation.
max_iter (int, optional): An integer indicating the maximum number of
batches that will be used for estimating the the covariance matrices
and lambdas. Default to length of `full_train_dataloader`.
"""
from dattri.func.fisher import (
estimate_covariance,
estimate_eigenvector,
estimate_lambda,
)

self._set_full_train_data(full_train_dataloader)

if max_iter is None:
max_iter = len(full_train_dataloader)

def _ekfac_hook(module: torch.nn.Module,
inputs: Union[Tensor, Tuple[Tensor]],
outputs: Union[Tensor, Tuple[Tensor]],
) -> None:
"""Hook function for caching the inputs and outputs of a module.

Args:
module (torch.nn.Module): The module to which the hook is registered.
inputs (Union[Tensor, Tuple[Tensor]]): The module input tensor(s).
outputs (Union[Tensor, Tuple[Tensor]]): The module output tensor(s).
"""
# Unpack tuple outputs if necessary
if isinstance(inputs, tuple):
inputs = inputs[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[bias issue] Here, we are using this inputs as the a_prev in our calculation for cov and lambda. While here we should append a torch.ones to the input to handle the bias.


if isinstance(outputs, tuple):
outputs = outputs[0]

if module.bias is not None:
# Attach ones to the end of inputs
ones = torch.ones(
inputs.shape[:-1] + (1,),
dtype=inputs.dtype,
device=inputs.device,
)
inputs = torch.cat([inputs, ones], dim=-1)

outputs.retain_grad()
name = self.module_to_name[module]
# Cache the inputs and outputs
self.layer_cache[name] = (inputs, outputs)

handles = []
for name in self.module_name:
# Once the model is forward once, the input and output of the layer
# in `module_name` will be stored in `self.layer_cache[name]`
mod = self.task.model.get_submodule(name)
handles.append(mod.register_forward_hook(_ekfac_hook))

func = partial(self.task.get_target_func(), self.task.get_param()[0])
# 1. Use random batch to estimate covariance matrices S and A
cov_matrices = estimate_covariance(func,
full_train_dataloader,
self.layer_cache,
max_iter,
device=self.device)

# 2. Calculate the eigenvalue decomposition of S and A
self.cached_q = estimate_eigenvector(cov_matrices)

# 3. Use random batch for eigenvalue correction
self.cached_lambdas = estimate_lambda(func,
full_train_dataloader,
self.cached_q,
self.layer_cache,
max_iter,
device=self.device)

# Remove hooks after preprocessing the FIM
for handle in handles:
handle.remove()

def transform_test_rep(
self,
ckpt_idx: int,
test_rep: torch.Tensor,
) -> torch.Tensor:
"""Calculate the transformation on the test representations.

Args:
ckpt_idx (int): Index of the model checkpoints. Used for ensembling
different trained model checkpoints.
test_rep (torch.Tensor): Test representations to be transformed.
Typically a 2-d tensor with shape (batch_size, num_parameters).

Returns:
torch.Tensor: Transformed test representations. Typically a 2-d
tensor with shape (batch_size, transformed_dimension).

Raises:
ValueError: If specifies a non-zero `ckpt_idx`.
"""
if ckpt_idx != 0:
error_msg = ("EK-FAC only supports single model checkpoint, "
"but receives non-zero `ckpt_idx`.")
raise ValueError(error_msg)

# Unflatten the test_rep
full_model_params = {
k: p for k, p in self.task.model.named_parameters() if p.requires_grad
}
partial_model_params = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using self.task.get_param(layer_name=self.layer_name, layer_split=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit hard to use this function, because it only provides the flattened parameters? Here we need the original shape information for each layer.

Copy link
Contributor

@jiaqima jiaqima Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think layer_split = True will give you a map to the module name.

Maybe @TheaperDeng has better ideas here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I think this can be done by easily change to the get_param. So I think I can handle this in next PR, and leave it as it is now in this PR.

name: full_model_params[name] for name in self.layer_name
}
layer_test_rep = {}
current_index = 0
for name, params in partial_model_params.items():
size = math.prod(params.shape)
layer_test_rep[name] = test_rep[
:, current_index : current_index + size,
].reshape(-1, *params.shape)
current_index += size

ifvp = {}

for name in self.module_name:
if self.name_to_module[name].bias is not None:
dim_out = layer_test_rep[name + ".weight"].shape[1]
dim_in = layer_test_rep[name + ".weight"].shape[2] + 1
_v = torch.cat(
[
layer_test_rep[name + ".weight"].flatten(start_dim=1),
layer_test_rep[name + ".bias"].flatten(start_dim=1),
],
dim=-1,
)
_v = _v.reshape(-1, dim_out, dim_in)
else:
_v = layer_test_rep[name + ".weight"]

_lambda = self.cached_lambdas[name]
q_a, q_s = self.cached_q[name]

_ifvp = q_s.T @ ((q_s @ _v @ q_a.T) / (_lambda + self.damping)) @ q_a
ifvp[name] = _ifvp.flatten(start_dim=1)

# Flatten the parameters again
transformed_test_rep_layers = [ifvp[name] for name in self.module_name]

return torch.cat(transformed_test_rep_layers, dim=1)
Loading
Loading