Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jacobian computation fails for last layer laplace #265

Open
magnusross opened this issue Dec 5, 2024 · 3 comments
Open

Jacobian computation fails for last layer laplace #265

magnusross opened this issue Dec 5, 2024 · 3 comments

Comments

@magnusross
Copy link

I am having some issues running last layer laplace for a certain type of model. I have included a minimal example below:

import torch
from laplace import Laplace
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn

# Define a new model for 2D input
class SingleLinearLayer2D(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SingleLinearLayer2D, self).__init__()
        # self.layers = nn.ModuleList([nn.Linear(input_dim, output_dim), nn.Linear(output_dim, output_dim)])
        self.layers = nn.ModuleList([nn.Linear(input_dim, output_dim)])
        

    def forward(self, x):
        # Flatten the 2D input to 1D
        for layer in self.layers:
            x = layer(x)
        return x.flatten(1,2)
    
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")
# Generate some 2D data
x = torch.randn(100, 10, 10, device=device)
y = torch.randn(100, 50, device=device)

# Create a DataLoader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Create the new model
model = SingleLinearLayer2D(input_dim=10, output_dim=5)
model = model.to(device)

la = Laplace(model, 'regression', hessian_structure='kron', subset_of_weights='last_layer')
la.fit(dataloader)
try:
    la(x, pred_type="glm", n_samples=10, link_approx='mc')
except RuntimeError as e:
    print(f"Predict failed for kron: {e}")

la = Laplace(model, 'regression', hessian_structure='diag', subset_of_weights='last_layer')
try:
    la.fit(dataloader)
except RuntimeError as e:
    print(f"Fit failed for diag: {e}")

The error is always something like this

  File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/baselaplace.py", line 850, in fit
    loss_batch, H_batch = self._curv_closure(X, y, N=N)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/baselaplace.py", line 1857, in _curv_closure
    return self.backend.diag(X, y, N=N, **self._asdl_fisher_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/curvature/curvature.py", line 417, in diag
    Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/curvature/curvature.py", line 162, in last_layer_jacobians
    Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/torch/functional.py", line 402, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

You can see that for the diag approximation, the fit fails, and for the kron approximation, the fit works but the glm prediction fails. The issue is the same on cuda or cpu. This same code works fine if subset_of_weights="all". I am not exactly sure what model architectures cause this issue, but it seems to fail when the input to the model has more than one dimension, but the output has a single dimension. It also fails for model with multiple layers, as you can see from the commented code.

Sorry I feel I haven't explained the issue super clearly, but hopefully the example gives enough information. Plase let me know if I can provide anything else!

Thanks :)

@magnusross
Copy link
Author

Note I have also tried different backends to CurvlinopsEF (e.g. AsdlGGN) and it doesn't fix the problem

@wiseodd
Copy link
Collaborator

wiseodd commented Dec 12, 2024

Can you use this instead? #254

See also docs: https://aleximmer.github.io/Laplace/huggingface_example/#laplace-on-a-subset-of-an-llms-weights

@magnusross
Copy link
Author

Thank, I'll try that. From reading some of the docs and other issues it seems that if you use the "switching off gradients" approach, you lose some performance, is that the case? In particular, would like to make use of the fast variance for predictions, since my use case has many (100s-1000s) of outputs, is that possible without using LLLaplace?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants