Skip to content

Added MCore FSDP support for TE #1890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ def forward(
ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight
if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if not hasattr(param, "__fsdp_param__"):
ctx.main_grad = weight.main_grad
ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
Expand Down Expand Up @@ -527,11 +531,14 @@ def backward(
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
ctx.main_grad
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None
)
if not hasattr(param, "__fsdp_param__"):
main_grad = (
ctx.main_grad
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None
)
else:
main_grad = origin_weight.get_main_grad()

# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
Expand Down
19 changes: 13 additions & 6 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,11 @@ def forward(
ctx.grad_output_quantizer = grad_output_quantizer
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if not hasattr(param, "__fsdp_param__"):
ctx.main_grad = weight.main_grad

ctx.debug = debug
ctx.cpu_offloading = cpu_offloading
Expand Down Expand Up @@ -452,11 +456,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
ctx.main_grad
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None
)
if not hasattr(param, "__fsdp_param__"):
main_grad = (
ctx.main_grad
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None
)
else:
main_grad = weight.get_main_grad()

if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
Expand Down
Loading