diff --git a/notebooks/writing_a_trace_transform_cpu_offloading.ipynb b/notebooks/writing_a_trace_transform_cpu_offloading.ipynb index f692a9f40e..86d8263769 100644 --- a/notebooks/writing_a_trace_transform_cpu_offloading.ipynb +++ b/notebooks/writing_a_trace_transform_cpu_offloading.ipynb @@ -525,7 +525,7 @@ "\n", "# Verify that saved tensors are on CPU.\n", "saved_tensor_devices = set()\n", - "for t in actual.grad_fn.next_functions[0][0].saved_tensors:\n", + "for t in actual.grad_fn.saved_tensors:\n", " saved_tensor_devices.add(str(t.device))\n", "\n", "assert \"cpu\" in saved_tensor_devices # Verify that we actually have saved tensors on CPU\n", diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 0f3faeaf4a..8a9827ca30 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -105,16 +105,18 @@ def detach_if_tensor(t): saved_tensors = tuple(map(detach_if_tensor, saved_tensors)) - # We must save tensors using ctx.save_for_backward - ctx.save_for_backward(*saved_tensors) - ctx.side_channel = side_channel if side_channel is not None: assert not side_channel ctx.side_channel["fw"] = flat_output - + # We must save tensors using ctx.save_for_backward but + # we want to save the tensors in the function returning the outputs to avoid memory leaks + # (basically ref-cycles via output.grad_fn.next_functions[0, 0].saved_tensors[0] == output + # PyTorch autograd handles this gracefully for output.grad_fn.saved_tensors) + ctx.side_channel["tensors_to_save"] = saved_tensors return torch.randn(1, device="meta", requires_grad=True) else: + ctx.save_for_backward(*saved_tensors) return flat_output # NOTE: If `torch.autograd.function.once_differentiable` is to be removed, @@ -125,25 +127,26 @@ def detach_if_tensor(t): def backward(ctx, *raw_args): if ctx.side_channel is not None: args = ctx.side_channel.pop("bw") + saved_tensors_list = ctx.side_channel.pop("saved_tensors") assert not ctx.side_channel else: args = list(raw_args) - # ctx.saved_tensors is a tuple of tensors saved in forward. Our compiled - # backward is a really long function that takes all the tensors saved in - # forward and gradually uses them to compute the gradients of the - # inputs. Unfortunately, Python holds a reference to all arguments of a - # function until the function returns, even if we delete the variable - # "saved_tensors" inside the function, the tensors will still be held in - # memory until the function returns. Fortunately, Python passes mutable - # objects by reference, so we can just replace the saved_tensors with an - # empty list and the memory will be freed immediately. We must also - # delete the reference to the saved_tensors in the context, otherwise - # the memory will be freed only when the context is deleted. - saved_tensors_list = list(ctx.saved_tensors) # Make a copy as we will mutate it - - # This is an undocumented API, but it's the only way to clear the - # reference to the saved tensors in the context - ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context + # ctx.saved_tensors is a tuple of tensors saved in forward. Our compiled + # backward is a really long function that takes all the tensors saved in + # forward and gradually uses them to compute the gradients of the + # inputs. Unfortunately, Python holds a reference to all arguments of a + # function until the function returns, even if we delete the variable + # "saved_tensors" inside the function, the tensors will still be held in + # memory until the function returns. Fortunately, Python passes mutable + # objects by reference, so we can just replace the saved_tensors with an + # empty list and the memory will be freed immediately. We must also + # delete the reference to the saved_tensors in the context, otherwise + # the memory will be freed only when the context is deleted. + saved_tensors_list = list(ctx.saved_tensors) # Make a copy as we will mutate it + + # This is an undocumented API, but it's the only way to clear the + # reference to the saved tensors in the context + ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args) assert not args @@ -165,6 +168,7 @@ def forward(ctx, dummy, side_channel, *args): ctx.side_channel = side_channel ctx.num_args = len(args) res = ctx.side_channel.pop("fw") + ctx.save_for_backward(*ctx.side_channel.pop("tensors_to_save")) assert not ctx.side_channel return res @@ -172,6 +176,8 @@ def forward(ctx, dummy, side_channel, *args): def backward(ctx, *args): assert not ctx.side_channel ctx.side_channel["bw"] = list(args) + ctx.side_channel["saved_tensors"] = list(ctx.saved_tensors) # see above + ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context return torch.randn(1, device="meta"), None, *([None] * ctx.num_args) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 68e8a8e56b..900b026d6f 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -3187,3 +3187,26 @@ def fn(x): ): # prims is unpack_sequence and any output is TensorProxy # Verify that we print information about the unpacked TensorProxy. assert "cpu f32[3]" in str(bsym) + + +def test_apply_autograd_memory(): + from thunder.executors.torch_autograd import connect_to_autograd + + def foo(): + def backward(*args): + return None + + x = torch.randn(2, 2, requires_grad=True) + o = x.sum() + + connect_to_autograd( + backward_fn=backward, + flat_args=(x,), + flat_output=(o,), + saved_tensors=(o,), + saved_other=(), + return_none_instead_of_grads=True, + ) + return [weakref.ref(x), weakref.ref(o)] + + assert not any(wr() for wr in foo()) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index af1c3232cb..718465538c 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1752,10 +1752,10 @@ def f(x, y): # With activation checkpointing, we are saving only the original input. # The intermediate values are recomputed during backward pass. - assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2 + assert len(out.grad_fn.saved_tensors) == 2 # We detach the saved tensors (which returns a new Python tensor backed by same storage) # the order seems to be non-deterministic sometimes - assert {t.data_ptr() for t in out.grad_fn.next_functions[0][0].saved_tensors} == {x.data_ptr(), y.data_ptr()} + assert {t.data_ptr() for t in out.grad_fn.saved_tensors} == {x.data_ptr(), y.data_ptr()} g = torch.ones_like(out) out.backward(g) @@ -1948,8 +1948,8 @@ def fn(a): a = torch.randn(2, 2, device=device, requires_grad=True) res = jfn(a) res2 = jfn2(a) - assert len(res.grad_fn.next_functions[0][0].saved_tensors) == 3 # should be decomposed - assert len(res2.grad_fn.next_functions[0][0].saved_tensors) == 1 + assert len(res.grad_fn.saved_tensors) == 3 # should be decomposed + assert len(res2.grad_fn.saved_tensors) == 1 if NVFUSER_AVAILABLE and device == "cuda": # check everything is fused