Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sharderformer] Support zbv in Sharderformer Policy #6150

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b31a052
[feat] Sharderformer support zbv
duanjunwen Nov 21, 2024
5f89e7f
[feat] support chatglm2, command, deepseek for zbv
duanjunwen Nov 21, 2024
41e1972
[feat] support zbv in shardformer policy:
duanjunwen Nov 22, 2024
37a5a66
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Nov 29, 2024
efffe6b
[feat] support GPT2FusedLinearConv1D
duanjunwen Dec 9, 2024
2b94e00
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Dec 9, 2024
a84fc41
[feat] support GPT2FusedLinear (without tp)
duanjunwen Dec 9, 2024
014cc27
[fix] debug FusedConvLinear
duanjunwen Dec 10, 2024
778d4df
[shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
duanjunwen Dec 10, 2024
8cb74e7
Merge branch 'main' into feature/sharderformer_support_zbv
duanjunwen Dec 10, 2024
d168b73
[Shardformer] support FusedLinear1D base for zbv
duanjunwen Dec 10, 2024
01a9cb3
[shardformer] support zbv in FusedLinear1D base, Col, Row
duanjunwen Dec 10, 2024
fc77b24
[shardformer] support zbv in blip2 and sam policy
duanjunwen Dec 10, 2024
70b0ae1
[shardformer] fix bug incorrect number of gradients; add fusedLinear
duanjunwen Dec 10, 2024
37b670e
[fix] fix incorrect number of gradients ;
duanjunwen Dec 10, 2024
94bb9ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
dee1878
[Shardformer] add en doc for zbv;
duanjunwen Dec 16, 2024
83e670e
[fix] fix typo in Model compatibility table
duanjunwen Dec 16, 2024
2a55566
[fix] fix API Reference typo
duanjunwen Dec 16, 2024
5430eb0
[Shardformer] add zh-Han doc for zbv
duanjunwen Dec 17, 2024
25da23d
[fix] fix Linear name; update en & zh doc
duanjunwen Dec 18, 2024
fd5bd33
[fix] fix shardformer doc import err
duanjunwen Dec 18, 2024
c749a7c
[fix] fix shardconfig import in doc
duanjunwen Dec 18, 2024
eba4e33
[fix] fix shardformer doc
duanjunwen Dec 18, 2024
3c5ce9e
[fix] fix shardconfig doc
duanjunwen Dec 18, 2024
6bbe666
[fix] fix config
duanjunwen Dec 18, 2024
3946366
[fix] remove shardconfig
duanjunwen Dec 18, 2024
b99c733
[fix] fix doc
duanjunwen Dec 18, 2024
99a7829
[feat] add zbv doc string
duanjunwen Dec 18, 2024
f67ce86
[fix] rm doc
duanjunwen Dec 18, 2024
bbdcca1
[fix] fix doc
duanjunwen Dec 18, 2024
9665f66
[fix] empty zbv doc
duanjunwen Dec 18, 2024
568e2c5
[fix] ifx torch version
duanjunwen Dec 20, 2024
f8dc150
[fix] fix torch version
duanjunwen Dec 20, 2024
1481b8d
[fix] fix torch versions
duanjunwen Dec 20, 2024
cb52e28
[fix] fix torch versions
duanjunwen Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions colossalai/pipeline/weight_grad_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class WeightGradStore:

@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))

@classmethod
Expand All @@ -18,7 +17,6 @@ def flush(cls, chunk=0):

@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
Expand Down
11 changes: 10 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from .qkv_fused_linear import (
FusedLinear1D,
FusedLinear1D_Col,
FusedLinear1D_Row,
GPT2FusedLinearConv1D,
GPT2FusedLinearConv1D_Col,
GPT2FusedLinearConv1D_Row,
)

__all__ = [
"Embedding1D",
Expand All @@ -16,6 +23,7 @@
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",
"GPT2FusedLinearConv1D_Row",
"GPT2FusedLinearConv1D_Col",
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
Expand All @@ -26,6 +34,7 @@
"FusedLayerNorm",
"FusedRMSNorm",
"FusedLinear1D_Col",
"FusedLinear1D",
"ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",
Expand Down
Loading
Loading