-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dattri.algorithm, dattri.func] Refactor the implementation of EKFAC #143
Conversation
@@ -4,6 +4,8 @@ | |||
|
|||
from typing import TYPE_CHECKING | |||
|
|||
from dattri.task import AttributionTask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is likely not needed
|
||
|
||
class IFAttributorEKFAC(BaseInnerProductAttributor): | ||
"""The inner product attributor with DataInf inverse hessian transformation.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataInf -> EKFAC
Hi @sx-liu , thanks for the PR. Could you try to add some more detailed test about the attributor's internal functions. Could be something like this:
|
Hi @jiaqima , I added one additional unit test for transformed test rep, which is looking at the correlation with ground truth. I think it might be hard to otherwise check the correctness about the FIM. What do you think? |
device: Optional[str] = "cpu", | ||
damping: float = 0.0, | ||
) -> None: | ||
"""Initialize the DataInf inverse Hessian attributor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataInf -> EKFAC
Hessian -> FIM
|
||
|
||
class IFAttributorEKFAC(BaseInnerProductAttributor): | ||
"""The inner product attributor with EK-FAC inverse hessian transformation.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hessian -> FIM
"Ensemble of EK-FAC is not supported.") | ||
raise ValueError(error_msg) | ||
|
||
if not module_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if module_name is None
|
||
self.layer_cache = {} # cache for each layer | ||
|
||
def _ekfac_hook(module: torch.nn.Module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a docstring to explain the requirements for these arguments
full_train_dataloader: DataLoader, | ||
max_iter: Optional[int] = None, | ||
) -> None: | ||
"""Cache the dataset and statistics for inverse hessian/fisher calculation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hessian/fisher -> FIM
|
||
Cache the full training dataset as other attributors. | ||
Estimate and cache the covariance matrices, eigenvector matrices | ||
and corrected eigenvalues based on the distribution of training data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on the distribution of training data -> based on samples of training data.
|
||
Args: | ||
full_train_dataloader (DataLoader): The dataloader | ||
with full training samples for inverse hessian calculation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hessian -> FIM
with full training samples for inverse hessian calculation. | ||
max_iter (int, optional): An integer indicating the maximum number of | ||
batches that will be used for estimating the the covariance matrices | ||
and lambdas. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default
# Cache the inputs and outputs | ||
self.layer_cache[name] = (inputs, outputs) | ||
|
||
self.handles = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this part to self.cache()? It seems that self.handles are only needed in self.cache()
full_model_params = { | ||
k: p for k, p in self.task.model.named_parameters() if p.requires_grad | ||
} | ||
partial_model_params = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using self.task.get_param(layer_name=self.layer_name, layer_split=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a bit hard to use this function, because it only provides the flattened parameters? Here we need the original shape information for each layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think layer_split = True will give you a map to the module name.
Maybe @TheaperDeng has better ideas here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I think this can be done by easily change to the get_param
. So I think I can handle this in next PR, and leave it as it is now in this PR.
self.module_name = module_name | ||
|
||
# Update layer_name corresponding to selected modules | ||
self.layer_name = [name + ".weight" for name in self.module_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[bias issue] Here I think we also need to apend the name + ".bias"
dattri/func/fisher.py
Outdated
max_iter: Optional[int] = None, | ||
device: Optional[str] = "cpu", | ||
) -> Dict[str, torch.tensor]: | ||
"""Estimate the 'covariance' matrices S and A in EK-FAC IFVP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment here needs change.
""" | ||
# Unpack tuple outputs if necessary | ||
if isinstance(inputs, tuple): | ||
inputs = inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[bias issue] Here, we are using this inputs
as the a_prev
in our calculation for cov and lambda. While here we should append a torch.ones
to the input to handle the bias.
ifvp = {} | ||
|
||
for name in self.module_name: | ||
_v = layer_test_rep[name + ".weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[bias issue] again, I think we need the ".bias".
|
||
Args: | ||
func (Callable): A Python function that takes one or more arguments. | ||
Must return the following, | ||
- losses: a tensor of shape (batch_size,). | ||
- loss: a single tensor of loss. Should be the mean loss by the | ||
batch size. | ||
- mask (optional): a tensor of shape (batch_size, t), where 1's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we need to make this mask output clear in the init function's document of EKFAC IF Attributor
The EKFAC implementation overall looks good to me. One comment is that it use Another comment is about the bias term, I have made some comments to the place I think need to change. |
I just incorporated the bias into the gradient and FIM calculation, but I found that the correlation drops drastically. I will try some other metrics and double check the implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sx-liu , the last bug was really hard to identify. LGTM
Description
1. Motivation and Context
The current implementation of EK-FAC is out-dated. Refactor the EK-FAC attribution to follow the format of other IF attribution, as well as update the hook mechanism to avoid redundency.
2. Summary of the change
estimate_covariance
,estimate_eigenvector
andestimate_lambda
for direct use by the users.ifvp_at_x_ekfac
, addIFAttributorEKFAC
instead for attribution.MLPCache
class andmanual_cache_forward
function, and usetorch.register_forward_hook
instead.3. What tests have been added/updated for the change?