Skip to content

Commit

Permalink
Merge branch 'main' into thunderfx_tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Dec 12, 2024
2 parents 87737f6 + bb20d73 commit f9f6b23
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
23 changes: 14 additions & 9 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,18 +619,23 @@ def import_ctx(self):
# NOTE If the call ctx was specified directly, then no import is needed to call the function
import_ctx = {}
else:
from thunder.extend import AdHocExecutor

# BoundSymbols of Symbols without Python implementations (either because they
# have Python implementations or defined call ctxs) are assumed to need
# a module import to run properly
assert self.sym.module is not None # TODO: Is this a valid assumption?
module_name = self.sym.module.__name__
import_ctx = {module_name: self.sym.module}

# TODO Include the other modules on the path?
# Also includes the root module of this (potential) submodule
if "." in module_name:
root_name = module_name.split(".")[0]
import_ctx[root_name] = sys.modules[root_name]
if isinstance(self.sym.executor, AdHocExecutor):
import_ctx = {}
else:
assert self.sym.module is not None # TODO: Is this a valid assumption?
module_name = self.sym.module.__name__
import_ctx = {module_name: self.sym.module}

# TODO Include the other modules on the path?
# Also includes the root module of this (potential) submodule
if "." in module_name:
root_name = module_name.split(".")[0]
import_ctx[root_name] = sys.modules[root_name]

self._import_ctx.update(import_ctx)
return self._import_ctx
Expand Down
13 changes: 13 additions & 0 deletions thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ class NOTHING:
IS_WINDOWS = platform.system() == "Windows"


def _bitsandbytes_available():
if not package_available("bitsandbytes"):
return False
try:
import bitsandbytes
except (ImportError, RuntimeError):
return False
return True


BITSANDBYTES_AVAILABLE = _bitsandbytes_available()


def version_between(version: str, *, min_ver: str | None = None, max_ver: str | None = None):
v = packaging.version.parse(version)
if min_ver is not None and v < packaging.version.parse(min_ver):
Expand Down
7 changes: 2 additions & 5 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DynamoThunderExecutor,
_all_test_executors,
version_between,
BITSANDBYTES_AVAILABLE,
)
import thunder.tests.nanogpt_model as nanogpt_model
import thunder.tests.hf_bart_self_attn as hf_bart_self_attn
Expand Down Expand Up @@ -288,13 +289,9 @@ def dummy(*args):
version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"),
reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413",
)
@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available")
@requiresCUDA
def test_quantization():
try:
import bitsandbytes
except (ImportError, RuntimeError):
pytest.skip("bitsandbytes not found")

from thunder.tests import litgpt_model
from lightning.fabric.plugins import BitsandbytesPrecision

Expand Down
11 changes: 4 additions & 7 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import thunder
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform, nvtx_push, nvtx_pop
from thunder.tests.framework import requiresCUDA, version_between
from thunder.tests.framework import requiresCUDA, version_between, BITSANDBYTES_AVAILABLE


@requiresCUDA
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_materialization():
version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"),
reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413",
)
@pytest.mark.skipif(not package_available("bitsandbytes"), reason="`bitsandbytes` is not available")
@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available")
@requiresCUDA
def test_quantization_on_meta():
from thunder.transforms import MaterializationTransform
Expand Down Expand Up @@ -194,10 +194,7 @@ def test_quantization_on_meta():
version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"),
reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413",
)
@pytest.mark.skipif(
not package_available("bitsandbytes"),
reason="`bitsandbytes` is not available",
)
@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available")
@requiresCUDA
def test_nvfuser_cse():
with torch.device("cuda"):
Expand Down Expand Up @@ -305,7 +302,7 @@ def f(x):
version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"),
reason="https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1413",
)
@pytest.mark.skipif(not package_available("bitsandbytes"), reason="`bitsandbytes` is not available")
@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="`bitsandbytes` is not available")
@requiresCUDA
def test_materialization_init():
from thunder.transforms import MaterializationTransform
Expand Down

0 comments on commit f9f6b23

Please sign in to comment.