From 44b5786566d39a85b8bd3ad35e239cc6abbe73da Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 08:51:48 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/policies/gpt2.py | 40 +++++++++++++------ .../test_gpt2_qkv_fused_linear_1d.py | 10 ++--- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 1ee0bb4cef4f..148e0ff2d33d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -68,7 +68,7 @@ def module_policy(self): sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv - + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -96,13 +96,17 @@ def module_policy(self): "split_sizes": [self.model.config.hidden_size] * 3, "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -112,13 +116,17 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -160,13 +168,17 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -175,13 +187,17 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -205,7 +221,7 @@ def module_policy(self): policy=policy, target_key=GPT2MLP, ) - + if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by self.append_or_create_submodule_replacement( @@ -617,7 +633,7 @@ def get_held_layers(self) -> List[nn.Module]: else: if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + held_layers.append(self.model.classifier) # if self.pipeline_stage_manager.is_last_stage(): # held_layers.append(self.model.dropout) # held_layers.append(self.model.classifier) @@ -654,7 +670,7 @@ def get_held_layers(self) -> List[nn.Module]: else: if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.score) - + # if self.pipeline_stage_manager.is_last_stage(): # held_layers.append(self.model.score) return held_layers diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 0f68f6c639f4..34074642c58d 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -9,7 +9,7 @@ import colossalai from colossalai.lazy import LazyInitContext from colossalai.pipeline.weight_grad_store import WeightGradStore -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D +from colossalai.shardformer.layer import GPT2FusedLinearConv1D, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -125,9 +125,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_base = GPT2FusedLinearConv1D.from_native_module( - linear_copy, seq_parallel_mode=seq_parallel_mode - ) + linear_base = GPT2FusedLinearConv1D.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode) assert linear.weight.shape == torch.Size([48, 192]) assert linear_base.weight.shape == torch.Size([48, 192]) @@ -153,6 +151,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel assert_close(out.grad, gather_out.grad) assert_close(linear.weight.grad, linear_base.weight.grad) + def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -182,7 +181,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # check backward correctness out.sum().backward() gather_out.sum().backward() - + WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue WeightGradStore.pop(chunk=0) @@ -191,6 +190,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # TODO:linear_base.weight.grad is None; But not none in WeightGradStore # assert_close(linear.weight.grad, linear_base.weight.grad) + @parameterize("lazy_init", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):