Skip to content

Commit

Permalink
Added offloading support FP8 attention (#1131)
Browse files Browse the repository at this point in the history
* Added offloading support FP8 attention

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Update transformer_engine/pytorch/attention.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
3 people committed Sep 5, 2024
1 parent 5fafeb0 commit 454e389
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5698,16 +5698,23 @@ def forward(
out_save = out_ret
fp8_tensors = (None, None, None, None, None, None)

ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

from .cpu_offload import CPUOffloadEnabled

if CPUOffloadEnabled:
tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
if ctx.fp8:
tensor_list = fp8_tensors
else:
tensor_list = [q, k, v, out_save]

tensor_list.extend(aux_ctx_tensors)

qkv_layout = "sbhd_sbhd_sbhd"
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True

ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
Expand Down

0 comments on commit 454e389

Please sign in to comment.