Skip to content

Commit

Permalink
[shardformer] fix bug incorrect number of gradients; add fusedLinear
Browse files Browse the repository at this point in the history
base testcase;
  • Loading branch information
duanjunwen committed Dec 10, 2024
1 parent fc77b24 commit 70b0ae1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
3 changes: 1 addition & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce and not fp8_communication:
Expand Down Expand Up @@ -1114,7 +1113,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
if ctx.async_grad_reduce_scatter:
handle.wait()

return output, grad_weight, grad_bias, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None, None


class _SplitForwardGatherBackward(torch.autograd.Function):
Expand Down
4 changes: 1 addition & 3 deletions colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,9 +1349,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None

output_parallel = linear_with_grad_accum(
input_parallel, self.weight, bias, True, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
)
output_parallel = linear_with_grad_accum(input_parallel, self.weight, bias, True, use_zbv=self.use_zbv)

output = output_parallel

Expand Down
35 changes: 34 additions & 1 deletion tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer import FusedLinear1D, FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

Expand Down Expand Up @@ -120,12 +120,45 @@ def check_linear_1d_col_row(lazy_init: bool):
assert_close(target_grad2, linear_row.weight.grad)


@parameterize("lazy_init", [False, True])
def check_linear_1d_base(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(8, 80).cuda()
with ctx:
linear_copy = nn.Linear(8, 80).cuda()
linear_base = FusedLinear1D.from_native_module(linear_copy)

assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([80])
assert linear_base.weight.shape == torch.Size([80, 8])
assert linear_base.bias.shape == torch.Size([80])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias

# ensure weights are reversibly loadable
linear_base.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_base.state_dict())

# check computation correctness
x = torch.rand(4, 8).cuda()
out = linear(x)
base_out = linear_base(x)
assert_close(out, base_out)

# check backward correctness
out.sum().backward()
base_out.sum().backward()

assert_close(linear.weight.grad, linear_base.weight.grad)


def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")

check_linear_1d_col()
check_linear_1d_row()
check_linear_1d_col_row()
check_linear_1d_base()


@rerun_if_address_is_in_use()
Expand Down

0 comments on commit 70b0ae1

Please sign in to comment.