Skip to content

Commit

Permalink
[fix] fix incorrect number of gradients ;
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Dec 10, 2024
1 parent 70b0ae1 commit 37b670e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
if ctx.async_grad_reduce_scatter:
handle.wait()

return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None


def _ring_as_reducescatter(
Expand Down Expand Up @@ -930,7 +930,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
# grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None


class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = linear_with_async_comm(
Expand Down Expand Up @@ -580,6 +581,7 @@ def forward(self, input_: Tensor) -> Tensor:
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = F.linear(input_, self.weight)
Expand Down

0 comments on commit 37b670e

Please sign in to comment.