Skip to content

[PyTorch] Draft of new activation offloading API #1762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
679 changes: 552 additions & 127 deletions tests/pytorch/test_cpu_offloading.py

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
CPUOffload
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
Expand Down Expand Up @@ -289,15 +289,14 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
_disable_wgrads(block)

if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
cpu_offload = CPUOffload()
block = cpu_offload(block)

use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
if cpu_offload:
cpu_offload.sync_before_bwd()
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _load_library():
from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch.cpu_offload import CPUOffload
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,12 @@ def forward(
)
else:
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
is_cpu_offload_enabled,
offload,
)

if CPUOffloadEnabled:
mark_activation_offload(
if is_cpu_offload_enabled():
offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
)

Expand Down Expand Up @@ -1054,19 +1054,19 @@ def forward(
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
is_cpu_offload_enabled,
offload,
)

if CPUOffloadEnabled:
if is_cpu_offload_enabled():
if ctx.fp8:
tensor_list = fp8_tensors
else:
tensor_list = [q, k, v, out_save]

qkv_layout = "sbhd_sbhd_sbhd"
mark_activation_offload(*tensor_list)
mark_activation_offload(*aux_ctx_tensors)
offload(*tensor_list)
offload(*aux_ctx_tensors)

ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1117,15 +1117,15 @@ def forward(
cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
)

from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
from transformer_engine.pytorch.cpu_offload import is_cpu_offload_enabled

if CPUOffloadEnabled:
if is_cpu_offload_enabled():
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
Expand Down
Loading