You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since torchinterp1d uses torch.autograd.Function, it is not compatible by default with vmap. Here's an example of code that will not run:
importtorchfromtorchinterp1dimportinterp1ddefinterpolate(xp):
x=torch.linspace(-5, 5, 100)
y=x**3returninterp1d(x, y, torch.atleast_1d(xp))
xp=torch.rand(20) *10-5print(f"{xp=}, {torch.vmap(interpolate)(xp)=}")
The relevant part of the stack trace is:
RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
For more details, please see https://pytorch.org/docs/master/notes/extending.func.html
Based on the PyTorch docs, the fix may be as easy as setting generate_vmap_rule=True in torchinterp1d, but I haven't looked into this yet.
It'd be great to get a fix for this since vmap is incredible useful.
The text was updated successfully, but these errors were encountered:
Since
torchinterp1d
usestorch.autograd.Function
, it is not compatible by default withvmap
. Here's an example of code that will not run:The relevant part of the stack trace is:
Based on the PyTorch docs, the fix may be as easy as setting
generate_vmap_rule=True
intorchinterp1d
, but I haven't looked into this yet.It'd be great to get a fix for this since
vmap
is incredible useful.The text was updated successfully, but these errors were encountered: