diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 800364003cef..eb063a15856f 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -6,7 +6,13 @@ from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule -from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D +from .qkv_fused_linear import ( + FusedLinear1D_Col, + FusedLinear1D_Row, + GPT2FusedLinearConv1D, + GPT2FusedLinearConv1D_Col, + GPT2FusedLinearConv1D_Row, +) __all__ = [ "Embedding1D", diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e29961c85997..3dc850ac5a54 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -79,7 +79,7 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_ ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce ctx.fp8_communication = fp8_communication - ctx.use_zbv=use_zbv + ctx.use_zbv = use_zbv output = torch.matmul(input_, weight) @@ -93,8 +93,8 @@ def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias fp8_communication = ctx.fp8_communication - use_zbv=ctx.use_zbv - + use_zbv = ctx.use_zbv + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) if bias is not None: @@ -173,7 +173,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: @@ -205,7 +205,7 @@ def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias use_zbv = ctx.use_zbv - + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) @@ -224,7 +224,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if len(grad_output.shape) > 2: grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - + # split dx & dw if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -276,7 +276,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - + grad_bias = grad_output.sum(dim=0) if use_bias else None return grad_input, grad_weight, grad_bias, None, None, None, None @@ -882,7 +882,9 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False): + def forward( + ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False + ): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -946,7 +948,6 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) @@ -1004,7 +1005,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_reduce_scatter: @@ -1327,15 +1328,17 @@ def _all_to_all_single( ).contiguous() -def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False): +def matmul_with_async_comm( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False +): return MatmulWithAsyncCommunication.apply( input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv ) + def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False): - return MatmulWithGradAccum.apply( - input_, weight, bias, async_grad_allreduce, use_zbv - ) + return MatmulWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv) + def linear_with_async_comm( input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False @@ -1370,7 +1373,15 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False, use_zbv=False + input_, + weight, + bias, + process_group, + async_grad_reduce_scatter, + dim, + ring=False, + fp8_communication=False, + use_zbv=False, ): return _MatmulWithGatherForwardReduceScatterBackward.apply( input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index aa83fb993ab4..1967417b5062 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -38,7 +38,13 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset, is_share_sp_tp -__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", "GPT2FusedLinearConv1D"] +__all__ = [ + "FusedLinear1D_Col", + "FusedLinear1D_Row", + "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv1D_Row", + "GPT2FusedLinearConv1D", +] # ==================================== # For GPT Only @@ -647,6 +653,7 @@ class GPT2FusedLinearConv1D(ParallelModule): More details about ``initializer`` please refer to `init `_. """ + def __init__( self, in_features: int, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 1ee0bb4cef4f..148e0ff2d33d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -68,7 +68,7 @@ def module_policy(self): sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv - + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -96,13 +96,17 @@ def module_policy(self): "split_sizes": [self.model.config.hidden_size] * 3, "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -112,13 +116,17 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -160,13 +168,17 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -175,13 +187,17 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -205,7 +221,7 @@ def module_policy(self): policy=policy, target_key=GPT2MLP, ) - + if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by self.append_or_create_submodule_replacement( @@ -617,7 +633,7 @@ def get_held_layers(self) -> List[nn.Module]: else: if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + held_layers.append(self.model.classifier) # if self.pipeline_stage_manager.is_last_stage(): # held_layers.append(self.model.dropout) # held_layers.append(self.model.classifier) @@ -654,7 +670,7 @@ def get_held_layers(self) -> List[nn.Module]: else: if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.score) - + # if self.pipeline_stage_manager.is_last_stage(): # held_layers.append(self.model.score) return held_layers diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 0f68f6c639f4..34074642c58d 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -9,7 +9,7 @@ import colossalai from colossalai.lazy import LazyInitContext from colossalai.pipeline.weight_grad_store import WeightGradStore -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D +from colossalai.shardformer.layer import GPT2FusedLinearConv1D, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -125,9 +125,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_base = GPT2FusedLinearConv1D.from_native_module( - linear_copy, seq_parallel_mode=seq_parallel_mode - ) + linear_base = GPT2FusedLinearConv1D.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode) assert linear.weight.shape == torch.Size([48, 192]) assert linear_base.weight.shape == torch.Size([48, 192]) @@ -153,6 +151,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel assert_close(out.grad, gather_out.grad) assert_close(linear.weight.grad, linear_base.weight.grad) + def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -182,7 +181,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # check backward correctness out.sum().backward() gather_out.sum().backward() - + WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue WeightGradStore.pop(chunk=0) @@ -191,6 +190,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # TODO:linear_base.weight.grad is None; But not none in WeightGradStore # assert_close(linear.weight.grad, linear_base.weight.grad) + @parameterize("lazy_init", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):