Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inplace] Silently incorrect gradient when leaf variable is used in an inplace operation #1284

Open
kshitij12345 opened this issue Oct 10, 2024 · 5 comments · May be fixed by #1458
Open

[inplace] Silently incorrect gradient when leaf variable is used in an inplace operation #1284

kshitij12345 opened this issue Oct 10, 2024 · 5 comments · May be fixed by #1458

Comments

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Oct 10, 2024

import torch
import thunder

def forward(x):
    return x.mul_(2)

x = torch.ones(3, requires_grad=True)
v = torch.ones_like(x, requires_grad=False)

# PyTorch Eager errors out.
# forward(x) # RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

jforward = thunder.jit(forward)
jforward(x).backward(v)
print(x.grad)  # tensor([1., 1., 1.])
@IvanYashchuk
Copy link
Collaborator

Thunder's backward function simply returns None:

In [8]: thunder.last_backward_traces(jforward)[0]
Out[8]: 
# Constructed by Backward pass
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  t2, = cotangents
  # C0 (empty sequence)
  # C1 (empty sequence)
  return (None,)

This is happening because copy_'s backward rule is incorrectly implemented to do that:

In [16]: from thunder.core.transforms import augmented_forward_impls

In [17]: from thunder.core.transforms import backward_impls

In [18]: from inspect import getsource

In [19]: print(getsource(augmented_forward_impls[thunder.prims.copy_.id]))
    prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()),


In [20]: print(getsource(backward_impls[thunder.prims.copy_.id]))
    prims.PrimIDs.COPY_: lambda g: (None, None),

@IvanYashchuk
Copy link
Collaborator

Even if the grad rule is fixed there's another problem. There's no grad_fn assigned to the output:

import torch
import thunder

def forward(x):
    return x.mul_(2)

x = torch.ones(3, requires_grad=True)
v = torch.ones_like(x, requires_grad=False)

jforward = thunder.jit(forward)
out = jforward(x)
assert out.grad_fn

@IvanYashchuk
Copy link
Collaborator

Here's a working example:

def _copy__impl(copy_from, copy_to):
    copy_to.copy_(copy_from)
    return copy_from # Current Thunder implementation returns copy_to

import thunder.executors.torchex
ex = thunder.executors.torchex.ex
thunder.executors.torchex._copy__impl = _copy__impl
thunder.executors.torchex.copy_ = ex.register_operator("copy_", meta=thunder.prims.copy_, tags=(thunder.prims.OpTags.DONT_DCE,), fn=_copy__impl)
thunder.executors.torchex._register_implementation(thunder.prims.copy_, thunder.executors.torchex.copy_, checker=thunder.executors.torchex._always_executable)

import thunder
import torch

from thunder.core.transforms import backward_impls
backward_impls[thunder.prims.copy_.id] = lambda g: (g, None) # Current Thunder implementation returns None, None

def f(x):
    return x.mul_(2)

x = torch.ones(3, requires_grad=True)
v = torch.ones_like(x, requires_grad=False)
jf = thunder.jit(f, skip_inplace_functionalization=True) # It doesn't work with skip_inplace_functionalization=False

out = jf(x); print(out)
assert out.grad_fn

out.backward(v)
torch.testing.assert_close(x.grad, 2.0 * v)

There are two main code changes:

  1. Modify _copy__impl implementation to return copy_from instead. This has something to do with the internals of PyTorch's Autograd, a grad_fn is not created when copy_to is returned there.
  2. Modify copy_'s backward rule in Thunder to propagate the provided tensor.

After this there's a bug left in the skip_inplace_functionalization=False code path preventing creation of grad_fn. The main difference between the True/False paths is that in the True path the "flat_args" output is replaced with the output of copy_. Here is the return statement from last execution trace for both paths:
skip_inplace_functionalization=True:

return {'output': t1, 'flat_args': [x], 'flat_output': (t1,)}, ((), ())

skip_inplace_functionalization=False:

return {'output': t1, 'flat_args': [t1], 'flat_output': (t1,)}, ((), ())

Alternatively to making this example work (Thunder is perfectly capable to do that) we can imitate PyTorch's behavior by combining is_leaf and requires_grad attributes to raise the same error from within _copy__impl function.
PyTorch's error: https://github.com/pytorch/pytorch/blob/4d9b5a87e4cc1e2ff81dce5123ee98a7c1b2d6a8/torch/csrc/autograd/VariableTypeUtils.h#L80-L84
which is checked here when copy_ is called on the leaf tensor requiring grad:
https://github.com/pytorch/pytorch/blob/4d9b5a87e4cc1e2ff81dce5123ee98a7c1b2d6a8/torch/csrc/autograd/VariableTypeManual.cpp#L204

@IvanYashchuk IvanYashchuk removed their assignment Oct 28, 2024
@kshitij12345
Copy link
Collaborator Author

I think it would be ok (and probably easy) to error out. As PyTorch doesn't support this, I think it is unlikely that user will write this code and expect it to work.

@IvanYashchuk
Copy link
Collaborator

I agree with erroring out. @beverlylytle, could you please help us here and implement raising an error for this case in Thunder?

@beverlylytle beverlylytle self-assigned this Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants