Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 21, 2024
1 parent fae44b6 commit 521c77a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 111 deletions.
14 changes: 10 additions & 4 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,8 @@ def _rmse(a, b):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
@pytest.mark.parametrize("RoPE", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
Expand All @@ -1332,12 +1333,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):

logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm
dtype, config, True, qkv_format, input_layernorm, RoPE
)

logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm
dtype, config, False, qkv_format, input_layernorm, RoPE
)

tols = dict(atol=5e-1, rtol=5e-1)
Expand Down Expand Up @@ -1399,7 +1400,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
)


def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
Expand All @@ -1418,6 +1419,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
)

with fp8_model_init(enabled=fp8_mha):
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
Expand Down Expand Up @@ -1475,6 +1480,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
rotary_pos_emb=rotary_pos_emb,
)
out.backward(out_grad)

Expand Down
Loading

0 comments on commit 521c77a

Please sign in to comment.