Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements #1100

Merged
merged 24 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
59b99ca
fp8 mha with rope
yaox12 Aug 13, 2024
c46f82c
avoid index select in cast ops
yaox12 Aug 14, 2024
dafd73f
avoid index select in fused_attn_fwd
yaox12 Aug 14, 2024
0d2ff34
rename is_first_module_in_mha to fp8_output
yaox12 Aug 14, 2024
0e837c3
resolve comments
yaox12 Aug 15, 2024
33c3ed6
resolve comments
yaox12 Aug 15, 2024
13feabb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
ae856e4
move transpose to backward for fp8 input
yaox12 Aug 16, 2024
7e26d22
fix ut
yaox12 Aug 19, 2024
fae44b6
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 21, 2024
521c77a
resolve comments
yaox12 Aug 21, 2024
10c6961
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 21, 2024
dd30c2d
update argument list for CP
yaox12 Aug 21, 2024
a94b3ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
bf56399
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 26, 2024
400d526
fix for FA3
yaox12 Aug 26, 2024
b935e13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2024
9eca369
remove unnecessary copy of scale_inv
yaox12 Aug 26, 2024
e3b75db
skip fp8 dpa/mha tests when fa3 is not available
yaox12 Aug 27, 2024
46d428f
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 28, 2024
6b80dd6
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 30, 2024
df6132f
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Sep 3, 2024
c017154
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Sep 4, 2024
f9da6d7
fix a merge bug
yaox12 Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,41 +1347,35 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]

global _attention_backends
if not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE
)

tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
)

atol = 5e-1
Expand Down Expand Up @@ -1422,7 +1416,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
)


def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE):
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
get_fp8_te_dtype,
get_fp8_torch_dtype,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
Expand Down Expand Up @@ -381,6 +385,11 @@ def get_attention_backend(

# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and not _flash_attn_3_plus:
logger.debug(
"Disabling FlashAttention as FlashAttention 3 is not available for FP8 DPA/FP8 MHA."
yaox12 marked this conversation as resolved.
Show resolved Hide resolved
)
use_flash_attention = False
if use_flash_attention and is_training:
logger.debug("Disabling FlashAttention as it does not support FP8 training")
use_flash_attention = False
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def forward(
fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor:
is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
yaox12 marked this conversation as resolved.
Show resolved Hide resolved

# Make sure input dimensions are compatible
in_features = weight.shape[-1]
Expand Down
Loading