Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements #1100

Merged
merged 24 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
59b99ca
fp8 mha with rope
yaox12 Aug 13, 2024
c46f82c
avoid index select in cast ops
yaox12 Aug 14, 2024
dafd73f
avoid index select in fused_attn_fwd
yaox12 Aug 14, 2024
0d2ff34
rename is_first_module_in_mha to fp8_output
yaox12 Aug 14, 2024
0e837c3
resolve comments
yaox12 Aug 15, 2024
33c3ed6
resolve comments
yaox12 Aug 15, 2024
13feabb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
ae856e4
move transpose to backward for fp8 input
yaox12 Aug 16, 2024
7e26d22
fix ut
yaox12 Aug 19, 2024
fae44b6
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 21, 2024
521c77a
resolve comments
yaox12 Aug 21, 2024
10c6961
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 21, 2024
dd30c2d
update argument list for CP
yaox12 Aug 21, 2024
a94b3ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
bf56399
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 26, 2024
400d526
fix for FA3
yaox12 Aug 26, 2024
b935e13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2024
9eca369
remove unnecessary copy of scale_inv
yaox12 Aug 26, 2024
e3b75db
skip fp8 dpa/mha tests when fa3 is not available
yaox12 Aug 27, 2024
46d428f
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 28, 2024
6b80dd6
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Aug 30, 2024
df6132f
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Sep 3, 2024
c017154
Merge branch 'main' into xiny/fp8_mha_with_rope
yaox12 Sep 4, 2024
f9da6d7
fix a merge bug
yaox12 Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,32 +1344,35 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training):
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]

if _flash_attn_3_plus and not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, is_training
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
)

atol = 5e-1
Expand Down Expand Up @@ -1410,7 +1413,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
)


def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training):
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
Expand All @@ -1429,6 +1432,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
)

with fp8_model_init(enabled=fp8_mha):
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
Expand Down Expand Up @@ -1489,6 +1496,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
rotary_pos_emb=rotary_pos_emb,
)
if is_training:
out.backward(out_grad)
Expand Down Expand Up @@ -1977,12 +1985,18 @@ def forward(
None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_s
META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
Expand Down
Loading
Loading