Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix wgrads for GroupedLinear when weights don't require grad #1258

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,36 +443,38 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)

if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms

def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
wgrad = None
return wgrad

wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
else:
wgrad = None
return wgrad
wgrad_list = [None] * ctx.num_gemms

wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms

if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
Expand Down
Loading