diff --git a/thunder/dynamo/compiler_graph_benchmark.py b/thunder/dynamo/compiler_graph_benchmark.py index eafd30ce0..ddb7f80e5 100644 --- a/thunder/dynamo/compiler_graph_benchmark.py +++ b/thunder/dynamo/compiler_graph_benchmark.py @@ -2,6 +2,7 @@ from itertools import chain from pytest_benchmark.fixture import BenchmarkFixture from typing import TYPE_CHECKING +from looseversion import LooseVersion import torch from thunder.dynamo import ThunderCompiler @@ -124,6 +125,23 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args): def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): split_module = super().__call__(gm, sample_args) + + def has_checkpoint_node(g): + if g.find_nodes(op="call_function", target=torch.ops.higher_order.tag_activation_checkpoint): + return True + for n in g.nodes: + if n.op == "call_module" and has_checkpoint_node(getattr(g.owning_module, n.target).graph): + return True + return False + + if LooseVersion(torch.__version__) < LooseVersion("2.6.0"): + # NOTE: PyTorch 2.6 changes the structure of GraphModule when using activation checkpointing. + # It's hard to retrieve the example input tensor for the GraphModule contains checkpoint operator before PyTorch 2.6 + if has_checkpoint_node(split_module.graph): + raise RuntimeError( + "The benchmarking of the Torch activation checkpointing is only supported with PyTorch version 2.6 or later." + ) + compiled_functions_to_submodule = { v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items() } diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index b123400ec..b128357b9 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING +import copy import torch from torch.fx.passes.split_module import split_module @@ -131,9 +132,10 @@ def callback(node) -> int: return partition_cnt # `split_module` iterates over nodes and determines the partition to place them based on the callback. - split_gm: torch.fx.GraphModule = split_module( + original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) + split_gm = copy.deepcopy(original_split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions @@ -142,6 +144,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: thunder_compiled_fns = [] submodule_to_compiled_fns = {} for node in split_gm.graph.nodes: + node_name = node.name if is_thunder_supported_partition(node): graph_module = getattr(split_gm, node.name) # Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators @@ -150,13 +153,17 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: # Update the node name from "submod_*" to "thunder_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) thunder_compiled_fns.append(jit_fn) - submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) + submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( + jit_fn, CompilerType.THUNDER + ) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) jit_fn = torch_inductor(graph_module) # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) - submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) + submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( + jit_fn, CompilerType.TORCH_INDUCTOR + ) else: # Everything else is a glue code to call and pass outputs between the other partitions. pass @@ -166,6 +173,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: return split_gm, SubgraphInfo( gm, + original_split_gm, split_gm, thunder_compiled_fns, submodule_to_compiled_fns, diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index d434d0234..668f2ef0b 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -80,17 +80,21 @@ class SubgraphInfo: Attributes: original_graph_module: The original graph module. - split_graph_module: The graph module for the split subgraph. + original_split_graph_module: The original split graph module before any transformations are applied. + Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols, + and before any submodules are compiled by Thunder. + split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules. thunder_compiled_fns: List of thunder optimized callables. This could be :obj:`None` if there the graph module was not supported by thunder. Look at the :attr:`split_reasons` for further information. - submodule_to_compiled_functions: Dict from subgraph to compiled function. + submodule_to_compiled_functions: Dict from subgraph in :attr:`original_split_graph_module` to compiled function. This will be a dict with one pair in case the graph was not split. split_reasons: List of reasons explaining why the subgraph was split. Present only if there are was a split. """ original_graph_module: torch.fx.GraphModule + original_split_graph_module: torch.fx.GraphModule | None split_graph_module: torch.fx.GraphModule | None thunder_compiled_fns: list[Callable] | None submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction] @@ -466,8 +470,7 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule): Args: gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace. """ - new_graph = copy.deepcopy(gm.graph) - for n in new_graph.nodes: + for n in gm.graph.nodes: # replace the torch operator in "call_function" node if n.op == "call_function": assert isinstance(n.target, Callable) @@ -476,19 +479,18 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule): check( n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder" ) - with new_graph.inserting_before(n): - thunder_node = new_graph.call_function( + with gm.graph.inserting_before(n): + thunder_node = gm.graph.call_function( _torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs ) n.replace_all_uses_with(thunder_node) - new_graph.erase_node(n) + gm.graph.erase_node(n) else: if n.op == "call_module": raise RuntimeError( "Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs" ) - new_graph.lint() - gm.graph = new_graph + gm.graph.lint() recompile_graph(gm) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index da9129dcb..42299c149 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -5,6 +5,7 @@ import torch.fx import torch.nn as nn import torch.nn.functional as F +from looseversion import LooseVersion from thunder import dtypes from thunder.dynamo import ThunderCompiler @@ -445,6 +446,10 @@ def func(x): IS_WINDOWS, reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", ), + pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("2.6.0"), + reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275", + ), ), ) @requiresCUDA @@ -639,6 +644,35 @@ def f(x): compiled(x) +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("2.6.0"), + reason="The checkpoint function becomes a submodule of the module containing `tag_activation_checkpoint` in PyTorch 2.6.0.", +) +@requiresCUDA +def test_ThunderCompilerGraphBenchmarking_checkpoint(benchmark): + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(10, 20) + + def forward(self, x): + x = torch.utils.checkpoint.checkpoint(self.layer1, x) + x = F.relu(x) + return x + + x = torch.randn(5, 10).cuda().requires_grad_() + model = SimpleModel().cuda().train() + + exe_backend = ThunderCompiler() + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"inductor": torch.compile, "thunderfx": torch.compile(backend=exe_backend)} + ) + # Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in + # https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421 + jf = torch._dynamo.optimize(backend=backend)(model) + out = jf(x) + + @requiresCUDA @pytest.mark.filterwarnings(r"ignore:`torch\.cpu\.amp\.autocast\((.*?)\)` is deprecated.*:FutureWarning") def test_checkpoint_converter():