You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am having some issues running last layer laplace for a certain type of model. I have included a minimal example below:
importtorchfromlaplaceimportLaplacefromtorch.utils.dataimportDataLoader, TensorDatasetimporttorch.nnasnn# Define a new model for 2D inputclassSingleLinearLayer2D(nn.Module):
def__init__(self, input_dim, output_dim):
super(SingleLinearLayer2D, self).__init__()
# self.layers = nn.ModuleList([nn.Linear(input_dim, output_dim), nn.Linear(output_dim, output_dim)])self.layers=nn.ModuleList([nn.Linear(input_dim, output_dim)])
defforward(self, x):
# Flatten the 2D input to 1Dforlayerinself.layers:
x=layer(x)
returnx.flatten(1,2)
device="cuda"iftorch.cuda.is_available() else"cpu"print(f"Using {device}")
# Generate some 2D datax=torch.randn(100, 10, 10, device=device)
y=torch.randn(100, 50, device=device)
# Create a DataLoaderdataset=TensorDataset(x, y)
dataloader=DataLoader(dataset, batch_size=32, shuffle=True)
# Create the new modelmodel=SingleLinearLayer2D(input_dim=10, output_dim=5)
model=model.to(device)
la=Laplace(model, 'regression', hessian_structure='kron', subset_of_weights='last_layer')
la.fit(dataloader)
try:
la(x, pred_type="glm", n_samples=10, link_approx='mc')
exceptRuntimeErrorase:
print(f"Predict failed for kron: {e}")
la=Laplace(model, 'regression', hessian_structure='diag', subset_of_weights='last_layer')
try:
la.fit(dataloader)
exceptRuntimeErrorase:
print(f"Fit failed for diag: {e}")
The error is always something like this
File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/baselaplace.py", line 850, in fit
loss_batch, H_batch = self._curv_closure(X, y, N=N)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/baselaplace.py", line 1857, in _curv_closure
return self.backend.diag(X, y, N=N, **self._asdl_fisher_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/curvature/curvature.py", line 417, in diag
Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/laplace/curvature/curvature.py", line 162, in last_layer_jacobians
Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/magnus/.conda/envs/laplace-tsf/lib/python3.12/site-packages/torch/functional.py", line 402, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given
You can see that for the diag approximation, the fit fails, and for the kron approximation, the fit works but the glm prediction fails. The issue is the same on cuda or cpu. This same code works fine if subset_of_weights="all". I am not exactly sure what model architectures cause this issue, but it seems to fail when the input to the model has more than one dimension, but the output has a single dimension. It also fails for model with multiple layers, as you can see from the commented code.
Sorry I feel I haven't explained the issue super clearly, but hopefully the example gives enough information. Plase let me know if I can provide anything else!
Thanks :)
The text was updated successfully, but these errors were encountered:
Thank, I'll try that. From reading some of the docs and other issues it seems that if you use the "switching off gradients" approach, you lose some performance, is that the case? In particular, would like to make use of the fast variance for predictions, since my use case has many (100s-1000s) of outputs, is that possible without using LLLaplace?
I am having some issues running last layer laplace for a certain type of model. I have included a minimal example below:
The error is always something like this
You can see that for the diag approximation, the fit fails, and for the kron approximation, the fit works but the glm prediction fails. The issue is the same on cuda or cpu. This same code works fine if subset_of_weights="all". I am not exactly sure what model architectures cause this issue, but it seems to fail when the input to the model has more than one dimension, but the output has a single dimension. It also fails for model with multiple layers, as you can see from the commented code.
Sorry I feel I haven't explained the issue super clearly, but hopefully the example gives enough information. Plase let me know if I can provide anything else!
Thanks :)
The text was updated successfully, but these errors were encountered: