-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inplace] Silently incorrect gradient when leaf variable is used in an inplace operation #1284
Comments
Thunder's backward function simply returns None: In [8]: thunder.last_backward_traces(jforward)[0]
Out[8]:
# Constructed by Backward pass
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
t2, = cotangents
# C0 (empty sequence)
# C1 (empty sequence)
return (None,) This is happening because copy_'s backward rule is incorrectly implemented to do that: In [16]: from thunder.core.transforms import augmented_forward_impls
In [17]: from thunder.core.transforms import backward_impls
In [18]: from inspect import getsource
In [19]: print(getsource(augmented_forward_impls[thunder.prims.copy_.id]))
prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()),
In [20]: print(getsource(backward_impls[thunder.prims.copy_.id]))
prims.PrimIDs.COPY_: lambda g: (None, None), |
Even if the grad rule is fixed there's another problem. There's no grad_fn assigned to the output: import torch
import thunder
def forward(x):
return x.mul_(2)
x = torch.ones(3, requires_grad=True)
v = torch.ones_like(x, requires_grad=False)
jforward = thunder.jit(forward)
out = jforward(x)
assert out.grad_fn |
Here's a working example: def _copy__impl(copy_from, copy_to):
copy_to.copy_(copy_from)
return copy_from # Current Thunder implementation returns copy_to
import thunder.executors.torchex
ex = thunder.executors.torchex.ex
thunder.executors.torchex._copy__impl = _copy__impl
thunder.executors.torchex.copy_ = ex.register_operator("copy_", meta=thunder.prims.copy_, tags=(thunder.prims.OpTags.DONT_DCE,), fn=_copy__impl)
thunder.executors.torchex._register_implementation(thunder.prims.copy_, thunder.executors.torchex.copy_, checker=thunder.executors.torchex._always_executable)
import thunder
import torch
from thunder.core.transforms import backward_impls
backward_impls[thunder.prims.copy_.id] = lambda g: (g, None) # Current Thunder implementation returns None, None
def f(x):
return x.mul_(2)
x = torch.ones(3, requires_grad=True)
v = torch.ones_like(x, requires_grad=False)
jf = thunder.jit(f, skip_inplace_functionalization=True) # It doesn't work with skip_inplace_functionalization=False
out = jf(x); print(out)
assert out.grad_fn
out.backward(v)
torch.testing.assert_close(x.grad, 2.0 * v) There are two main code changes:
After this there's a bug left in the return {'output': t1, 'flat_args': [x], 'flat_output': (t1,)}, ((), ()) skip_inplace_functionalization=False: return {'output': t1, 'flat_args': [t1], 'flat_output': (t1,)}, ((), ()) Alternatively to making this example work (Thunder is perfectly capable to do that) we can imitate PyTorch's behavior by combining |
I think it would be ok (and probably easy) to error out. As PyTorch doesn't support this, I think it is unlikely that user will write this code and expect it to work. |
I agree with erroring out. @beverlylytle, could you please help us here and implement raising an error for this case in Thunder? |
The text was updated successfully, but these errors were encountered: