-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dattri.algorithm, dattri.func] Refactor the implementation of EKFAC #143
Changes from all commits
12416c2
c0625d6
fa4bdee
2d2f45d
1990777
2ebd965
2b8cf3d
fee5503
c22389e
1071f24
049639b
3ab5634
36c063a
614f800
ad07c9b
e3c3017
41ef25e
beef82e
8b61c5e
1e573e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -815,3 +815,232 @@ def _get_layer_wise_reps(self, | |
query_layers.append(query[:, current_idx : split_index[i] + current_idx]) | ||
current_idx += split_index[i] | ||
return query_layers | ||
|
||
|
||
class IFAttributorEKFAC(BaseInnerProductAttributor): | ||
"""The inner product attributor with EK-FAC inverse FIM transformation.""" | ||
|
||
def __init__(self, | ||
task: AttributionTask, | ||
module_name: Optional[Union[str, List[str]]] = None, | ||
device: Optional[str] = "cpu", | ||
damping: float = 0.0, | ||
) -> None: | ||
"""Initialize the EK-FAC inverse FIM attributor. | ||
|
||
Args: | ||
task (AttributionTask): The task to be attributed. Must be an instance of | ||
`AttributionTask`. The loss function for EK-FAC attributor should return | ||
the following, | ||
- loss: a single tensor of loss. Should be the mean loss by the | ||
batch size. | ||
- mask (optional): a tensor of shape (batch_size, t), where 1's | ||
indicate that the IFVP will be estimated on these | ||
input positions and 0's indicate that these positions | ||
are irrelevant (e.g. padding tokens). | ||
t is the number of steps, or sequence length of the input data. If the | ||
input data are non-sequential, t should be set to 1. | ||
The FIM will be estimated on this function. | ||
module_name (Optional[Union[str, List[str]]]): The name of the module to be | ||
used to calculate the train/test representations. If None, all linear | ||
modules are used. This should be a string or a list of strings if | ||
multiple modules are needed. The name of module should follow the | ||
key of model.named_modules(). Default: None. | ||
device (str): Device to run the attributor on. Default is "cpu". | ||
damping (float): Damping factor used for non-convexity in EK-FAC IFVP | ||
calculation. Default is 0.0. | ||
|
||
Raises: | ||
ValueError: If there are multiple checkpoints in `task`. | ||
""" | ||
super().__init__(task, None, device) | ||
if len(self.task.checkpoints) > 1: | ||
error_msg = ("Received more than one checkpoint. " | ||
"Ensemble of EK-FAC is not supported.") | ||
raise ValueError(error_msg) | ||
|
||
if module_name is None: | ||
# Select all linear layers by default | ||
module_name = [ | ||
name for name, mod in self.task.model.named_modules() | ||
if isinstance(mod, torch.nn.Linear) | ||
] | ||
if not isinstance(module_name, list): | ||
module_name = [module_name] | ||
|
||
self.module_name = module_name | ||
|
||
self.damping = damping | ||
self.name_to_module = { | ||
name: self.task.model.get_submodule(name) for name in module_name | ||
} | ||
self.module_to_name = {v: k for k, v in self.name_to_module.items()} | ||
|
||
self.layer_cache = {} # cache for each layer | ||
|
||
# Update layer_name corresponding to selected modules | ||
self.layer_name = [] | ||
for name in self.module_name: | ||
self.layer_name.append(name + ".weight") | ||
if self.name_to_module[name].bias is not None: | ||
self.layer_name.append(name + ".bias") | ||
|
||
def cache( | ||
self, | ||
full_train_dataloader: DataLoader, | ||
max_iter: Optional[int] = None, | ||
) -> None: | ||
"""Cache the dataset and statistics for inverse FIM calculation. | ||
|
||
Cache the full training dataset as other attributors. | ||
Estimate and cache the covariance matrices, eigenvector matrices | ||
and corrected eigenvalues based on the samples of training data. | ||
|
||
Args: | ||
full_train_dataloader (DataLoader): The dataloader | ||
with full training samples for inverse FIM calculation. | ||
max_iter (int, optional): An integer indicating the maximum number of | ||
batches that will be used for estimating the the covariance matrices | ||
and lambdas. Default to length of `full_train_dataloader`. | ||
""" | ||
from dattri.func.fisher import ( | ||
estimate_covariance, | ||
estimate_eigenvector, | ||
estimate_lambda, | ||
) | ||
|
||
self._set_full_train_data(full_train_dataloader) | ||
|
||
if max_iter is None: | ||
max_iter = len(full_train_dataloader) | ||
|
||
def _ekfac_hook(module: torch.nn.Module, | ||
inputs: Union[Tensor, Tuple[Tensor]], | ||
outputs: Union[Tensor, Tuple[Tensor]], | ||
) -> None: | ||
"""Hook function for caching the inputs and outputs of a module. | ||
|
||
Args: | ||
module (torch.nn.Module): The module to which the hook is registered. | ||
inputs (Union[Tensor, Tuple[Tensor]]): The module input tensor(s). | ||
outputs (Union[Tensor, Tuple[Tensor]]): The module output tensor(s). | ||
""" | ||
# Unpack tuple outputs if necessary | ||
if isinstance(inputs, tuple): | ||
inputs = inputs[0] | ||
|
||
if isinstance(outputs, tuple): | ||
outputs = outputs[0] | ||
|
||
if module.bias is not None: | ||
# Attach ones to the end of inputs | ||
ones = torch.ones( | ||
inputs.shape[:-1] + (1,), | ||
dtype=inputs.dtype, | ||
device=inputs.device, | ||
) | ||
inputs = torch.cat([inputs, ones], dim=-1) | ||
|
||
outputs.retain_grad() | ||
name = self.module_to_name[module] | ||
# Cache the inputs and outputs | ||
self.layer_cache[name] = (inputs, outputs) | ||
|
||
handles = [] | ||
for name in self.module_name: | ||
# Once the model is forward once, the input and output of the layer | ||
# in `module_name` will be stored in `self.layer_cache[name]` | ||
mod = self.task.model.get_submodule(name) | ||
handles.append(mod.register_forward_hook(_ekfac_hook)) | ||
TheaperDeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
func = partial(self.task.get_target_func(), self.task.get_param()[0]) | ||
# 1. Use random batch to estimate covariance matrices S and A | ||
cov_matrices = estimate_covariance(func, | ||
full_train_dataloader, | ||
self.layer_cache, | ||
max_iter, | ||
device=self.device) | ||
|
||
# 2. Calculate the eigenvalue decomposition of S and A | ||
self.cached_q = estimate_eigenvector(cov_matrices) | ||
|
||
# 3. Use random batch for eigenvalue correction | ||
self.cached_lambdas = estimate_lambda(func, | ||
full_train_dataloader, | ||
self.cached_q, | ||
self.layer_cache, | ||
max_iter, | ||
device=self.device) | ||
|
||
# Remove hooks after preprocessing the FIM | ||
for handle in handles: | ||
handle.remove() | ||
|
||
def transform_test_rep( | ||
self, | ||
ckpt_idx: int, | ||
test_rep: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Calculate the transformation on the test representations. | ||
|
||
Args: | ||
ckpt_idx (int): Index of the model checkpoints. Used for ensembling | ||
different trained model checkpoints. | ||
test_rep (torch.Tensor): Test representations to be transformed. | ||
Typically a 2-d tensor with shape (batch_size, num_parameters). | ||
|
||
Returns: | ||
torch.Tensor: Transformed test representations. Typically a 2-d | ||
tensor with shape (batch_size, transformed_dimension). | ||
|
||
Raises: | ||
ValueError: If specifies a non-zero `ckpt_idx`. | ||
""" | ||
if ckpt_idx != 0: | ||
error_msg = ("EK-FAC only supports single model checkpoint, " | ||
"but receives non-zero `ckpt_idx`.") | ||
raise ValueError(error_msg) | ||
|
||
# Unflatten the test_rep | ||
full_model_params = { | ||
k: p for k, p in self.task.model.named_parameters() if p.requires_grad | ||
} | ||
partial_model_params = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about using self.task.get_param(layer_name=self.layer_name, layer_split=True) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a bit hard to use this function, because it only provides the flattened parameters? Here we need the original shape information for each layer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think layer_split = True will give you a map to the module name. Maybe @TheaperDeng has better ideas here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since I think this can be done by easily change to the |
||
name: full_model_params[name] for name in self.layer_name | ||
} | ||
layer_test_rep = {} | ||
current_index = 0 | ||
for name, params in partial_model_params.items(): | ||
size = math.prod(params.shape) | ||
layer_test_rep[name] = test_rep[ | ||
:, current_index : current_index + size, | ||
].reshape(-1, *params.shape) | ||
current_index += size | ||
|
||
ifvp = {} | ||
|
||
for name in self.module_name: | ||
if self.name_to_module[name].bias is not None: | ||
dim_out = layer_test_rep[name + ".weight"].shape[1] | ||
dim_in = layer_test_rep[name + ".weight"].shape[2] + 1 | ||
_v = torch.cat( | ||
[ | ||
layer_test_rep[name + ".weight"].flatten(start_dim=1), | ||
layer_test_rep[name + ".bias"].flatten(start_dim=1), | ||
], | ||
dim=-1, | ||
) | ||
_v = _v.reshape(-1, dim_out, dim_in) | ||
else: | ||
_v = layer_test_rep[name + ".weight"] | ||
|
||
_lambda = self.cached_lambdas[name] | ||
q_a, q_s = self.cached_q[name] | ||
|
||
_ifvp = q_s.T @ ((q_s @ _v @ q_a.T) / (_lambda + self.damping)) @ q_a | ||
ifvp[name] = _ifvp.flatten(start_dim=1) | ||
|
||
# Flatten the parameters again | ||
transformed_test_rep_layers = [ifvp[name] for name in self.module_name] | ||
|
||
return torch.cat(transformed_test_rep_layers, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[bias issue] Here, we are using this
inputs
as thea_prev
in our calculation for cov and lambda. While here we should append atorch.ones
to the input to handle the bias.