Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 10, 2024
1 parent 70b0ae1 commit 44b5786
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
40 changes: 28 additions & 12 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def module_policy(self):
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
Expand Down Expand Up @@ -96,13 +96,17 @@ def module_policy(self):
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
Expand All @@ -112,13 +116,17 @@ def module_policy(self):
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand Down Expand Up @@ -160,13 +168,17 @@ def module_policy(self):
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
Expand All @@ -175,13 +187,17 @@ def module_policy(self):
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand All @@ -205,7 +221,7 @@ def module_policy(self):
policy=policy,
target_key=GPT2MLP,
)

if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
self.append_or_create_submodule_replacement(
Expand Down Expand Up @@ -617,7 +633,7 @@ def get_held_layers(self) -> List[nn.Module]:
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
held_layers.append(self.model.classifier)
# if self.pipeline_stage_manager.is_last_stage():
# held_layers.append(self.model.dropout)
# held_layers.append(self.model.classifier)
Expand Down Expand Up @@ -654,7 +670,7 @@ def get_held_layers(self) -> List[nn.Module]:
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)

# if self.pipeline_stage_manager.is_last_stage():
# held_layers.append(self.model.score)
return held_layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D
from colossalai.shardformer.layer import GPT2FusedLinearConv1D, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

Expand Down Expand Up @@ -125,9 +125,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_base = GPT2FusedLinearConv1D.from_native_module(
linear_copy, seq_parallel_mode=seq_parallel_mode
)
linear_base = GPT2FusedLinearConv1D.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode)

assert linear.weight.shape == torch.Size([48, 192])
assert linear_base.weight.shape == torch.Size([48, 192])
Expand All @@ -153,6 +151,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel
assert_close(out.grad, gather_out.grad)
assert_close(linear.weight.grad, linear_base.weight.grad)


def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()

Expand Down Expand Up @@ -182,7 +181,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo
# check backward correctness
out.sum().backward()
gather_out.sum().backward()

WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)

Expand All @@ -191,6 +190,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo
# TODO:linear_base.weight.grad is None; But not none in WeightGradStore
# assert_close(linear.weight.grad, linear_base.weight.grad)


@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
Expand Down

0 comments on commit 44b5786

Please sign in to comment.