Skip to content

Commit

Permalink
Improvements for building wheels (#1148)
Browse files Browse the repository at this point in the history
* Improvements for wheels

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes for wheel build

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Move package finder to common

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* format

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* FIx

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix CI and distributed test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix paddle ci

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman committed Sep 3, 2024
1 parent 9437ceb commit 93f00a7
Show file tree
Hide file tree
Showing 13 changed files with 279 additions and 90 deletions.
3 changes: 2 additions & 1 deletion build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,15 @@ def install_and_import(package):
globals()[main_package] = importlib.import_module(main_package)


def uninstall_te_fw_packages():
def uninstall_te_wheel_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_paddle",
"transformer_engine_jax",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/wheel_utils/Dockerfile.aarch
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1

CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"]
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"]
2 changes: 1 addition & 1 deletion build_tools/wheel_utils/Dockerfile.x86
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1

CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"]
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"]
56 changes: 39 additions & 17 deletions build_tools/wheel_utils/build_wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
set -e

PLATFORM=${1:-manylinux_2_28_x86_64}
BUILD_COMMON=${2:-true}
BUILD_JAX=${3:-true}
BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_PADDLE=${5:-true}
BUILD_JAX=${5:-true}
BUILD_PADDLE=${6:-true}

export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-}
Expand All @@ -20,12 +21,33 @@ cd /TransformerEngine
git checkout $TARGET_BRANCH
git submodule update --init --recursive

if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
mv dist/* /wheelhouse/
fi

if $BUILD_COMMON ; then
VERSION=`cat build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"

# Create the wheel.
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt

# Repack the wheel for cuda specific package, i.e. cu12.
/opt/python/cp38-cp38/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
/opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE}

# Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
mv dist/"$whl_name" /wheelhouse/"$whl_name_target"
whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}"
rm -rf $WHL_BASE dist
mv *.whl /wheelhouse/"$whl_name_target"
fi

if $BUILD_PYTORCH ; then
Expand All @@ -37,8 +59,8 @@ fi

if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax
/opt/python/cp38-cp38/bin/pip install jax jaxlib
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi

Expand All @@ -48,30 +70,30 @@ if $BUILD_PADDLE ; then
dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64
cd /TransformerEngine/transformer_engine/paddle

/opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl
/opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt
/opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu

/opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl
/opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt
/opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu

/opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl
/opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt
/opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu

/opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl
/opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt
/opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu

/opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl
/opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt
/opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu

mv dist/* /wheelhouse/
fi
Expand Down
26 changes: 20 additions & 6 deletions qa/L0_jax_wheel/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,30 @@ set -e

: "${TE_PATH:=/opt/transformerengine}"

pip install wheel

cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax

VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"

# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel

cd transformer_engine/jax
python setup.py sdist
NVTE_RELEASE_BUILD=1 python setup.py sdist

export NVTE_RELEASE_BUILD=0
pip install dist/*
cd $TE_PATH
pip install dist/*
pip install dist/*.whl --no-deps

python $TE_PATH/tests/jax/test_sanity_import.py
27 changes: 20 additions & 7 deletions qa/L0_paddle_wheel/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,28 @@ set -e

: "${TE_PATH:=/opt/transformerengine}"

pip install wheel==0.44.0 pydantic

cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
pip install dist/*
cd transformer_engine/paddle
python setup.py bdist_wheel
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle

export NVTE_RELEASE_BUILD=0
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"

# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
pip install dist/*.whl --no-deps

cd transformer_engine/paddle
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
pip install dist/*

python $TE_PATH/tests/paddle/test_sanity_import.py
26 changes: 20 additions & 6 deletions qa/L0_pytorch_wheel/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,30 @@ set -e

: "${TE_PATH:=/opt/transformerengine}"

pip install wheel

cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch

VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"

# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel

cd transformer_engine/pytorch
python setup.py sdist
NVTE_RELEASE_BUILD=1 python setup.py sdist

export NVTE_RELEASE_BUILD=0
pip install dist/*
cd $TE_PATH
pip install dist/*
pip install dist/*.whl --no-deps

python $TE_PATH/tests/pytorch/test_sanity_import.py
4 changes: 4 additions & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

set -e

# pkg_resources is deprecated in setuptools 70+ and the packaging submodule
# has been removed from it. This is a temporary fix until upstream MLM fix.
pip install setuptools==69.5.1

: ${TE_PATH:=/opt/transformerengine}
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py

Expand Down
108 changes: 64 additions & 44 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_frameworks,
install_and_import,
remove_dups,
uninstall_te_fw_packages,
uninstall_te_wheel_packages,
)

frameworks = get_frameworks()
Expand Down Expand Up @@ -106,46 +106,69 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:


if __name__ == "__main__":
# Dependencies
setup_requires, install_requires, test_requires = setup_requirements()

__version__ = te_version()

ext_modules = [setup_common_extension()]
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_fw_packages()
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension

ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
with open("README.rst", encoding="utf-8") as f:
long_description = f.read()

# Settings for building top level empty package for dependency management.
if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
assert bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
ext_modules = []
cmdclass = {}
package_data = {}
include_package_data = False
setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
package_data = {"": ["VERSION.txt"]}
include_package_data = True
extras_require = {"test": test_requires}

if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_wheel_packages()
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension

ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
)
)
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension

ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine",
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension

ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine",
)
)
)
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension

ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine",
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension

ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine",
)
)
)

# Configure package
setuptools.setup(
Expand All @@ -158,13 +181,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
"transformer_engine/build_tools",
],
),
extras_require={
"test": test_requires,
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
},
extras_require=extras_require,
description="Transformer acceleration library",
long_description=long_description,
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8, <3.13",
Expand All @@ -178,6 +198,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
setup_requires=setup_requires,
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=True,
package_data={"": ["VERSION.txt"]},
include_package_data=include_package_data,
package_data=package_data,
)
Loading

0 comments on commit 93f00a7

Please sign in to comment.