Replies: 1 comment 3 replies
-
Hello! from tensordict import TensorDict
from tensordict.nn import TensorDictModule as Mod
import torch
mod = Mod(lambda x: (x+1).mean(), in_keys=["a"], out_keys=["b"])
td = TensorDict(a=torch.randn(10, 11, requires_grad=True), batch_size=[10])
vmap_mod = torch.vmap(mod, (0,))
td_out = vmap_mod(td)
print(td_out)
grad = torch.func.grad(lambda td: mod(td)["b"])
print(grad(td[0])["a"])
vmap_grad = torch.vmap(grad, (0,))
td_out = vmap_grad(td)
print(td_out["a"]) So vmapping a grad works (gradding a vmap doesn't I believe) Is there a version of this script that explains what you're trying to do and where it breaks? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone,
I am trying to move my code to tensordict(td) and vmap along batch according to https://discuss.pytorch.org/t/how-to-apply-vmap-on-a-heterogeneous-tensor/214109/5. Since vmap not support batched tensordict, now I just use batchsize = 1.
Nevertheless, everything works super good until I need to rewrite derivative module. I have a td with distance between atoms and predict energy by model, and the I need to calculate force by derivating energy w.r.t distance.
Before I use
torch.grad
, very straightforward to use. But within vmap, it raises:element 0 of tensors does not require grad and does not have a grad_fn
. I findBatchedtensor
of energy hasgrad_fn
, but "vmaped" tensor, the real tensor required_grad is false.Apparently(?), we can not combine vmap with torch.grad, but
torch.func.grad
, according to https://discuss.pytorch.org/t/use-vmap-and-grad-to-calculate-gradients-for-one-layer-independently-for-each-input-in-batch/187556/2?u=roy-kid, and https://discuss.pytorch.org/t/simple-use-case-compete-per-sample-gradient-with-autograd/207317/4?u=roy-kidSince
torch.func.vmap
can only specify the position of arguments, how can I derivateenergy
w.r.tdistance
. Here is my pesudo code:So thx for your help!
Also post in forum: https://discuss.pytorch.org/t/combine-vmap-func-grad-with-tensordict/215086?u=roy-kid
Beta Was this translation helpful? Give feedback.
All reactions