Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linear to be consumed by nvFuser by default #1371

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions notebooks/writing_a_trace_transform_cpu_offloading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -507,7 +507,7 @@
"model, args, kwargs = get_model_and_args()\n",
"\n",
"# Check against the vanilla `thunder.jit` model\n",
"expected = thunder.jit(model)(*args, **kwargs)\n",
"expected = thunder.jit(model, nv_enable_linear=False)(*args, **kwargs)\n",
"\n",
"grad_output = torch.randn_like(expected)\n",
"expected_grads = torch.autograd.grad(expected, model.parameters(), grad_output)\n",
Expand All @@ -519,7 +519,7 @@
" expected_cpu = expected.to(\"cpu\")\n",
" expected_grads_cpu = tree_map(lambda t: t.to(\"cpu\"), expected_grads)\n",
"\n",
"jmodel = thunder.jit(model, transforms=[CPUOffloading()])\n",
"jmodel = thunder.jit(model, nv_enable_linear=False, transforms=[CPUOffloading()])\n",
"\n",
"actual = jmodel(*args, **kwargs)\n",
"\n",
Expand Down Expand Up @@ -642,7 +642,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -656,8 +656,8 @@
"source": [
"model, args, kwargs = get_model_and_args()\n",
"\n",
"measurement_thunder = benchmark(thunder.jit(model), model, args, kwargs)\n",
"measurement_thunder_offload = benchmark(thunder.jit(model, transforms=[CPUOffloading()]), model, args, kwargs)\n",
"measurement_thunder = benchmark(thunder.jit(model, nv_enable_linear=False), model, args, kwargs)\n",
"measurement_thunder_offload = benchmark(thunder.jit(model, nv_enable_linear=False, transforms=[CPUOffloading()]), model, args, kwargs)\n",
"\n",
"del model, args, kwargs\n",
"\n",
Expand Down
4 changes: 4 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,10 @@ def map_inside_replacement(x: Any) -> None:

def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> bool:
enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser linear.")
# Enable linear by default if nvFuser version is 0.2.23 or later
# Because nvFuser 0.2.23 has a fix for https://github.com/NVIDIA/Fuser/issues/3369
if enable_linear is None and nvfuser_version() >= LooseVersion("0.2.23"):
enable_linear = True
if not enable_linear:
return False
# Verify linear inputs and bias (optional) are supported tensors.
Expand Down
5 changes: 4 additions & 1 deletion thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from functools import partial
import warnings
from looseversion import LooseVersion

import pytest
import torch
Expand All @@ -20,6 +21,7 @@
)
import thunder.tests.nanogpt_model as nanogpt_model
import thunder.tests.hf_bart_self_attn as hf_bart_self_attn
from thunder.executors.nvfuserex import nvfuser_version

#
# nanoGPT tests
Expand Down Expand Up @@ -557,4 +559,5 @@ def test_hf_llama():

top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols}
# changes this to fewer as needed, the goal is to not have too many fusions
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7
num_of_nvfusion = 4 if nvfuser_version() >= LooseVersion("0.2.23") else 7
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == num_of_nvfusion
1 change: 1 addition & 0 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def test_cse_rematerialization(executor, device, _):
disable_torch_autograd=True,
executors=executor.executors_list(),
nv_enable_bookend=False,
nv_enable_linear=False,
)
compiled_func(x, y)

Expand Down
Loading