From 70b0ae1e9de73638fa59fa4e68fc22a3f4da4c4e Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 16:48:07 +0800 Subject: [PATCH] [shardformer] fix bug incorrect number of gradients; add fusedLinear base testcase; --- colossalai/shardformer/layer/_operation.py | 3 +- .../shardformer/layer/qkv_fused_linear.py | 4 +-- .../test_layer/test_qkv_fused_linear_1d.py | 35 ++++++++++++++++++- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 921f92b025df..86970c641229 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -173,7 +173,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: @@ -1114,7 +1113,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if ctx.async_grad_reduce_scatter: handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 149e3d66a57f..74470ace15d8 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -1349,9 +1349,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_grad_accum( - input_parallel, self.weight, bias, True, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv - ) + output_parallel = linear_with_grad_accum(input_parallel, self.weight, bias, True, use_zbv=self.use_zbv) output = output_parallel diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index fccba564f7c7..43fca1ce65cc 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -7,7 +7,7 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row +from colossalai.shardformer.layer import FusedLinear1D, FusedLinear1D_Col, FusedLinear1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -120,12 +120,45 @@ def check_linear_1d_col_row(lazy_init: bool): assert_close(target_grad2, linear_row.weight.grad) +@parameterize("lazy_init", [False, True]) +def check_linear_1d_base(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = nn.Linear(8, 80).cuda() + with ctx: + linear_copy = nn.Linear(8, 80).cuda() + linear_base = FusedLinear1D.from_native_module(linear_copy) + + assert linear.weight.shape == torch.Size([80, 8]) + assert linear.bias.shape == torch.Size([80]) + assert linear_base.weight.shape == torch.Size([80, 8]) + assert linear_base.bias.shape == torch.Size([80]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + # ensure weights are reversibly loadable + linear_base.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_base.state_dict()) + + # check computation correctness + x = torch.rand(4, 8).cuda() + out = linear(x) + base_out = linear_base(x) + assert_close(out, base_out) + + # check backward correctness + out.sum().backward() + base_out.sum().backward() + + assert_close(linear.weight.grad, linear_base.weight.grad) + + def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_linear_1d_col() check_linear_1d_row() check_linear_1d_col_row() + check_linear_1d_base() @rerun_if_address_is_in_use()