diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 7db7b88f64..3c64737ee6 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -163,6 +163,11 @@ def apply_rematerialization_for_consumer( filter(lambda x: x.name not in map(lambda x: x.name, new_consumer_args), consumer.args) ) + # In the case where there are no tensors to rematerialize it is + # possible to terminate early and return the consumer as it was. + if not rematerialized_inputs: + return consumer + # Construct a temporary Trace object with subsymbols from the producer. trace = TraceCtx(None) trace.bound_symbols = producer.subsymbols diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 2224bfc9ba..cb4ece6aa0 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -210,7 +210,7 @@ def foo(a, x): assert len(fusions[0].subsymbols) == 3 # Verifies the intermediate consumer - assert fusions[1].subsymbols[-2].args[0].name == "g" + assert fusions[1].subsymbols[-1].args[0].name == "g" @instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,)) diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index d0f3462391..d5b57320fa 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -19,6 +19,7 @@ from thunder.examine import get_fusions from thunder.tests.framework import instantiate, NOTHING, nvFuserExecutor, TorchExecutor, requiresCUDA from thunder.tests.make_tensor import make_tensor +import thunder.torch as ltorch @value_and_grad @@ -199,6 +200,40 @@ def test_apply_rematerialization_consumer(executor, device, _): assert tuple(new_consumer.subsymbols) == tuple(new_consumer_case2.subsymbols) +@instantiate( + dtypes=NOTHING, + executors=(nvFuserExecutor,), +) +@disable_rematerialization_in_nvfuser_fusion +def test_apply_rematerialization_consumer_early_exit(executor, device, _): + @value_and_grad + def foo(t0): + t1 = ttorch.exp(t0) + t2 = ttorch.matmul(t1, t1) + return t2 + + t0 = make_tensor(2, 2, dtype=torch.float32, device=device) + initial_trace = thunder.trace()(foo, t0) + compiled_func = thunder.jit(initial_trace.python_callable()) + _ = compiled_func(t0) + traces = thunder.last_traces(compiled_func) + trace = traces[-1] + nvfuser_symbols = tuple(filter(lambda x: x.sym.name.startswith("nvFusion"), trace.bound_symbols)) + assert len(nvfuser_symbols) == 2 + + producer = nvfuser_symbols[0] + consumer = nvfuser_symbols[1] + + # Create a cut that has t0 as extra information and + # that contains all arguments(t2) from consumer. + cut = ("t0", "t2") + new_consumer = apply_rematerialization_for_consumer(producer, consumer, cut) + + # Check that the new consumer is the old consumer + assert id(new_consumer) == id(consumer) + assert tuple(new_consumer.subsymbols) == tuple(consumer.subsymbols) + + @instantiate( dtypes=NOTHING, executors=(nvFuserExecutor,),