From 454e389502ad4ed4f90b0990a631fe12bdf968fd Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Thu, 5 Sep 2024 10:54:08 -0700 Subject: [PATCH] Added offloading support FP8 attention (#1131) * Added offloading support FP8 attention Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/attention.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Selvaraj Anandaraj * Fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 59bc26140d..91c14899ec 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5698,16 +5698,23 @@ def forward( out_save = out_ret fp8_tensors = (None, None, None, None, None, None) + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] + if ctx.fp8: + tensor_list = fp8_tensors + else: + tensor_list = [q, k, v, out_save] + + tensor_list.extend(aux_ctx_tensors) + qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: if tensor is not None: tensor.activation_offloading = True - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)