Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR 848 clean up. #857

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3743,8 +3743,8 @@ def backward_fn(saved_for_backward, cotangents):
proxified if isinstance(entry, Proxy) else entry
for proxified, entry in zip(flat_saves_proxified, flat_saves)
]
saved_for_backward = tree_unflatten(flat_filtered, saves_spec)
env = reconstruct_forward_env_for_backward(trace, saved_for_backward)
unproxiefied_saved_for_backward = tree_unflatten(flat_filtered, saves_spec)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanYashchuk embarrassingly I don't quite get how this would work with a tree_map... how would I be able to zip two pytree while still keep them as a pytree?

env = reconstruct_forward_env_for_backward(trace, unproxiefied_saved_for_backward)

if torch_autograd:
cotangents = tree_unflatten(cotangents, output_spec)
Expand Down
2 changes: 0 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,13 +1630,11 @@ def foo(a, c):
static_jit = thunder.jit(foo)

out = dynamic_jit(a, c)
torch.autograd.backward(out, torch.rand_like(out), retain_graph=True)
dynamic_trace = thunder.last_backward_traces(dynamic_jit)[-1]
# dynamic trace should save `c` as proxy for backward
assert any(map(lambda x: isinstance(x, Proxy), tree_flatten(dynamic_trace.args[0])[0]))

out = static_jit(a, c)
torch.autograd.backward(out, torch.rand_like(out), retain_graph=True)
static_trace = thunder.last_backward_traces(static_jit)[-1]
# static trace should bake `c` as scalar number, so it won't show up in backward as proxy
assert not any(map(lambda x: isinstance(x, Proxy), tree_flatten(static_trace.args[0])[0]))
Loading