diff --git a/notebooks/writing_a_trace_transform_cpu_offloading.ipynb b/notebooks/writing_a_trace_transform_cpu_offloading.ipynb index 86d8263769..ef10be3304 100644 --- a/notebooks/writing_a_trace_transform_cpu_offloading.ipynb +++ b/notebooks/writing_a_trace_transform_cpu_offloading.ipynb @@ -472,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -507,7 +507,7 @@ "model, args, kwargs = get_model_and_args()\n", "\n", "# Check against the vanilla `thunder.jit` model\n", - "expected = thunder.jit(model)(*args, **kwargs)\n", + "expected = thunder.jit(model, nv_enable_linear=False)(*args, **kwargs)\n", "\n", "grad_output = torch.randn_like(expected)\n", "expected_grads = torch.autograd.grad(expected, model.parameters(), grad_output)\n", @@ -519,7 +519,7 @@ " expected_cpu = expected.to(\"cpu\")\n", " expected_grads_cpu = tree_map(lambda t: t.to(\"cpu\"), expected_grads)\n", "\n", - "jmodel = thunder.jit(model, transforms=[CPUOffloading()])\n", + "jmodel = thunder.jit(model, nv_enable_linear=False, transforms=[CPUOffloading()])\n", "\n", "actual = jmodel(*args, **kwargs)\n", "\n", @@ -642,7 +642,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -656,8 +656,8 @@ "source": [ "model, args, kwargs = get_model_and_args()\n", "\n", - "measurement_thunder = benchmark(thunder.jit(model), model, args, kwargs)\n", - "measurement_thunder_offload = benchmark(thunder.jit(model, transforms=[CPUOffloading()]), model, args, kwargs)\n", + "measurement_thunder = benchmark(thunder.jit(model, nv_enable_linear=False), model, args, kwargs)\n", + "measurement_thunder_offload = benchmark(thunder.jit(model, nv_enable_linear=False, transforms=[CPUOffloading()]), model, args, kwargs)\n", "\n", "del model, args, kwargs\n", "\n", diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 4e5a204f8c..8fcd703a60 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2261,6 +2261,10 @@ def map_inside_replacement(x: Any) -> None: def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> bool: enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser linear.") + # Enable linear by default if nvFuser version is 0.2.23 or later + # Because nvFuser 0.2.23 has a fix for https://github.com/NVIDIA/Fuser/issues/3369 + if enable_linear is None and nvfuser_version() >= LooseVersion("0.2.23"): + enable_linear = True if not enable_linear: return False # Verify linear inputs and bias (optional) are supported tensors. diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 0cbef7299e..073958875d 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from functools import partial import warnings +from looseversion import LooseVersion import pytest import torch @@ -20,6 +21,7 @@ ) import thunder.tests.nanogpt_model as nanogpt_model import thunder.tests.hf_bart_self_attn as hf_bart_self_attn +from thunder.executors.nvfuserex import nvfuser_version # # nanoGPT tests @@ -557,4 +559,5 @@ def test_hf_llama(): top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols} # changes this to fewer as needed, the goal is to not have too many fusions - assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7 + num_of_nvfusion = 4 if nvfuser_version() >= LooseVersion("0.2.23") else 7 + assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == num_of_nvfusion diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 2f8cd0ed09..5681b00d21 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -346,6 +346,7 @@ def test_cse_rematerialization(executor, device, _): disable_torch_autograd=True, executors=executor.executors_list(), nv_enable_bookend=False, + nv_enable_linear=False, ) compiled_func(x, y)