Skip to content

Commit

Permalink
fix and update ut
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Nov 1, 2024
1 parent 7b0fe30 commit 723364e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
24 changes: 20 additions & 4 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,24 @@
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.common.recipe import DelayedScaling

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}


def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p"
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
):
"""Test DotProductAttention module with context parallelism"""

# args are passing as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
Expand Down Expand Up @@ -72,7 +80,7 @@ def run_dpa_with_cp(
cp_comm_sub_groups.append(sub_group)

if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)

# instantiate core attn module
core_attn = DotProductAttention(
Expand Down Expand Up @@ -201,7 +209,11 @@ def run_dpa_with_cp(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
)
out.backward(dout)
if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
out.backward(dout_fp8)
else:
out.backward(dout)

# run core_attn wit CP
q_, k_, v_, dout_, *rest = [
Expand Down Expand Up @@ -269,7 +281,11 @@ def run_dpa_with_cp(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
)
out_.backward(dout_)
if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
out_.backward(dout_fp8_)
else:
out_.backward(dout_)

for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x))
Expand Down
6 changes: 5 additions & 1 deletion tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("fp8_mha", [False, True])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
Expand Down Expand Up @@ -153,6 +154,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!")

subprocess.run(
get_bash_arguments(
Expand All @@ -162,6 +165,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha,
),
check=True,
)
19 changes: 9 additions & 10 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ def forward(
fused_attn_qkv_dtype = None
fused_attn_backend = None
amax_per_step = None
qkv_dtype = q.dtype
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
Expand Down Expand Up @@ -1882,11 +1883,7 @@ def forward(
batch_p2p_comm,
)

if (
not fp8
or is_input_fp8
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
):
if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
kv_inputs[i % 2] = p2p_comm_buffers[i]
else:
# KV exchange is in BF16/FP16, cast received KV in each step
Expand Down Expand Up @@ -2438,7 +2435,7 @@ def forward(
fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]

out_fp8 = None
out_f16 = out.to(q_fp8.dtype if fp8 and is_output_fp8 else q_f16.dtype)
out_f16 = out.to(qkv_dtype)
if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)

Expand All @@ -2449,7 +2446,7 @@ def forward(
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
dtype=qkv_dtype,
)
else:
out_ret = out_f16
Expand Down Expand Up @@ -3856,6 +3853,7 @@ def forward(
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"

qkv_dtype = q.dtype
fused_attn_backend = None
fused_attn_qkv_dtype = None
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
Expand Down Expand Up @@ -3978,7 +3976,7 @@ def forward(
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
dtype=qkv_dtype,
)
out = out_fp8._data
out_ret = out_fp8
Expand Down Expand Up @@ -4072,6 +4070,7 @@ def backward(ctx, dout):
fused_attn_backend = None
fused_attn_dqkv_dtype = None
fused_attn_qkv_dtype = None
dout_dtype = dout.dtype
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
Expand Down Expand Up @@ -4210,7 +4209,7 @@ def backward(ctx, dout):
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=dout_fp8.dtype,
dtype=dout_dtype,
)
for x in [dq, dk, dv]
]
Expand All @@ -4221,7 +4220,7 @@ def backward(ctx, dout):
ctx.fp8_meta["scaling_bwd"],
META_DQKV,
fp8_dtype_backward,
TE_DType[dout_f16.dtype],
TE_DType[dout_dtype],
)
for x in [dq, dk, dv]
]
Expand Down

0 comments on commit 723364e

Please sign in to comment.