Skip to content

Commit

Permalink
Proxy order KeyError fix (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardofelluga authored Oct 4, 2024
1 parent da1a441 commit b4ca020
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
5 changes: 5 additions & 0 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def apply_rematerialization_for_consumer(
filter(lambda x: x.name not in map(lambda x: x.name, new_consumer_args), consumer.args)
)

# In the case where there are no tensors to rematerialize it is
# possible to terminate early and return the consumer as it was.
if not rematerialized_inputs:
return consumer

# Construct a temporary Trace object with subsymbols from the producer.
trace = TraceCtx(None)
trace.bound_symbols = producer.subsymbols
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def foo(a, x):
assert len(fusions[0].subsymbols) == 3

# Verifies the intermediate consumer
assert fusions[1].subsymbols[-2].args[0].name == "g"
assert fusions[1].subsymbols[-1].args[0].name == "g"


@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,))
Expand Down
35 changes: 35 additions & 0 deletions thunder/tests/test_nvfuser_remat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from thunder.examine import get_fusions
from thunder.tests.framework import instantiate, NOTHING, nvFuserExecutor, TorchExecutor, requiresCUDA
from thunder.tests.make_tensor import make_tensor
import thunder.torch as ltorch


@value_and_grad
Expand Down Expand Up @@ -199,6 +200,40 @@ def test_apply_rematerialization_consumer(executor, device, _):
assert tuple(new_consumer.subsymbols) == tuple(new_consumer_case2.subsymbols)


@instantiate(
dtypes=NOTHING,
executors=(nvFuserExecutor,),
)
@disable_rematerialization_in_nvfuser_fusion
def test_apply_rematerialization_consumer_early_exit(executor, device, _):
@value_and_grad
def foo(t0):
t1 = ttorch.exp(t0)
t2 = ttorch.matmul(t1, t1)
return t2

t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
initial_trace = thunder.trace()(foo, t0)
compiled_func = thunder.jit(initial_trace.python_callable())
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
nvfuser_symbols = tuple(filter(lambda x: x.sym.name.startswith("nvFusion"), trace.bound_symbols))
assert len(nvfuser_symbols) == 2

producer = nvfuser_symbols[0]
consumer = nvfuser_symbols[1]

# Create a cut that has t0 as extra information and
# that contains all arguments(t2) from consumer.
cut = ("t0", "t2")
new_consumer = apply_rematerialization_for_consumer(producer, consumer, cut)

# Check that the new consumer is the old consumer
assert id(new_consumer) == id(consumer)
assert tuple(new_consumer.subsymbols) == tuple(consumer.subsymbols)


@instantiate(
dtypes=NOTHING,
executors=(nvFuserExecutor,),
Expand Down

0 comments on commit b4ca020

Please sign in to comment.