diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 59bc26140d..91c14899ec 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5698,16 +5698,23 @@ def forward( out_save = out_ret fp8_tensors = (None, None, None, None, None, None) + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] + if ctx.fp8: + tensor_list = fp8_tensors + else: + tensor_list = [q, k, v, out_save] + + tensor_list.extend(aux_ctx_tensors) + qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: if tensor is not None: tensor.activation_offloading = True - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)