Skip to content

Commit

Permalink
Merge branch 'main' into bitsandbytes
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Dec 12, 2024
2 parents 01bb86b + f89ceca commit ba385cf
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 24 deletions.
6 changes: 2 additions & 4 deletions thunder/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from thunder.dynamo.compiler import ThunderCompiler
from thunder.dynamo.compiler import ThunderCompiler, thunderfx


__all__ = [
"ThunderCompiler",
]
__all__ = ["ThunderCompiler", "thunderfx"]
50 changes: 34 additions & 16 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,20 @@
from looseversion import LooseVersion
from typing import TYPE_CHECKING
import warnings
import inspect

import torch

from thunder.core.baseutils import run_once
from thunder.core.utils import safe_zip
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType
from thunder.dynamo.splitter import _splitter
from thunder.core.utils import check

if TYPE_CHECKING:
from thunder.dynamo.utils import SubgraphInfo
from os import PathLike


@run_once
def _warn_thunder_compiler():
warnings.warn(
"The ThunderCompiler is in active development and may not work as expected."
+ " Please report any issues you encounter to the Lightning Thunder team."
)
from collections.abc import Callable


class ThunderCompiler:
Expand All @@ -32,9 +27,7 @@ def __init__(self, **thunder_options):
function.
Keyword arguments:
thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Besides all the arguments to :func:`thunder.jit`,
it accepts ``torch_inductor_options`` which are passed to :func:`torch.compile` if part of the graph
is not supported by thunder.
thunder_options: a dictionary of options to pass to :func:`thunder.jit`.
Example:
>>> import torch
Expand All @@ -52,8 +45,6 @@ def __init__(self, **thunder_options):
"""
from thunder import jit

_warn_thunder_compiler()

if LooseVersion(torch.__version__) < LooseVersion("2.4.0"):
# NOTE: PyTorch 2.3 or lower has bug in `split_module` function used in splitter.
# See https://github.com/Lightning-AI/lightning-thunder/pull/1075#issuecomment-2324918409
Expand All @@ -67,11 +58,9 @@ def __init__(self, **thunder_options):
# Ref to the documentation of `SubgraphInfo` to know more about the information it contains.
self.subgraph_infos: list[SubgraphInfo] = []

torch_inductor_options = thunder_options.pop("torch_inductor_options", {})

self.thunder_options = thunder_options
self._thunder_jit = partial(jit, **thunder_options)
self._torch_compile = partial(torch.compile, **torch_inductor_options)
self._torch_compile = torch.compile

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
gm = remove_empty_autocast(gm)
Expand Down Expand Up @@ -127,3 +116,32 @@ def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytes
f"graph{graph_idx}_{cur_name}",
use_pytest_benchmark,
)


def thunderfx(fn: Callable, /, **kwargs) -> Callable:
"""Compiles a callable (function or model) by using Thunder as the backend of :func:`torch.compile`
Args:
fn: A :class:`~torch.nn.Module` or a function to compile.
Keyword Args:
**kwargs: a dictionary of options to pass to :func:`torch.compile` and :func:`thunder.jit`.
Returns:
The compiled callable
"""
import thunder

torch_compile_kwarg_names = inspect.getfullargspec(torch.compile).kwonlyargs
thunder_jit_kwarg_names = inspect.getfullargspec(thunder.jit).kwonlyargs
overlap = [kwarg_name for kwarg_name in thunder_jit_kwarg_names if kwarg_name in torch_compile_kwarg_names]
check(
not overlap,
lambda: f"There are overlapping kwargs between thunder.jit and torch.compile: {overlap}",
ValueError,
)

torch_compile_options = {k: v for k, v in kwargs.items() if k in torch_compile_kwarg_names}
thunder_options = {k: v for k, v in kwargs.items() if k not in torch_compile_kwarg_names}

backend = ThunderCompiler(**thunder_options)
compiled = torch.compile(fn, backend=backend, **torch_compile_options)
compiled._backend = backend
return compiled
21 changes: 18 additions & 3 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,20 @@ def reproducer(
if use_pytest_benchmark:
code_str += f"""import pytest
# NOTE: The reproducer function has already been processed by TorchDynamo.
# If we let it go through TorchDynamo again, it could be segmented further.
# To avoid this, we directly use Inductor here.
# See issue https://github.com/Lightning-AI/lightning-thunder/issues/1521
def torch_inductor(fn, inputs):
from torch._inductor import compile as inductor_compile
from torch.fx import symbolic_trace
fx_graph = symbolic_trace(fn)
return inductor_compile(fx_graph, inputs)
bench_executors_dict = {{}}
bench_executors_dict["thunder"]=partial(thunder.jit, {thunder_options_str})
bench_executors_dict["torch.compile"]=torch.compile
bench_executors_dict["dynamo_eager"]=partial(torch.compile, backend="eager")
bench_executors_dict["torch_inductor"]=torch_inductor
bench_executors_dict["eager"]=None
"""
if has_cuda_args:
Expand Down Expand Up @@ -812,7 +822,12 @@ def reproducer(
else:
func_str = f"""{func_str}
mod = DynamoModule()
compiled = mod if executor == None else executor(mod)
if executor == None:
compiled = mod
elif executor == torch_inductor:
compiled = executor(mod, inputs)
else:
compiled = executor(mod)
"""
if not has_cuda_args:
func_str += f"""benchmark(compiled, *inputs)"""
Expand Down
24 changes: 23 additions & 1 deletion thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from looseversion import LooseVersion

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
from thunder.dynamo import ThunderCompiler, thunderfx
from thunder.dynamo.utils import CompilerType
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking
from thunder import last_traces
Expand Down Expand Up @@ -930,3 +930,25 @@ def check(file_name, cmd):
cmd = "pytest" if use_pytest_benchmark else "python"
for fname in [s1, s2, s3]:
check(fname, cmd)


@requiresCUDA
def test_thunderfx():
def foo(x):
return torch.sin(x) + torch.cos(x)

x = torch.randn(4, 4, device="cuda", requires_grad=True)
cfoo = thunderfx(foo)
cfoo(x)
thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns
assert len(thunder_compiled_fns) == 1
assert last_traces(thunder_compiled_fns[0])

from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

cfoo = thunderfx(foo, dynamic=True, transforms=[NvtxProfileTransform()])
cfoo(x)
thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns
assert len(thunder_compiled_fns) == 1
trc = last_traces(thunder_compiled_fns[-1])[-1]
assert any(bsym.sym.id == "nvtx_range_push" for bsym in trc.bound_symbols)

0 comments on commit ba385cf

Please sign in to comment.