Skip to content

Commit

Permalink
[PyTorch] Drop FA as an installation requirement (#1226)
Browse files Browse the repository at this point in the history
* WIP: make FA2 optional

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP: fix logic

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor tweak

Signed-off-by: Charlene Yang <[email protected]>

* add L1 test to test all supported FA versions

Signed-off-by: Charlene Yang <[email protected]>

* update version to 2.1.1 and trim L1 tests

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update onnxruntime version

Signed-off-by: Charlene Yang <[email protected]>

* remove onnxruntime from L1 FA versions tests

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Oct 16, 2024
1 parent 43b9e1e commit 161b1d9
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 129 deletions.
23 changes: 23 additions & 0 deletions qa/L1_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==8.2.1
FA_versions=(2.1.1 2.3.0 2.4.0.post1 2.4.1 2.5.7 2.6.3 3.0.0b1)
for fa_version in "${FA_versions[@]}"
do
if [ "${fa_version}" \< "3.0.0" ]
then
pip install flash-attn==${fa_version}
else
pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
fi
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
install_reqs.extend(["torch"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
Expand Down
11 changes: 5 additions & 6 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
MultiheadAttention,
RotaryPositionEmbedding,
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
_flash_attn_3_plus,
_flash_attn_3_is_installed,
check_set_window_size,
AttentionParams,
_attention_backends,
Expand Down Expand Up @@ -1353,7 +1352,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]

if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1"
Expand Down Expand Up @@ -1381,7 +1380,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
Expand Down Expand Up @@ -1534,7 +1533,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"

if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
Expand All @@ -1561,7 +1560,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_plus and not is_training:
if _flash_attn_3_is_installed and not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
Expand Down
Loading

0 comments on commit 161b1d9

Please sign in to comment.