Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 10, 2024
1 parent 8cb74e7 commit 3fd2402
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 34 deletions.
8 changes: 7 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D
from .qkv_fused_linear import (
FusedLinear1D_Col,
FusedLinear1D_Row,
GPT2FusedLinearConv1D,
GPT2FusedLinearConv1D_Col,
GPT2FusedLinearConv1D_Row,
)

__all__ = [
"Embedding1D",
Expand Down
41 changes: 26 additions & 15 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv=use_zbv
ctx.use_zbv = use_zbv

output = torch.matmul(input_, weight)

Expand All @@ -93,8 +93,8 @@ def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
use_zbv=ctx.use_zbv
use_zbv = ctx.use_zbv

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
if bias is not None:
Expand Down Expand Up @@ -173,7 +173,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce and not fp8_communication:
Expand Down Expand Up @@ -205,7 +205,7 @@ def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)

Expand All @@ -224,7 +224,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

# split dx & dw
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
Expand Down Expand Up @@ -276,7 +276,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)

grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None, None, None
Expand Down Expand Up @@ -882,7 +882,9 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False):
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False
):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
Expand Down Expand Up @@ -946,7 +948,6 @@ def backward(ctx, grad_output):
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated


def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)

Expand Down Expand Up @@ -1004,7 +1005,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_reduce_scatter:
Expand Down Expand Up @@ -1327,15 +1328,17 @@ def _all_to_all_single(
).contiguous()


def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
def matmul_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return MatmulWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
)


def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return MatmulWithGradAccum.apply(
input_, weight, bias, async_grad_allreduce, use_zbv
)
return MatmulWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)


def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
Expand Down Expand Up @@ -1370,7 +1373,15 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc


def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False, use_zbv=False
input_,
weight,
bias,
process_group,
async_grad_reduce_scatter,
dim,
ring=False,
fp8_communication=False,
use_zbv=False,
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv
Expand Down
9 changes: 8 additions & 1 deletion colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset, is_share_sp_tp

__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", "GPT2FusedLinearConv1D"]
__all__ = [
"FusedLinear1D_Col",
"FusedLinear1D_Row",
"GPT2FusedLinearConv1D_Col",
"GPT2FusedLinearConv1D_Row",
"GPT2FusedLinearConv1D",
]

# ====================================
# For GPT Only
Expand Down Expand Up @@ -647,6 +653,7 @@ class GPT2FusedLinearConv1D(ParallelModule):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""

def __init__(
self,
in_features: int,
Expand Down
40 changes: 28 additions & 12 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def module_policy(self):
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
Expand Down Expand Up @@ -96,13 +96,17 @@ def module_policy(self):
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
Expand All @@ -112,13 +116,17 @@ def module_policy(self):
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand Down Expand Up @@ -160,13 +168,17 @@ def module_policy(self):
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
Expand All @@ -175,13 +187,17 @@ def module_policy(self):
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv":use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand All @@ -205,7 +221,7 @@ def module_policy(self):
policy=policy,
target_key=GPT2MLP,
)

if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
self.append_or_create_submodule_replacement(
Expand Down Expand Up @@ -617,7 +633,7 @@ def get_held_layers(self) -> List[nn.Module]:
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
held_layers.append(self.model.classifier)
# if self.pipeline_stage_manager.is_last_stage():
# held_layers.append(self.model.dropout)
# held_layers.append(self.model.classifier)
Expand Down Expand Up @@ -654,7 +670,7 @@ def get_held_layers(self) -> List[nn.Module]:
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)

# if self.pipeline_stage_manager.is_last_stage():
# held_layers.append(self.model.score)
return held_layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D
from colossalai.shardformer.layer import GPT2FusedLinearConv1D, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

Expand Down Expand Up @@ -125,9 +125,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_base = GPT2FusedLinearConv1D.from_native_module(
linear_copy, seq_parallel_mode=seq_parallel_mode
)
linear_base = GPT2FusedLinearConv1D.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode)

assert linear.weight.shape == torch.Size([48, 192])
assert linear_base.weight.shape == torch.Size([48, 192])
Expand All @@ -153,6 +151,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel
assert_close(out.grad, gather_out.grad)
assert_close(linear.weight.grad, linear_base.weight.grad)


def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()

Expand Down Expand Up @@ -182,7 +181,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo
# check backward correctness
out.sum().backward()
gather_out.sum().backward()

WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)

Expand All @@ -191,6 +190,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo
# TODO:linear_base.weight.grad is None; But not none in WeightGradStore
# assert_close(linear.weight.grad, linear_base.weight.grad)


@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
Expand Down

0 comments on commit 3fd2402

Please sign in to comment.