You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It's a very nieche problem, but tripped me over big time :')
Issue
For model.eval() , z_pred will not have tracked gradients (z_pred.requires_gradient==False).
For custom torch.autograd this will lead to an error: RuntimeError: One of the differentiated Tensors does not require grad.
Minimal example
import torch
import torchdeq
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer = torch.nn.Linear(10, 10)
# deq
self.deq = get_deq()
apply_norm(self.layer, 'weight_norm')
def implicit_layer(self, x):
return self.layer(x)
def forward(self, x, pos):
z = torch.zeros_like(x)
reset_norm(self.layer)
f = lambda z: self.f(z, pos)
z_pred, info = self.deq(self.implicit_layer, z)
# if model.eval() -> z_pred[-1].requires_grad is False!
energy = z_pred[-1]
forces = -1 * (
torch.autograd.grad(
energy,
# diff with respect to pos
# if you get 'One of the differentiated Tensors appears to not have been used in the graph'
# then because pos is not 'used' to calculate the energy
pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
# allow_unused=True,
)[0]
)
return energy, forces
def run(model, eval=False):
if eval:
model.eval()
else:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for step in range(10):
x = torch.randn(10, 10)
pos = torch.randn(10, 3)
energy, forces = model(x, pos)
# loss
optimizer.zero_grad()
energy_target = torch.randn(10, 1)
energy_loss = torch.nn.functional.mse_loss(energy, energy_target)
force_target = torch.randn(10, 3)
force_loss = torch.nn.functional.mse_loss(forces, force_target)
loss = energy_loss + force_loss
if not eval:
loss.backward()
optimizer.step()
return True
if __name__ == '__main__':
model = MyModel()
success = run(model, eval=False)
print(f'train success: {success}')
success = run(model, eval=True)
print(f'eval success: {success}')
While model.train() it will work perfectly well. For model.eval() we get the error: RuntimeError: One of the differentiated Tensors does not require grad.
Desired behaviour
A flag to set such that z_pred[-1].requires_grad is always True, even when model.eval(). self.deq = get_deq(grad_in_eval=True)
The text was updated successfully, but these errors were encountered:
Thanks a lot for your interest! I think a quick fix is to enable self.deq to be in the train mode while other components of the model are in eval mode.
I appreciate the suggestion! I think we can implement such a feature into the lib. Feel free to submit a PR.
I'll be back to close this issue soon.
It's a very nieche problem, but tripped me over big time :')
Issue
For
model.eval()
,z_pred
will not have tracked gradients (z_pred.requires_gradient==False
).For custom torch.autograd this will lead to an error:
RuntimeError: One of the differentiated Tensors does not require grad
.Minimal example
While
model.train()
it will work perfectly well. Formodel.eval()
we get the error:RuntimeError: One of the differentiated Tensors does not require grad
.Desired behaviour
A flag to set such that
z_pred[-1].requires_grad
is alwaysTrue
, even whenmodel.eval()
.self.deq = get_deq(grad_in_eval=True)
The text was updated successfully, but these errors were encountered: