From fef423b39661511b1bbb5a5e29b80c6260c228e6 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 28 Nov 2024 10:23:57 +0200 Subject: [PATCH] Register grad for torch.Tensor.item (#1481) The rule is simple because PyTorch doesn't support propagating grads through item calls as it returns a disconnected from PyTorch Autograd graph Python scalar. Fixes #1479. --- thunder/tests/test_grad.py | 15 +++++++++++++++ thunder/torch/__init__.py | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index fbd479cb3..7a6d5b9a8 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -623,6 +623,21 @@ def test_vjp_correctness_zeta_manual(op, device, dtype, executor, comp): comp(grad_rhs, expected_grad[0], equal_nan=True) +@ops((get_opinfo("item"),), supported_dtypes=(dtypes.float64,)) +def test_vjp_correctness_torch_item_manual(op, device, dtype, executor, comp): + from thunder.torch import item + + for sample in op.sample_inputs(device, dtype, requires_grad=True, no_rhs_numbers=True): + out = op.torch_reference(*sample.args, **sample.kwargs) + flat_op, flat_args, spec = flatten_func(item, sample.args, sample.kwargs) + initial_trace = thunder.trace()(vjp(flat_op), flat_args, (None,)) + actual_out, (grad_in,) = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)( + flat_args, (None,) + ) + assert grad_in is None, "grad_in should be None" + comp(actual_out, out, equal_nan=True) + + @ops((get_opinfo("nll_loss"),), supported_dtypes=(dtypes.float64,)) def test_vjp_correctness_nll_loss_manual(op, device, dtype, executor, comp): for sample in op.sample_inputs(device, dtype, requires_grad=True, no_rhs_numbers=True): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 5ec568e02..6327b4d05 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -4889,6 +4889,10 @@ def item(a: TensorLike) -> Number: return prims.item(a) +# PyTorch does not support backward for torch.item +register_grad(item.id, item) + + # TODO Move this to nn.functional @torchsymbol(torch.nn.functional.linear) def linear(a: TensorLike, w: TensorLike, /, bias: None | TensorLike = None) -> TensorLike: