diff --git a/setup.py b/setup.py index e418cb95ff..6cee4690dc 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,18 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not found_pybind11(): setup_reqs.append("pybind11") + # Framework-specific requirements + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + if "pytorch" in frameworks: + install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) + test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) + if "jax" in frameworks: + install_reqs.extend(["jax", "flax>=0.7.1"]) + test_reqs.extend(["numpy", "praxis"]) + if "paddle" in frameworks: + install_reqs.append("paddlepaddle-gpu") + test_reqs.append("numpy") + return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]