diff --git a/build_tools/utils.py b/build_tools/utils.py index 964a445bc4..d846b87f22 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -301,7 +301,7 @@ def install_and_import(package): globals()[main_package] = importlib.import_module(main_package) -def uninstall_te_fw_packages(): +def uninstall_te_wheel_packages(): subprocess.check_call( [ sys.executable, @@ -309,6 +309,7 @@ def uninstall_te_fw_packages(): "pip", "uninstall", "-y", + "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_paddle", "transformer_engine_jax", diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index a0bcd80347..7d839958cb 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 602d99ed4d..7dedf2a761 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 1896fc4e42..7682a2b6aa 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -5,10 +5,11 @@ set -e PLATFORM=${1:-manylinux_2_28_x86_64} -BUILD_COMMON=${2:-true} -BUILD_JAX=${3:-true} +BUILD_METAPACKAGE=${2:-true} +BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} -BUILD_PADDLE=${5:-true} +BUILD_JAX=${5:-true} +BUILD_PADDLE=${6:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} @@ -20,12 +21,33 @@ cd /TransformerEngine git checkout $TARGET_BRANCH git submodule update --init --recursive +if $BUILD_METAPACKAGE ; then + cd /TransformerEngine + NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt + mv dist/* /wheelhouse/ +fi + if $BUILD_COMMON ; then + VERSION=`cat build_tools/VERSION.txt` + WHL_BASE="transformer_engine-${VERSION}" + + # Create the wheel. /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + + # Repack the wheel for cuda specific package, i.e. cu12. + /opt/python/cp38-cp38/bin/wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" + /opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE} + + # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" - mv dist/"$whl_name" /wheelhouse/"$whl_name_target" + whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" + rm -rf $WHL_BASE dist + mv *.whl /wheelhouse/"$whl_name_target" fi if $BUILD_PYTORCH ; then @@ -37,8 +59,8 @@ fi if $BUILD_JAX ; then cd /TransformerEngine/transformer_engine/jax - /opt/python/cp38-cp38/bin/pip install jax jaxlib - /opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib + /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi @@ -48,30 +70,30 @@ if $BUILD_PADDLE ; then dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 cd /TransformerEngine/transformer_engine/paddle - /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl + /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt - /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl + /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt - /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl + /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt - /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl + /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt - /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl + /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt - /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu mv dist/* /wheelhouse/ fi diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 109633495b..2c3b832933 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/jax -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/jax/test_sanity_import.py diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh index e2d6d38dd4..30fbb1df1f 100644 --- a/qa/L0_paddle_wheel/test.sh +++ b/qa/L0_paddle_wheel/test.sh @@ -6,15 +6,28 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel==0.44.0 pydantic + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel -pip install dist/* -cd transformer_engine/paddle -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle -export NVTE_RELEASE_BUILD=0 +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +pip install dist/*.whl --no-deps + +cd transformer_engine/paddle +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel pip install dist/* python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index e108e93cdb..fd8457c44b 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/pytorch -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/pytorch/test_sanity_import.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index fef48fd4b0..50394c33a9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -4,6 +4,10 @@ set -e +# pkg_resources is deprecated in setuptools 70+ and the packaging submodule +# has been removed from it. This is a temporary fix until upstream MLM fix. +pip install setuptools==69.5.1 + : ${TE_PATH:=/opt/transformerengine} pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py diff --git a/setup.py b/setup.py index 4e5359e9c0..0b0639aea6 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ get_frameworks, install_and_import, remove_dups, - uninstall_te_fw_packages, + uninstall_te_wheel_packages, ) frameworks = get_frameworks() @@ -106,46 +106,69 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if __name__ == "__main__": - # Dependencies - setup_requires, install_requires, test_requires = setup_requirements() - __version__ = te_version() - ext_modules = [setup_common_extension()] - if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - # Remove residual FW packages since compiling from source - # results in a single binary with FW extensions included. - uninstall_te_fw_packages() - if "pytorch" in frameworks: - from build_tools.pytorch import setup_pytorch_extension - - ext_modules.append( - setup_pytorch_extension( - "transformer_engine/pytorch/csrc", - current_file_path / "transformer_engine" / "pytorch" / "csrc", - current_file_path / "transformer_engine", + with open("README.rst", encoding="utf-8") as f: + long_description = f.read() + + # Settings for building top level empty package for dependency management. + if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): + assert bool( + int(os.getenv("NVTE_RELEASE_BUILD", "0")) + ), "NVTE_RELEASE_BUILD env must be set for metapackage build." + ext_modules = [] + cmdclass = {} + package_data = {} + include_package_data = False + setup_requires = [] + install_requires = ([f"transformer_engine_cu12=={__version__}"],) + extras_require = { + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], + "paddle": [f"transformer_engine_paddle=={__version__}"], + } + else: + setup_requires, install_requires, test_requires = setup_requirements() + ext_modules = [setup_common_extension()] + cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} + package_data = {"": ["VERSION.txt"]} + include_package_data = True + extras_require = {"test": test_requires} + + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + # Remove residual FW packages since compiling from source + # results in a single binary with FW extensions included. + uninstall_te_wheel_packages() + if "pytorch" in frameworks: + from build_tools.pytorch import setup_pytorch_extension + + ext_modules.append( + setup_pytorch_extension( + "transformer_engine/pytorch/csrc", + current_file_path / "transformer_engine" / "pytorch" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "jax" in frameworks: - from build_tools.jax import setup_jax_extension - - ext_modules.append( - setup_jax_extension( - "transformer_engine/jax/csrc", - current_file_path / "transformer_engine" / "jax" / "csrc", - current_file_path / "transformer_engine", + if "jax" in frameworks: + from build_tools.jax import setup_jax_extension + + ext_modules.append( + setup_jax_extension( + "transformer_engine/jax/csrc", + current_file_path / "transformer_engine" / "jax" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "paddle" in frameworks: - from build_tools.paddle import setup_paddle_extension - - ext_modules.append( - setup_paddle_extension( - "transformer_engine/paddle/csrc", - current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine", + if "paddle" in frameworks: + from build_tools.paddle import setup_paddle_extension + + ext_modules.append( + setup_paddle_extension( + "transformer_engine/paddle/csrc", + current_file_path / "transformer_engine" / "paddle" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) # Configure package setuptools.setup( @@ -158,13 +181,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: "transformer_engine/build_tools", ], ), - extras_require={ - "test": test_requires, - "pytorch": [f"transformer_engine_torch=={__version__}"], - "jax": [f"transformer_engine_jax=={__version__}"], - "paddle": [f"transformer_engine_paddle=={__version__}"], - }, + extras_require=extras_require, description="Transformer acceleration library", + long_description=long_description, + long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8, <3.13", @@ -178,6 +198,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: setup_requires=setup_requires, install_requires=install_requires, license_files=("LICENSE",), - include_package_data=True, - package_data={"": ["VERSION.txt"]}, + include_package_data=include_package_data, + package_data=package_data, ) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index f4eb2c419f..46cfa9176a 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -4,6 +4,7 @@ """FW agnostic user-end APIs""" +import sys import glob import sysconfig import subprocess @@ -15,6 +16,16 @@ import transformer_engine +def is_package_installed(package): + """Checks if a pip package is installed.""" + return ( + subprocess.run( + [sys.executable, "-m", "pip", "show", package], capture_output=True, check=False + ).returncode + == 0 + ) + + def get_te_path(): """Find Transformer Engine install path using pip""" return Path(transformer_engine.__path__[0]).parent diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 3200c8a019..05adbd624c 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -5,21 +5,50 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import ctypes +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_jax" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[jax]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[jax]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 62fa1fe626..50cf2186d6 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -6,9 +6,41 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging +from importlib.metadata import version + +from transformer_engine.common import is_package_installed + def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_paddle" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[paddle]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[paddle]==VERSION'", + module_name, + ) + from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 1c755491b0..89b20002a7 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -6,25 +6,54 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import importlib +import importlib.util import sys import torch +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_torch" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[pytorch]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[pytorch]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) - module_name = "transformer_engine_torch" spec = importlib.util.spec_from_file_location(module_name, so_path) solib = importlib.util.module_from_spec(spec) sys.modules[module_name] = solib