Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sharderformer] Support zbv in Sharderformer Policy #6150

Merged
merged 59 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
b31a052
[feat] Sharderformer support zbv
duanjunwen Nov 21, 2024
5f89e7f
[feat] support chatglm2, command, deepseek for zbv
duanjunwen Nov 21, 2024
41e1972
[feat] support zbv in shardformer policy:
duanjunwen Nov 22, 2024
37a5a66
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Nov 29, 2024
efffe6b
[feat] support GPT2FusedLinearConv1D
duanjunwen Dec 9, 2024
2b94e00
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Dec 9, 2024
a84fc41
[feat] support GPT2FusedLinear (without tp)
duanjunwen Dec 9, 2024
014cc27
[fix] debug FusedConvLinear
duanjunwen Dec 10, 2024
778d4df
[shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
duanjunwen Dec 10, 2024
8cb74e7
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Dec 10, 2024
d168b73
[Shardformer] support FusedLinear1D base for zbv
duanjunwen Dec 10, 2024
01a9cb3
[shardformer] support zbv in FusedLinear1D base, Col, Row
duanjunwen Dec 10, 2024
fc77b24
[shardformer] support zbv in blip2 and sam policy
duanjunwen Dec 10, 2024
70b0ae1
[shardformer] fix bug incorrect number of gradients; add fusedLinear
duanjunwen Dec 10, 2024
37b670e
[fix] fix incorrect number of gradients ;
duanjunwen Dec 10, 2024
94bb9ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
dee1878
[Shardformer] add en doc for zbv;
duanjunwen Dec 16, 2024
83e670e
[fix] fix typo in Model compatibility table
duanjunwen Dec 16, 2024
2a55566
[fix] fix API Reference typo
duanjunwen Dec 16, 2024
5430eb0
[Shardformer] add zh-Han doc for zbv
duanjunwen Dec 17, 2024
25da23d
[fix] fix Linear name; update en & zh doc
duanjunwen Dec 18, 2024
fd5bd33
[fix] fix shardformer doc import err
duanjunwen Dec 18, 2024
c749a7c
[fix] fix shardconfig import in doc
duanjunwen Dec 18, 2024
eba4e33
[fix] fix shardformer doc
duanjunwen Dec 18, 2024
3c5ce9e
[fix] fix shardconfig doc
duanjunwen Dec 18, 2024
6bbe666
[fix] fix config
duanjunwen Dec 18, 2024
3946366
[fix] remove shardconfig
duanjunwen Dec 18, 2024
b99c733
[fix] fix doc
duanjunwen Dec 18, 2024
99a7829
[feat] add zbv doc string
duanjunwen Dec 18, 2024
f67ce86
[fix] rm doc
duanjunwen Dec 18, 2024
bbdcca1
[fix] fix doc
duanjunwen Dec 18, 2024
9665f66
[fix] empty zbv doc
duanjunwen Dec 18, 2024
568e2c5
[fix] ifx torch version
duanjunwen Dec 20, 2024
f8dc150
[fix] fix torch version
duanjunwen Dec 20, 2024
1481b8d
[fix] fix torch versions
duanjunwen Dec 20, 2024
cb52e28
[fix] fix torch versions
duanjunwen Dec 20, 2024
30e65e7
[fix] fix pyramid versions
duanjunwen Dec 23, 2024
541664a
[fix] fix pyramid, zope version
duanjunwen Dec 23, 2024
ed76d69
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Dec 23, 2024
e592884
[fix] try fix workflow
duanjunwen Dec 23, 2024
3b0669a
[fix] try import ShardConfig in yml
duanjunwen Dec 23, 2024
1cd60a0
[fix] fix workflow
duanjunwen Dec 23, 2024
573d5ce
[fix] fix workflow
duanjunwen Dec 23, 2024
938bf6d
[fix] fix workflow
duanjunwen Dec 23, 2024
90d1d53
[fix] fix workflow
duanjunwen Dec 23, 2024
9b7940f
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Dec 23, 2024
aab6275
[fix] fix ci
duanjunwen Dec 23, 2024
63b7db5
[fix] fix zbv doc
duanjunwen Dec 23, 2024
7fb23a5
[fix] fix param for qkv linear, gpt2fused linear; fix requirments;
duanjunwen Dec 24, 2024
f0a8d78
[fix] fix policy use fused_linear
duanjunwen Dec 24, 2024
ff316c9
[fix] fix weight grad none, err caused by weight ptr change
duanjunwen Dec 24, 2024
f52c36e
[fix] fix comm in WeightGradStore
duanjunwen Dec 24, 2024
feca06e
[fix] fix WeightGradStore pop param
duanjunwen Dec 25, 2024
d74071a
[fix] remove useless param in doc; fix gpt2 qkv test;
duanjunwen Dec 25, 2024
c0b6fbc
[shardformer] simplify execute_w_pass_grad_accum;
duanjunwen Dec 25, 2024
130b50c
[fix] rm useless comments
duanjunwen Dec 25, 2024
c4df1cc
[shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
duanjunwen Dec 25, 2024
52a3b88
[shardformer] Run meaningful doc test
duanjunwen Dec 25, 2024
ee6bba9
[shadformer] fix doc test cmd;
duanjunwen Dec 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/doc_check_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
# there is no main branch, so it's safe to checkout the main branch from the merged branch
# docer will rebase the remote main branch to the merged branch, so we have to config user
- name: Make the merged branch main

run: |
cd ColossalAI
git checkout -b main
Expand Down
13 changes: 13 additions & 0 deletions colossalai/pipeline/schedule/zero_bubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:


class ZeroBubbleVPipeScheduler(PipelineSchedule):
r"""
ZeroBubbleVPipeScheduler

Args:
stage_manager (PipelineStageManager): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
schedule (List[ScheduledNode]): Schedule for ZeroBubbleVPipe.
num_model_chunks (int) : The number of model chunk in a device.
num_microbatch (Optional[int]): The number of microbatch.
microbatch_size (Optional[int]): The size per microbatch.
enable_metadata_cache (bool): whether to enable metadata cache to acclerate communication.
overlap_p2p (bool): whether to use overlap_p2p.
"""

def __init__(
self,
stage_manager: PipelineStageManager,
Expand Down
24 changes: 17 additions & 7 deletions colossalai/pipeline/weight_grad_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class WeightGradStore:

@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))

@classmethod
Expand All @@ -18,15 +17,26 @@ def flush(cls, chunk=0):

@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
if isinstance(weight, tuple):
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
# View will lead to weight ptr change
# weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update
_, weight_origin = weight
if weight_origin.grad is not None:
func(total_input, grad_output, weight_origin.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight_origin.grad = grad_weight
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")
13 changes: 11 additions & 2 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,24 @@
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from .qkv_fused_linear import (
FusedLinear,
FusedLinear1D_Col,
FusedLinear1D_Row,
GPT2FusedLinearConv,
GPT2FusedLinearConv1D_Col,
GPT2FusedLinearConv1D_Row,
)

__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"LinearWithGradAccum",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",
"GPT2FusedLinearConv",
"GPT2FusedLinearConv1D_Row",
"GPT2FusedLinearConv1D_Col",
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
Expand All @@ -26,6 +34,7 @@
"FusedLayerNorm",
"FusedRMSNorm",
"FusedLinear1D_Col",
"FusedLinear",
"ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",
Expand Down
Loading
Loading