Skip to content

Commit

Permalink
support graph-by-graph benchmarking for PyTorch native checkpointing (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Nov 19, 2024
1 parent f206afa commit 60f3ee1
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 12 deletions.
18 changes: 18 additions & 0 deletions thunder/dynamo/compiler_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import chain
from pytest_benchmark.fixture import BenchmarkFixture
from typing import TYPE_CHECKING
from looseversion import LooseVersion

import torch
from thunder.dynamo import ThunderCompiler
Expand Down Expand Up @@ -124,6 +125,23 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args):

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
split_module = super().__call__(gm, sample_args)

def has_checkpoint_node(g):
if g.find_nodes(op="call_function", target=torch.ops.higher_order.tag_activation_checkpoint):
return True
for n in g.nodes:
if n.op == "call_module" and has_checkpoint_node(getattr(g.owning_module, n.target).graph):
return True
return False

if LooseVersion(torch.__version__) < LooseVersion("2.6.0"):
# NOTE: PyTorch 2.6 changes the structure of GraphModule when using activation checkpointing.
# It's hard to retrieve the example input tensor for the GraphModule contains checkpoint operator before PyTorch 2.6
if has_checkpoint_node(split_module.graph):
raise RuntimeError(
"The benchmarking of the Torch activation checkpointing is only supported with PyTorch version 2.6 or later."
)

compiled_functions_to_submodule = {
v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items()
}
Expand Down
14 changes: 11 additions & 3 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import copy

import torch
from torch.fx.passes.split_module import split_module
Expand Down Expand Up @@ -131,9 +132,10 @@ def callback(node) -> int:
return partition_cnt

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
split_gm: torch.fx.GraphModule = split_module(
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)
split_gm = copy.deepcopy(original_split_gm)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions
Expand All @@ -142,6 +144,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
thunder_compiled_fns = []
submodule_to_compiled_fns = {}
for node in split_gm.graph.nodes:
node_name = node.name
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
Expand All @@ -150,13 +153,17 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn)
thunder_compiled_fns.append(jit_fn)
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER)
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
jit_fn, CompilerType.THUNDER
)
elif node.name.startswith("submod"): # For inductor
graph_module = getattr(split_gm, node.name)
jit_fn = torch_inductor(graph_module)
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR)
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
jit_fn, CompilerType.TORCH_INDUCTOR
)
else:
# Everything else is a glue code to call and pass outputs between the other partitions.
pass
Expand All @@ -166,6 +173,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:

return split_gm, SubgraphInfo(
gm,
original_split_gm,
split_gm,
thunder_compiled_fns,
submodule_to_compiled_fns,
Expand Down
20 changes: 11 additions & 9 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,21 @@ class SubgraphInfo:
Attributes:
original_graph_module: The original graph module.
split_graph_module: The graph module for the split subgraph.
original_split_graph_module: The original split graph module before any transformations are applied.
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols,
and before any submodules are compiled by Thunder.
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules.
thunder_compiled_fns: List of thunder optimized callables.
This could be :obj:`None` if there the graph module was not supported by thunder.
Look at the :attr:`split_reasons` for further information.
submodule_to_compiled_functions: Dict from subgraph to compiled function.
submodule_to_compiled_functions: Dict from subgraph in :attr:`original_split_graph_module` to compiled function.
This will be a dict with one pair in case the graph was not split.
split_reasons: List of reasons explaining why the subgraph was split.
Present only if there are was a split.
"""

original_graph_module: torch.fx.GraphModule
original_split_graph_module: torch.fx.GraphModule | None
split_graph_module: torch.fx.GraphModule | None
thunder_compiled_fns: list[Callable] | None
submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction]
Expand Down Expand Up @@ -466,8 +470,7 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule):
Args:
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace.
"""
new_graph = copy.deepcopy(gm.graph)
for n in new_graph.nodes:
for n in gm.graph.nodes:
# replace the torch operator in "call_function" node
if n.op == "call_function":
assert isinstance(n.target, Callable)
Expand All @@ -476,19 +479,18 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule):
check(
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder"
)
with new_graph.inserting_before(n):
thunder_node = new_graph.call_function(
with gm.graph.inserting_before(n):
thunder_node = gm.graph.call_function(
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs
)
n.replace_all_uses_with(thunder_node)
new_graph.erase_node(n)
gm.graph.erase_node(n)
else:
if n.op == "call_module":
raise RuntimeError(
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs"
)
new_graph.lint()
gm.graph = new_graph
gm.graph.lint()
recompile_graph(gm)


Expand Down
34 changes: 34 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.fx
import torch.nn as nn
import torch.nn.functional as F
from looseversion import LooseVersion

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
Expand Down Expand Up @@ -445,6 +446,10 @@ def func(x):
IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275",
),
),
)
@requiresCUDA
Expand Down Expand Up @@ -639,6 +644,35 @@ def f(x):
compiled(x)


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="The checkpoint function becomes a submodule of the module containing `tag_activation_checkpoint` in PyTorch 2.6.0.",
)
@requiresCUDA
def test_ThunderCompilerGraphBenchmarking_checkpoint(benchmark):
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20)

def forward(self, x):
x = torch.utils.checkpoint.checkpoint(self.layer1, x)
x = F.relu(x)
return x

x = torch.randn(5, 10).cuda().requires_grad_()
model = SimpleModel().cuda().train()

exe_backend = ThunderCompiler()
backend = ThunderCompilerGraphBenchmarking(
benchmark, executors={"inductor": torch.compile, "thunderfx": torch.compile(backend=exe_backend)}
)
# Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
jf = torch._dynamo.optimize(backend=backend)(model)
out = jf(x)


@requiresCUDA
@pytest.mark.filterwarnings(r"ignore:`torch\.cpu\.amp\.autocast\((.*?)\)` is deprecated.*:FutureWarning")
def test_checkpoint_converter():
Expand Down

0 comments on commit 60f3ee1

Please sign in to comment.