Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dattri.algorithm, dattri.func] Refactor the implementation of EKFAC #143

Merged
merged 20 commits into from
Oct 27, 2024

Conversation

sx-liu
Copy link
Collaborator

@sx-liu sx-liu commented Sep 17, 2024

Description

1. Motivation and Context

The current implementation of EK-FAC is out-dated. Refactor the EK-FAC attribution to follow the format of other IF attribution, as well as update the hook mechanism to avoid redundency.

2. Summary of the change

  1. Refactor EK-FAC base functions, add estimate_covariance, estimate_eigenvector and estimate_lambda for direct use by the users.
  2. Remove ifvp_at_x_ekfac, add IFAttributorEKFAC instead for attribution.
  3. Remove MLPCache class and manual_cache_forward function, and use torch.register_forward_hook instead.

3. What tests have been added/updated for the change?

  • N/A: No test will be added (please justify)
  • Unit test: Typically, this should be included if you implemented a new function/fixed a bug.
  • Application test: If you wrote an example for the toolkit, this test should be added.
  • Document test: If you added an external API, then you should check if the document is correctly generated.
  • ...

@@ -4,6 +4,8 @@

from typing import TYPE_CHECKING

from dattri.task import AttributionTask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is likely not needed



class IFAttributorEKFAC(BaseInnerProductAttributor):
"""The inner product attributor with DataInf inverse hessian transformation."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataInf -> EKFAC

@jiaqima
Copy link
Contributor

jiaqima commented Oct 4, 2024

Hi @sx-liu , thanks for the PR. Could you try to add some more detailed test about the attributor's internal functions. Could be something like this:

def test_datainf_transform_test_rep(self):

@jiaqima jiaqima changed the title Refactor the implementation of EKFAC [dattri.algorithms] Refactor the implementation of EKFAC Oct 12, 2024
@jiaqima jiaqima changed the title [dattri.algorithms] Refactor the implementation of EKFAC [dattri.algorithm, dattri.func] Refactor the implementation of EKFAC Oct 12, 2024
@sx-liu
Copy link
Collaborator Author

sx-liu commented Oct 13, 2024

Hi @jiaqima , I added one additional unit test for transformed test rep, which is looking at the correlation with ground truth. I think it might be hard to otherwise check the correctness about the FIM. What do you think?

device: Optional[str] = "cpu",
damping: float = 0.0,
) -> None:
"""Initialize the DataInf inverse Hessian attributor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataInf -> EKFAC

Hessian -> FIM



class IFAttributorEKFAC(BaseInnerProductAttributor):
"""The inner product attributor with EK-FAC inverse hessian transformation."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hessian -> FIM

"Ensemble of EK-FAC is not supported.")
raise ValueError(error_msg)

if not module_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if module_name is None


self.layer_cache = {} # cache for each layer

def _ekfac_hook(module: torch.nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a docstring to explain the requirements for these arguments

full_train_dataloader: DataLoader,
max_iter: Optional[int] = None,
) -> None:
"""Cache the dataset and statistics for inverse hessian/fisher calculation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hessian/fisher -> FIM


Cache the full training dataset as other attributors.
Estimate and cache the covariance matrices, eigenvector matrices
and corrected eigenvalues based on the distribution of training data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the distribution of training data -> based on samples of training data.


Args:
full_train_dataloader (DataLoader): The dataloader
with full training samples for inverse hessian calculation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hessian -> FIM

with full training samples for inverse hessian calculation.
max_iter (int, optional): An integer indicating the maximum number of
batches that will be used for estimating the the covariance matrices
and lambdas.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default

# Cache the inputs and outputs
self.layer_cache[name] = (inputs, outputs)

self.handles = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this part to self.cache()? It seems that self.handles are only needed in self.cache()

full_model_params = {
k: p for k, p in self.task.model.named_parameters() if p.requires_grad
}
partial_model_params = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using self.task.get_param(layer_name=self.layer_name, layer_split=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit hard to use this function, because it only provides the flattened parameters? Here we need the original shape information for each layer.

Copy link
Contributor

@jiaqima jiaqima Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think layer_split = True will give you a map to the module name.

Maybe @TheaperDeng has better ideas here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I think this can be done by easily change to the get_param. So I think I can handle this in next PR, and leave it as it is now in this PR.

self.module_name = module_name

# Update layer_name corresponding to selected modules
self.layer_name = [name + ".weight" for name in self.module_name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[bias issue] Here I think we also need to apend the name + ".bias"

max_iter: Optional[int] = None,
device: Optional[str] = "cpu",
) -> Dict[str, torch.tensor]:
"""Estimate the 'covariance' matrices S and A in EK-FAC IFVP.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment here needs change.

"""
# Unpack tuple outputs if necessary
if isinstance(inputs, tuple):
inputs = inputs[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[bias issue] Here, we are using this inputs as the a_prev in our calculation for cov and lambda. While here we should append a torch.ones to the input to handle the bias.

ifvp = {}

for name in self.module_name:
_v = layer_test_rep[name + ".weight"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[bias issue] again, I think we need the ".bias".


Args:
func (Callable): A Python function that takes one or more arguments.
Must return the following,
- losses: a tensor of shape (batch_size,).
- loss: a single tensor of loss. Should be the mean loss by the
batch size.
- mask (optional): a tensor of shape (batch_size, t), where 1's
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we need to make this mask output clear in the init function's document of EKFAC IF Attributor

@TheaperDeng
Copy link
Collaborator

The EKFAC implementation overall looks good to me.

One comment is that it use loss.backward so we can not close the autograd during the attribution process. This is not consistent with other attribution, but I think it's fine for now.

Another comment is about the bias term, I have made some comments to the place I think need to change.

@sx-liu
Copy link
Collaborator Author

sx-liu commented Oct 17, 2024

I just incorporated the bias into the gradient and FIM calculation, but I found that the correlation drops drastically. I will try some other metrics and double check the implementation.

@sx-liu sx-liu requested review from jiaqima and TheaperDeng October 26, 2024 01:22
Copy link
Collaborator

@TheaperDeng TheaperDeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sx-liu , the last bug was really hard to identify. LGTM

@jiaqima jiaqima merged commit 0038c57 into TRAIS-Lab:main Oct 27, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants