From c8714eb315f2f5da2bf06720b8e2074b96ecebdd Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 24 Jul 2024 14:47:01 -0700 Subject: [PATCH 1/2] clean up on tests --- thunder/tests/test_grad.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 2b35f1e88b..2963c90a41 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1630,13 +1630,11 @@ def foo(a, c): static_jit = thunder.jit(foo) out = dynamic_jit(a, c) - torch.autograd.backward(out, torch.rand_like(out), retain_graph=True) dynamic_trace = thunder.last_backward_traces(dynamic_jit)[-1] # dynamic trace should save `c` as proxy for backward assert any(map(lambda x: isinstance(x, Proxy), tree_flatten(dynamic_trace.args[0])[0])) out = static_jit(a, c) - torch.autograd.backward(out, torch.rand_like(out), retain_graph=True) static_trace = thunder.last_backward_traces(static_jit)[-1] # static trace should bake `c` as scalar number, so it won't show up in backward as proxy assert not any(map(lambda x: isinstance(x, Proxy), tree_flatten(static_trace.args[0])[0])) From 89b68369876379c294738c52fc7a494125c9ed6a Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 24 Jul 2024 14:59:14 -0700 Subject: [PATCH 2/2] renaming variables --- thunder/core/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index c4e3eebef2..e7bdeb8df7 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3743,8 +3743,8 @@ def backward_fn(saved_for_backward, cotangents): proxified if isinstance(entry, Proxy) else entry for proxified, entry in zip(flat_saves_proxified, flat_saves) ] - saved_for_backward = tree_unflatten(flat_filtered, saves_spec) - env = reconstruct_forward_env_for_backward(trace, saved_for_backward) + unproxiefied_saved_for_backward = tree_unflatten(flat_filtered, saves_spec) + env = reconstruct_forward_env_for_backward(trace, unproxiefied_saved_for_backward) if torch_autograd: cotangents = tree_unflatten(cotangents, output_spec)