From 26c8fcc92dfd43ea19e308a97062211aab711066 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:17:03 -0700 Subject: [PATCH] Add FP8 support to CP implementation with KV P2P (#1114) * add window_size to AttnFuncWithCP Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo for cudnn thd Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo to AttnFuncWithCP Signed-off-by: Xiaowei Ren * fix seq_offsets calculation of cudnn thd Signed-off-by: Xiaowei Ren * remove a thd assert Signed-off-by: Xiaowei Ren * fix bias for thd test Signed-off-by: Xiaowei Ren * add thd test for cudnn FA with CP Signed-off-by: Xiaowei Ren * skip GQA/MQA test for cuDNN THD Signed-off-by: Xiaowei Ren * make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1 Signed-off-by: Xiaowei Ren * fix seq_offsets inputs Signed-off-by: Xiaowei Ren * remove two comments Signed-off-by: Xiaowei Ren * fix attn mask type for cudnn thd with cp Signed-off-by: Xiaowei Ren * fix attn_mask_type check Signed-off-by: Xiaowei Ren * fix attn_mask_type for cudnn fa with thd Signed-off-by: Xiaowei Ren * fix a typo Signed-off-by: Xiaowei Ren * fix out dout in bwd Signed-off-by: Xiaowei Ren * assert cudnn+thd does not support attn bias Signed-off-by: Xiaowei Ren * check if attn_mask_type has padding Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * change cp test batch size to 2 Signed-off-by: Xiaowei Ren * fix code format Signed-off-by: Xiaowei Ren * fix two assert info Signed-off-by: Xiaowei Ren * fix assert comment Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * assert swa+CP cannot work with thd format Signed-off-by: Xiaowei Ren * add a new CP function for swa Signed-off-by: Xiaowei Ren * add a missing dgrads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * add draft fwd function for swa+cp Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * enable flash attention for swa+cp Signed-off-by: Xiaowei Ren * remove an assert of swa+cp Signed-off-by: Xiaowei Ren * call SWAFuncWithCP for swa+cp Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * use 2hd layout Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by: Xiaowei Ren * add a code comment Signed-off-by: Xiaowei Ren * tensor shape bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by: Xiaowei Ren * add function to compute cu_seqlens of a cp rank Signed-off-by: Xiaowei Ren * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix FlashAttention output sequence length Signed-off-by: Xiaowei Ren * fix cu_seqlens_kv_per_step calculation Signed-off-by: Xiaowei Ren * zero dQKV for ending padded tokens Signed-off-by: Xiaowei Ren * zero dQKV tensors of FlashAttention Signed-off-by: Xiaowei Ren * fix softmax_lse correction Signed-off-by: Xiaowei Ren * remove padded tokens of KV to save comounication Signed-off-by: Xiaowei Ren * do not need to zero dkv for FlashAttention any mroe Signed-off-by: Xiaowei Ren * zero out tensors Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix CP unit test Signed-off-by: Xiaowei Ren * fix kv shape of cp test with thd format Signed-off-by: Xiaowei Ren * update cp unit test Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add simple code framework Signed-off-by: Xiaowei Ren * try not to have a separate CP function for SWA Signed-off-by: Xiaowei Ren * backup some code change Signed-off-by: Xiaowei Ren * back up code Signed-off-by: Xiaowei Ren * clean up fwd implementation of SWAFuncWithCP Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * fix assert info Signed-off-by: Xiaowei Ren * reduce kv chunk concat overheads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * make AttnFuncWithCP and SWAFuncWithCP have same API Signed-off-by: Xiaowei Ren * add a docstring Signed-off-by: Xiaowei Ren * preliminary implementation of SWAFuncWithCP forward seems working Signed-off-by: Xiaowei Ren * fix output shape of SWAFuncWithCP Signed-off-by: Xiaowei Ren * code refactoring for FlashAttention and add a code placeholder for bwd Signed-off-by: Xiaowei Ren * use gather_along_first_dim Signed-off-by: Xiaowei Ren * finish the preliminary implementation of bwd Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix assert condition Signed-off-by: Xiaowei Ren * add draft implementation of SWA+CP with FusedAttention Signed-off-by: Xiaowei Ren * fix attention mask type of swa+cp Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * add qkv_layout Signed-off-by: Xiaowei Ren * add missing window_size argument Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * fix kv shape of swa+cp Signed-off-by: Xiaowei Ren * bug and typo fix Signed-off-by: Xiaowei Ren * fix dout shape Signed-off-by: Xiaowei Ren * add multi stream in fwd of swa+cp Signed-off-by: Xiaowei Ren * save chunk_ids_to_kv_ag in fwd Signed-off-by: Xiaowei Ren * add multi stream in bwd of swa+cp Signed-off-by: Xiaowei Ren * minor fix to cp stream sync Signed-off-by: Xiaowei Ren * rename AttnFuncWithCP Signed-off-by: Xiaowei Ren * check if window size is None Signed-off-by: Xiaowei Ren * fix docstring of AttnFuncWithCP Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * add env var for users to choose KV ag or KV p2p Signed-off-by: Xiaowei Ren * update cp tests Signed-off-by: Xiaowei Ren * fix window size in cp unit test Signed-off-by: Xiaowei Ren * fix pytest skip messages Signed-off-by: Xiaowei Ren * add cp_comm_type into API Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code cleaning Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add deterministic konb in cuDNN fused attn backend Signed-off-by: Xiaowei Ren * pass fp8 and fp8_meta to attn_func_with_cp Signed-off-by: Xiaowei Ren * assert only Fused Attn can support FP8+CP Signed-off-by: Xiaowei Ren * remove redundant assert Signed-off-by: Xiaowei Ren * add a fwd draft implementation of FP8 + CP Signed-off-by: Xiaowei Ren * save fp8 and fp8_meta Signed-off-by: Xiaowei Ren * assert sequence length divisible requirements Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove a redundant qkv_layout compute Signed-off-by: Xiaowei Ren * if condition change Signed-off-by: Xiaowei Ren * some typo fix Signed-off-by: Xiaowei Ren * add support table of context parallelism Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo and code format fix Signed-off-by: Xiaowei Ren * do not print multiple disabling messages Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix aux_ctx_tensors of FP8 Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * fix device in torch.arange and adjust code for the PR of MLA Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * commit code change for FP8+CP Signed-off-by: Xiaowei Ren * commit more code change for FP8+CP Signed-off-by: Xiaowei Ren * commit more fp8 code for FP8+CP Signed-off-by: Xiaowei Ren * bug fixes Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * cast merged CP results from FP32 to BF16 Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix softmax_lse Signed-off-by: Xiaowei Ren * fix some bugs of FP8 dkv exchange Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * add FP8 unit test Signed-off-by: Xiaowei Ren * fix typos and clean asserts Signed-off-by: Xiaowei Ren * fix get_p2p_comm_info Signed-off-by: Xiaowei Ren * fix dkv p2p exchange Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * change FP8 dkv P2P to A2A Signed-off-by: Xiaowei Ren * add FP8+CP unit test Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * assert amax reduction is needed for FP8+CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove duplicated code Signed-off-by: Xiaowei Ren * destroy process group in CP unit test Signed-off-by: Xiaowei Ren * remove interval from fp8_recipe because it has been deprecated Signed-off-by: Xiaowei Ren * try to fix the failed CP test with the latest CI pipeline Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant f before string Signed-off-by: Xiaowei Ren * change META_O_CP Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xiaowei Ren --- .../fused_attn/run_fused_attn_with_cp.py | 147 ++-- .../fused_attn/test_fused_attn_with_cp.py | 12 +- transformer_engine/pytorch/attention.py | 696 ++++++++++++------ 3 files changed, 592 insertions(+), 263 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 2433a8a09d..6c775fb127 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -2,15 +2,18 @@ # # See LICENSE for license information. -import os, sys +import os, sys, logging +from contextlib import nullcontext import torch import torch.distributed as dist from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling -dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} +dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} def run_dpa_with_cp( @@ -57,6 +60,9 @@ def run_dpa_with_cp( assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if dtype == "fp8": + fp8_recipe = DelayedScaling(fp8_dpa=True) + # instantiate core attn module core_attn = DotProductAttention( config.num_heads, @@ -171,18 +177,27 @@ def run_dpa_with_cp( # run core_attn without CP for x in [q, k, v]: x.requires_grad = True - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], - ) - out.backward(dout) + + if dtype == "fp8": + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out.backward(dout) # run core_attn wit CP q_, k_, v_, dout_, *rest = [ @@ -226,31 +241,34 @@ def run_dpa_with_cp( core_attn.set_context_parallel_group( cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type ) - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], - ) - out_.backward(dout_) + + if dtype == "fp8": + core_attn.reset_fp8_meta_tensors() + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) assert torch.all(~torch.isinf(x)) # compare results with and without CP - tols = dict(atol=5e-3, rtol=5e-3) - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -309,32 +327,55 @@ def run_dpa_with_cp( else: assert False, f"{qkv_format} is an unsupported qkv_format!" + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + tols = dict(atol=2.5e-2, rtol=2.5e-2) + else: + tols = dict(atol=3.5e-2, rtol=3.5e-2) + elif dtype == "fp16": + tols = dict(atol=5e-3, rtol=5e-3) + elif dtype == "fp8": + tols = dict(atol=5e-1, rtol=5e-1) + rmse_tol = 0.1 + else: + assert False, f"{dtype} is an unsupported dtype!" + + def _rmse(a, b): + return torch.sqrt((a - b).square().mean()).item() + + def _error(a, b): + if dtype != "fp8": + torch.testing.assert_close(a, b, **tols) + else: + try: + torch.testing.assert_close(a, b, **tols) + except Exception as e: + logging.debug(e) + + rmse = _rmse(a, b) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert ( + rmse < rmse_tol * rmse_range + ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + if qkv_format == "bshd": - torch.testing.assert_close(out_[:, 0], out[:, 0], **tols) - torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols) - torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols) - torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols) - torch.testing.assert_close(out_[:, 1], out[:, 1], **tols) - torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols) - torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols) - torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[:, 0], b[:, 0]) + _error(a[:, 1], b[:, 1]) elif qkv_format == "sbhd": - torch.testing.assert_close(out_[0], out[0], **tols) - torch.testing.assert_close(dq_[0], dq[0], **tols) - torch.testing.assert_close(dk_[0], dk[0], **tols) - torch.testing.assert_close(dv_[0], dv[0], **tols) - torch.testing.assert_close(out_[1], out[1], **tols) - torch.testing.assert_close(dq_[1], dq[1], **tols) - torch.testing.assert_close(dk_[1], dk[1], **tols) - torch.testing.assert_close(dv_[1], dv[1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[0], b[0]) + _error(a[1], b[1]) elif qkv_format == "thd": - torch.testing.assert_close(out_, out, **tols) - torch.testing.assert_close(dq_, dq, **tols) - torch.testing.assert_close(dk_, dk, **tols) - torch.testing.assert_close(dv_, dv, **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a, b) else: assert False, f"{qkv_format} is an unsupported qkv_format!" + dist.destroy_process_group() + def main(**kwargs): run_dpa_with_cp(**kwargs) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 0074d18cec..82875e2791 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -90,7 +90,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) @@ -121,8 +121,16 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): ) if config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip( - f"Fused attention does not support sliding window attention + context parallelism yet!" + "Fused attention does not support sliding window attention + context parallelism yet!" + ) + if cp_comm_type == "all_gather" and dtype == "fp8": + pytest.skip( + "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" ) + if dtype == "fp8" and qkv_format == "thd": + pytest.skip("FP8 attention cannot work with THD format yet!") + if dtype == "fp8" and config.attn_bias_type != "no_bias": + pytest.skip("FP8 attention cannot work with bias yet!") subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 71bc15fdad..8fac4778c8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -95,6 +95,9 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -654,18 +657,6 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None - if ( - use_fused_attention - and context_parallel - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "context parallellism", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None if ( use_fused_attention and window_size is not None @@ -1322,6 +1313,8 @@ def forward( attn_bias, deterministic, use_fused_attention, + fp8, + fp8_meta, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -1407,6 +1400,43 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + else: + q_f16, k_f16, v_f16 = q, k, v + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + p2p_comm_buffers = [None for _ in range(cp_size)] if use_fused_attention and qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) @@ -1433,7 +1463,23 @@ def forward( batch_p2p_comm, ) - kv_inputs[i % 2] = p2p_comm_buffers[i] + if ( + not fp8 + or fp8_meta["recipe"].fp8_mha + or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ): + kv_inputs[i % 2] = p2p_comm_buffers[i] + else: + # KV exchange is in BF16/FP16, cast received KV in each step + kv_inputs[i % 2] = cast_to_fp8( + p2p_comm_buffers[i], + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + ) + if fp8 and use_fused_attention: + fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] if causal: if i == 0: if pad_between_seqs_q: @@ -1474,38 +1520,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1572,42 +1620,44 @@ def forward( if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv // 2, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None + if cu_seqlens_kv_padded is None + else cu_seqlens_kv_padded // 2 + ), + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1693,42 +1743,44 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q // 2, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=( + None + if cu_seqlens_q_padded is None + else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: if qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] @@ -1795,38 +1847,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, sq, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1866,8 +1920,16 @@ def forward( softmax_lse_per_step[i - 1].squeeze_(-1) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): + if fp8: + out_per_step[i - 1] = cast_from_fp8( + out_per_step[i - 1], + fp8_meta["scaling_fwd"], + META_O_CP, + fp8_dtype_forward, + TE_DType[torch.float32], + ) if i == 1: - out = torch.zeros_like(q) + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] @@ -1951,13 +2013,55 @@ def forward( else: out = out.view(-1, *out.shape[-2:]) + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] + fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] + + out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) + if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) + + if fp8 and fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor( + data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) + else: + out_ret = out_f16 + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, kv_save, out_save = q, kv, out_fp8 + fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() + fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + elif fp8 and fp8_meta["recipe"].fp8_mha: + kv_fp8 = Float8Tensor( + data=kv, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=k_fp8.dtype, + ) + q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + else: + q_save, kv_save, out_save = q_f16, kv, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + ctx.save_for_backward( - q, - kv, - out, + q_save, + kv_save, + out_save, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, + fp8_fwd_scales, + fp8_fwd_scale_invs, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, *rng_states, @@ -1976,7 +2080,9 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - return out + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + return out_ret @staticmethod def backward(ctx, dout): @@ -1987,10 +2093,11 @@ def backward(ctx, dout): batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2] - rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3] - attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4] + (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] + cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -2025,22 +2132,60 @@ def backward(ctx, dout): if ctx.use_fused_attention: # [b, np, sq//2] -> [b, np, sq//2, 1] softmax_lse_.unsqueeze_(-1) - if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) + + if ctx.fp8: + if ctx.use_fused_attention: + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_qkv_dtype = fp8_dtype_backward + fused_attn_dqkv_dtype = fp8_dtype_backward + fused_attn_backend = FusedAttnBackend["FP8"] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) + dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) + dkv_fp8_ = torch.empty_like(dkv_fp8) + dout_dtype = dout.dtype + if ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + dout = dout._data + else: + dout = cast_to_fp8( + dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] + fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] + fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] + fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] + fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] + fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] + fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]] + dq = torch.empty_like(q) + if ctx.qkv_format == "thd" and causal: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + p2p_comm_buffers = [ + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + ] + p2p_comm_buffers[0][0].copy_(kv) + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + out = out.view(*q.shape) dout = dout.view(*q.shape) - # Flash Attn outputs - dq = torch.empty_like(q) - if ctx.qkv_format == "thd" and causal: - dq[cu_seqlens_q_padded[-1] :].fill_(0) - - p2p_comm_buffers = [ - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - ] - p2p_comm_buffers[0][0].copy_(kv) send_recv_reqs = [] fa_optional_backward_kwargs = {} @@ -2056,18 +2201,40 @@ def backward(ctx, dout): send_tensor = p2p_comm_buffers[i % 2] recv_tensor = p2p_comm_buffers[(i + 1) % 2] - if i == 0: - send_tensor = send_tensor[0] - recv_tensor = recv_tensor[0] - if i == (cp_size - 1): - send_tensor = send_tensor[1] - recv_tensor = recv_tensor[1] - - send_recv_reqs = flash_attn_p2p_communicate( - rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm - ) + if ctx.fp8: + if i < cp_size - 1: + send_recv_reqs = flash_attn_p2p_communicate( + rank, + send_tensor[0], + send_dst, + recv_tensor[0], + recv_src, + ctx.cp_group, + batch_p2p_comm, + ) + else: + dkv_a2a_req = torch.distributed.all_to_all_single( + dkv_fp8, + dkv_fp8_, + group=ctx.cp_group, + async_op=True, + ) + send_recv_reqs = [dkv_a2a_req] + else: + if i == 0: + send_tensor = send_tensor[0] + recv_tensor = recv_tensor[0] + if i == (cp_size - 1): + send_tensor = send_tensor[1] + recv_tensor = recv_tensor[1] + send_recv_reqs = flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm + ) kv = p2p_comm_buffers[i % 2][0] + if ctx.fp8 and ctx.use_fused_attention: + fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -2090,7 +2257,14 @@ def backward(ctx, dout): dout_ = dout.view(-1, *dout.shape[-3:]) elif ctx.qkv_format == "thd": q_, kv_, out_, dout_ = q, kv, out, dout - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2103,10 +2277,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -2114,6 +2288,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -2169,7 +2345,14 @@ def backward(ctx, dout): q_, out_, dout_ = q, out, dout # [2, t, np, hn] -> [2, t/2, np, hn] kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2182,10 +2365,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 @@ -2195,6 +2378,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -2256,7 +2441,14 @@ def backward(ctx, dout): out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) kv_ = kv - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2269,10 +2461,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=( None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 ), @@ -2282,6 +2474,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: if ctx.qkv_format == "thd": @@ -2325,7 +2519,10 @@ def backward(ctx, dout): ) else: if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2338,10 +2535,10 @@ def backward(ctx, dout): kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], out, dout, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -2349,6 +2546,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, sq, np, hn] -> [b*sq, np, hn] @@ -2383,6 +2582,8 @@ def backward(ctx, dout): **fa_optional_backward_kwargs, ) + if ctx.fp8: + dq = dq_fp8[(rank + i + 1) % cp_size] if i >= (cp_size - rank - 1) or not causal: # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal # [b*sq, np, hn] -> [b, sq, np, hn] if not causal @@ -2395,7 +2596,17 @@ def backward(ctx, dout): # [b*sq//2, np, hn] -> [sq//2, b, np, hn] dq_ = dq_.view(-1, *dq.shape[-3:]) - if causal: + if ctx.fp8: + if i >= (cp_size - rank - 1) or not causal: + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].fill_(0) + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[0].fill_(0) + dq[1].copy_(dq_) + elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) elif i == (cp_size - rank - 1): @@ -2450,7 +2661,13 @@ def backward(ctx, dout): for req in send_recv_reqs: req.wait() - dkv = p2p_comm_buffers[(i + 1) % 2][1] + if ctx.fp8: + if i < cp_size - 1: + dkv = dkv_fp8_[(rank + i + 1) % cp_size] + else: + dkv = dkv_fp8[(rank + i + 1) % cp_size] + else: + dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) if ctx.qkv_format in ["bshd", "sbhd"]: @@ -2469,7 +2686,17 @@ def backward(ctx, dout): # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal dkv_ = dkv_.view(*dkv.shape) - if causal: + if ctx.fp8: + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + dkv[:, 1, ...].fill_(0) + else: + dkv.copy_(dkv_) + elif causal: if i == (cp_size - 1): if rank == 0: if ctx.qkv_format == "bshd": @@ -2507,6 +2734,26 @@ def backward(ctx, dout): else: dkv.add_(dkv_) + if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0] + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1] + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dq, dkv = [ + cast_from_fp8( + x, + ctx.fp8_meta["scaling_bwd"], + META_DQKV_CP, + fp8_dtype_backward, + TE_DType[torch.float32], + ) + for x in [dq_fp8, dkv_fp8] + ] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + if causal: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] @@ -2527,6 +2774,25 @@ def backward(ctx, dout): dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv = dkv_ + if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + dq, dkv = [ + cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) + for x in [dq, dkv] + ] + dq, dk, dv = [ + Float8Tensor( + data=x, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=dout_dtype, + ) + for x in [dq, dkv[0], dkv[1]] + ] + else: + dk, dv = dkv[0], dkv[1] + if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) @@ -2534,8 +2800,8 @@ def backward(ctx, dout): return ( None, dq, - dkv[0], - dkv[1], + dk, + dv, None, None, None, @@ -2553,12 +2819,14 @@ def backward(ctx, dout): attn_dbias, None, None, + None, + None, ) -@jit_fuser +@torch.compile def get_seq_chunk_ids_to_all_gathered_kv( - local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device ): """Compute sequence chunk ids to the all-gathered KV.""" seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv @@ -2569,7 +2837,7 @@ def get_seq_chunk_ids_to_all_gathered_kv( local_chunk_id - num_chunks + 1, local_chunk_id + 1, dtype=torch.int32, - device="cuda", + device=device, ) chunk_ids_to_all_gathered_kv = torch.where( chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 @@ -2683,6 +2951,7 @@ def forward( if (window_size is None or window_size[0] == -1) else window_size[0] ), + k.device, ) chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag num_kv_chunks = chunk_ids_to_kv_ag.numel() @@ -3029,6 +3298,8 @@ def attn_forward_func_with_cp( deterministic=False, use_fused_attention=False, window_size=None, + fp8=False, + fp8_meta=None, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -3109,6 +3380,8 @@ def attn_forward_func_with_cp( attn_bias, deterministic, use_fused_attention, + fp8, + fp8_meta, ) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -5638,9 +5911,21 @@ def forward( and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen) ) + if fp8: + assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( + f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" + " is required for FP8 attention!" + ) + assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" + assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism with" + " FP8!" + ) + if context_parallel: assert ( - fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8 + or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen ), f"{fused_attention_backend} does not work with context parallelism!" assert core_attention_bias_type not in [ "alibi" @@ -5670,19 +5955,14 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=core_attention_bias_type, attn_bias=core_attention_bias, + deterministic=self.deterministic, use_fused_attention=True, window_size=window_size, + fp8=fp8, + fp8_meta=fp8_meta, ) else: with self.attention_dropout_ctx(): - if fp8: - assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( - f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" - " is required for FP8 attention!" - ) - assert ( - fp8_meta is not None - ), "FP8 metadata fp8_meta is required for FP8 attention!" output = FusedAttnFunc.apply( self.training, max_seqlen_q,