diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 50bf38c001..2906241e11 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -619,18 +619,23 @@ def import_ctx(self): # NOTE If the call ctx was specified directly, then no import is needed to call the function import_ctx = {} else: + from thunder.extend import AdHocExecutor + # BoundSymbols of Symbols without Python implementations (either because they # have Python implementations or defined call ctxs) are assumed to need # a module import to run properly - assert self.sym.module is not None # TODO: Is this a valid assumption? - module_name = self.sym.module.__name__ - import_ctx = {module_name: self.sym.module} - - # TODO Include the other modules on the path? - # Also includes the root module of this (potential) submodule - if "." in module_name: - root_name = module_name.split(".")[0] - import_ctx[root_name] = sys.modules[root_name] + if isinstance(self.sym.executor, AdHocExecutor): + import_ctx = {} + else: + assert self.sym.module is not None # TODO: Is this a valid assumption? + module_name = self.sym.module.__name__ + import_ctx = {module_name: self.sym.module} + + # TODO Include the other modules on the path? + # Also includes the root module of this (potential) submodule + if "." in module_name: + root_name = module_name.split(".")[0] + import_ctx[root_name] = sys.modules[root_name] self._import_ctx.update(import_ctx) return self._import_ctx diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index a54cb0d841..313658900c 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -62,6 +62,19 @@ class NOTHING: IS_WINDOWS = platform.system() == "Windows" +def _bitsandbytes_available(): + if not package_available("bitsandbytes"): + return False + try: + import bitsandbytes + except (ImportError, RuntimeError): + return False + return True + + +BITSANDBYTES_AVAILABLE = _bitsandbytes_available() + + def version_between(version: str, *, min_ver: str | None = None, max_ver: str | None = None): v = packaging.version.parse(version) if min_ver is not None and v < packaging.version.parse(min_ver): diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 0cbef7299e..4fd40fb11d 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -17,6 +17,7 @@ DynamoThunderExecutor, _all_test_executors, version_between, + BITSANDBYTES_AVAILABLE, ) import thunder.tests.nanogpt_model as nanogpt_model import thunder.tests.hf_bart_self_attn as hf_bart_self_attn @@ -288,13 +289,9 @@ def dummy(*args): version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"), reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413", ) +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available") @requiresCUDA def test_quantization(): - try: - import bitsandbytes - except (ImportError, RuntimeError): - pytest.skip("bitsandbytes not found") - from thunder.tests import litgpt_model from lightning.fabric.plugins import BitsandbytesPrecision diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index 42a0ba0e6f..1f7ea58562 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -5,7 +5,7 @@ import thunder from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform, nvtx_push, nvtx_pop -from thunder.tests.framework import requiresCUDA, version_between +from thunder.tests.framework import requiresCUDA, version_between, BITSANDBYTES_AVAILABLE @requiresCUDA @@ -117,7 +117,7 @@ def test_materialization(): version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"), reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413", ) -@pytest.mark.skipif(not package_available("bitsandbytes"), reason="`bitsandbytes` is not available") +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available") @requiresCUDA def test_quantization_on_meta(): from thunder.transforms import MaterializationTransform @@ -194,10 +194,7 @@ def test_quantization_on_meta(): version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"), reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413", ) -@pytest.mark.skipif( - not package_available("bitsandbytes"), - reason="`bitsandbytes` is not available", -) +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available") @requiresCUDA def test_nvfuser_cse(): with torch.device("cuda"): @@ -305,7 +302,7 @@ def f(x): version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"), reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413", ) -@pytest.mark.skipif(not package_available("bitsandbytes"), reason="`bitsandbytes` is not available") +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available") @requiresCUDA def test_materialization_init(): from thunder.transforms import MaterializationTransform