Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 9, 2024
1 parent a84fc41 commit ba7fc35
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
8 changes: 7 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D
from .qkv_fused_linear import (
FusedLinear1D_Col,
FusedLinear1D_Row,
GPT2FusedLinearConv1D,
GPT2FusedLinearConv1D_Col,
GPT2FusedLinearConv1D_Row,
)

__all__ = [
"Embedding1D",
Expand Down
12 changes: 6 additions & 6 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)

Expand All @@ -165,7 +165,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

# split dx & dw
if weight.grad is not None:
grad = weight.grad
Expand Down Expand Up @@ -217,7 +217,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)

grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None, None, None
Expand Down Expand Up @@ -1213,10 +1213,10 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
)


def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return MatmulWithGradAccum.apply(
input_, weight, bias, async_grad_allreduce, use_zbv
)
return MatmulWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)


def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
Expand Down
9 changes: 8 additions & 1 deletion colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset, is_share_sp_tp

__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", "GPT2FusedLinearConv1D"]
__all__ = [
"FusedLinear1D_Col",
"FusedLinear1D_Row",
"GPT2FusedLinearConv1D_Col",
"GPT2FusedLinearConv1D_Row",
"GPT2FusedLinearConv1D",
]

# ====================================
# For GPT Only
Expand Down Expand Up @@ -641,6 +647,7 @@ class GPT2FusedLinearConv1D(ParallelModule):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""

def __init__(
self,
in_features: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D
from colossalai.shardformer.layer import GPT2FusedLinearConv1D, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

Expand Down Expand Up @@ -125,9 +125,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_base = GPT2FusedLinearConv1D.from_native_module(
linear_copy, seq_parallel_mode=seq_parallel_mode
)
linear_base = GPT2FusedLinearConv1D.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode)

assert linear.weight.shape == torch.Size([48, 192])
assert linear_base.weight.shape == torch.Size([48, 192])
Expand All @@ -153,6 +151,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel
assert_close(out.grad, gather_out.grad)
assert_close(linear.weight.grad, linear_base.weight.grad)


def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()

Expand Down Expand Up @@ -182,7 +181,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo
# check backward correctness
out.sum().backward()
gather_out.sum().backward()

WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)

Expand All @@ -191,6 +190,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo
# TODO:linear_base.weight.grad is None; But not none in WeightGradStore
# assert_close(linear.weight.grad, linear_base.weight.grad)


@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
Expand Down

0 comments on commit ba7fc35

Please sign in to comment.