Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero bubble]support zbv all #6081

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
167 commits
Select commit Hold shift + click to select a range
f5a52e1
fp8 operators for compressed communication
BurkeHulk Jul 1, 2024
6991819
Merge branch 'hpcaitech:main' into feature/fp8_comm
BurkeHulk Jul 4, 2024
e17f835
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
dbfa7d3
fix typo
GuangyaoZhang Jul 10, 2024
1e19594
fix scaling algorithm in FP8 casting
BurkeHulk Jul 12, 2024
e881901
support fp8 communication in pipeline parallelism
BurkeHulk Jul 12, 2024
6601874
add fp8_communication flag in the script
BurkeHulk Jul 12, 2024
1f1b856
Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/f…
BurkeHulk Jul 12, 2024
51f916b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
9470701
Merge pull request #5885 from BurkeHulk/feature/fp8_comm
BurkeHulk Jul 16, 2024
457a0de
shardformer fp8
GuangyaoZhang Jul 8, 2024
5a310b9
fix rebase
GuangyaoZhang Jul 17, 2024
6a20f07
remove all to all
GuangyaoZhang Jul 17, 2024
d0bdb51
Merge pull request #5899 from BurkeHulk/SP_fp8
GuangyaoZhang Jul 18, 2024
5b969fd
fix shardformer fp8 communication training degradation
GuangyaoZhang Jul 18, 2024
62661cd
Merge pull request #5921 from BurkeHulk/fp8_fix
GuangyaoZhang Jul 18, 2024
5fd0592
[fp8] support all-gather flat tensor (#5932)
ver217 Jul 24, 2024
ae486ce
[fp8] add fp8 comm for low level zero
ver217 Aug 2, 2024
91e596d
[test] add zero fp8 test case
ver217 Aug 2, 2024
c297e21
Merge pull request #5961 from ver217/feature/zeor-fp8
BurkeHulk Aug 2, 2024
53cb960
[Feature] llama shardformer fp8 support (#5938)
GuangyaoZhang Aug 5, 2024
0c10afd
[FP8] rebase main (#5963)
flybird11111 Aug 6, 2024
afb26de
[fp8]support all2all fp8 (#5953)
flybird11111 Aug 6, 2024
76ea164
[fp8] add fp8 linear (#5967)
ver217 Aug 7, 2024
ccabcf6
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
ver217 Aug 7, 2024
7739629
fix (#5976)
flybird11111 Aug 7, 2024
b480eec
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
BurkeHulk Aug 8, 2024
4b9bec8
[test ci]Feature/fp8 comm (#5981)
flybird11111 Aug 8, 2024
8241c0c
[fp8] support gemini plugin (#5978)
ver217 Aug 9, 2024
e4aadee
[fp8] use torch compile (torch >= 2.3.0) (#5979)
botbw Aug 9, 2024
f1a3a32
[fp8]Moe support fp8 communication (#5977)
flybird11111 Aug 9, 2024
b2483c8
[fp8] support hybrid parallel plugin (#5982)
wangbluo Aug 12, 2024
0978080
[fp8] refactor fp8 linear with compile (#5993)
ver217 Aug 13, 2024
597b206
[fp8] support asynchronous FP8 communication (#5997)
flybird11111 Aug 14, 2024
88fa096
[fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)
botbw Aug 15, 2024
1a2e90d
[fp8] linear perf enhancement
botbw Aug 15, 2024
20722a8
[fp8]update reduce-scatter test (#6002)
flybird11111 Aug 15, 2024
3f09a61
[fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)
wangbluo Aug 16, 2024
0a51319
[fp8] zero support fp8 linear. (#6006)
flybird11111 Aug 16, 2024
4cf79fa
merge
wangbluo Aug 17, 2024
81272e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
02636c5
fix the merge
wangbluo Aug 19, 2024
52289e4
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
1a5847e
fix the merge
wangbluo Aug 19, 2024
3353042
fix the merge
wangbluo Aug 19, 2024
64aad96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
4c82bfc
fix the merge
wangbluo Aug 19, 2024
0d8e82a
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
12b4401
fix
wangbluo Aug 19, 2024
2eb3683
fix
wangbluo Aug 19, 2024
88b3f06
fix the merge
wangbluo Aug 19, 2024
1f703e0
fix
wangbluo Aug 19, 2024
5382311
fix
wangbluo Aug 20, 2024
f7acfa1
fix
wangbluo Aug 20, 2024
2ee6235
fix
wangbluo Aug 20, 2024
2e4cbe3
fix
wangbluo Aug 20, 2024
2d362ac
fix merge
wangbluo Aug 20, 2024
eb5ba40
fix the merge
wangbluo Aug 21, 2024
193030f
fix
wangbluo Aug 21, 2024
6aface9
fix
wangbluo Aug 21, 2024
698c8b9
fix
wangbluo Aug 21, 2024
8b8e282
fix
wangbluo Aug 21, 2024
eea37da
[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
wangbluo Aug 22, 2024
d77e66a
Merge pull request #6023 from wangbluo/fp8_merge
wangbluo Aug 22, 2024
971b16a
fix
wangbluo Aug 22, 2024
a292554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
afe845f
Merge pull request #6024 from wangbluo/fix_merge
wangbluo Aug 22, 2024
caab4a3
Merge branch 'main' into feature/fp8_comm
ver217 Aug 22, 2024
0bc9a87
Update train_dpo.py
flybird11111 Aug 23, 2024
3b0df30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
9e76764
Update low_level_zero_plugin.py
flybird11111 Aug 23, 2024
0bf46c5
Merge pull request #6029 from hpcaitech/flybird11111-patch-1
wangbluo Aug 23, 2024
dae3999
fix
wangbluo Aug 26, 2024
80d24ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2024
4a6f31e
Merge pull request #6033 from wangbluo/fix
wangbluo Aug 26, 2024
17904cb
Merge pull request #6012 from hpcaitech/feature/fp8_comm
ver217 Aug 27, 2024
d383449
[CI] Remove triton version for compatibility bug; update req torch >=…
Edenzzzz Aug 27, 2024
cc1b0ef
[plugin] hotfix zero plugin (#6036)
ver217 Aug 28, 2024
4a68efb
[Colossal-LLaMA] Refactor latest APIs (#6030)
TongLi3701 Aug 28, 2024
0d3a85d
add fused norm (#6038)
TongLi3701 Aug 28, 2024
e96a076
[FP8] unsqueeze scale to make it compatible with torch.compile (#6040)
GuangyaoZhang Aug 29, 2024
e9032fb
[colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model;…
flymin Sep 2, 2024
c650a90
[Hotfix] Remove deprecated install (#6042)
TongLi3701 Sep 3, 2024
c3b5caf
[fp8] optimize all-gather (#6043)
ver217 Sep 3, 2024
26e5539
[fp8] fix linear hook (#6046)
ver217 Sep 3, 2024
5ce6dd7
[fp8] disable all_to_all_fp8 in intranode (#6045)
BurkeHulk Sep 9, 2024
b3db105
[release] update version (#6041)
ver217 Sep 10, 2024
8fd25d6
[Feature] Split cross-entropy computation in SP (#5959)
Edenzzzz Sep 10, 2024
c54c4fc
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
botbw Sep 10, 2024
13946c4
[fp8] hotfix backward hook (#6053)
ver217 Sep 11, 2024
a35a078
[doc] update sp doc (#6055)
flybird11111 Sep 11, 2024
fdd84b9
fix the sp
wangbluo Sep 13, 2024
216d54e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
0a01e2a
fix the attn
wangbluo Sep 13, 2024
683179c
fix
wangbluo Sep 13, 2024
6eb8832
fix
wangbluo Sep 13, 2024
f393867
fix
wangbluo Sep 13, 2024
dc03217
fix
wangbluo Sep 13, 2024
e79d442
[zerobubble]Support ZeroBubble Pipeline (#6034)
duanjunwen Sep 10, 2024
696fced
[fp8] fix missing fp8_comm flag in mixtral (#6057)
botbw Sep 13, 2024
0b14a55
fix
wangbluo Sep 13, 2024
0ad3129
fix
wangbluo Sep 13, 2024
b582319
fix
wangbluo Sep 13, 2024
f20b066
[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 …
GuangyaoZhang Sep 14, 2024
bdb125f
[doc] FP8 training and communication document (#6050)
GuangyaoZhang Sep 14, 2024
827ef3e
fix
wangbluo Sep 14, 2024
37e3523
Merge pull request #6061 from wangbluo/sp_fix
wangbluo Sep 14, 2024
10e4f7d
fix
wangbluo Sep 16, 2024
63314ce
Merge pull request #6064 from wangbluo/fix_attn
wangbluo Sep 18, 2024
4fa6b95
[moe] add parallel strategy for shared_expert && fix test for deepsee…
botbw Sep 18, 2024
f9546ba
[ColossalEval] support for vllm (#6056)
Camille7777 Sep 18, 2024
dabc2e7
[release] update version (#6062)
ver217 Sep 19, 2024
b3b3278
Merge branch 'main' of github.com:flybird11111/ColossalAI into main
flybird11111 Sep 29, 2024
da4595a
[feat] add zerobubble pp (just a frame now); add POC test for dx_dw; …
duanjunwen Aug 22, 2024
e450dd2
[update] update text;
duanjunwen Aug 26, 2024
ccc37a4
[feat] add test run_fwd_bwd automatic scheduling;
duanjunwen Aug 26, 2024
228d71e
[feat] fix poc format
duanjunwen Aug 28, 2024
8b0ffed
[fix] fix poc test; add comments in poc;
duanjunwen Aug 28, 2024
97f2443
[feat] add optim backward_b_by_grad
duanjunwen Aug 29, 2024
c90bd98
[feat] fix optimizer bwd b & w; support return accum loss & output
duanjunwen Aug 29, 2024
5df5965
[fix] fix optim bwd; add license for v_schedule; remove redundant att…
duanjunwen Aug 30, 2024
94a12f6
[feat] update test; rm comments;
duanjunwen Sep 2, 2024
cc5e7dc
[fix] rm zbv in hybridplugin
duanjunwen Sep 2, 2024
ad8ad64
[fix] fix optim bwd;
duanjunwen Sep 2, 2024
f347591
[fix] fix optim bwd;
duanjunwen Sep 3, 2024
4249a36
[fix] rm output.data after send fwd;
duanjunwen Sep 3, 2024
497d545
[fix] fix bwd step if condition; remove useless comments and format i…
duanjunwen Sep 3, 2024
0825700
[fix] fix mem check;
duanjunwen Sep 4, 2024
ae4cf5b
[fix] fix mem assertation
duanjunwen Sep 9, 2024
e80179c
[fix] fix mem; use a new model shape; only assert mem less and equal …
duanjunwen Sep 9, 2024
2683d26
[fix] fix model zoo import;
duanjunwen Sep 9, 2024
9094cc3
[feat] moehybrid support zerobubble;
duanjunwen Sep 12, 2024
3e2f260
[fix] fix zerobubble pp for shardformer type input;
duanjunwen Sep 18, 2024
8ce22ae
[fix] fix require_grad & deallocate call;
duanjunwen Sep 19, 2024
f8d6f98
[fix] fix mem assert;
duanjunwen Sep 19, 2024
78a439b
[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
duanjunwen Sep 20, 2024
8bc8bb0
[fix] fix pipeline util func deallocate --> release_tensor_data; fix …
duanjunwen Sep 20, 2024
a3a797d
[fix] fix zerobubble; support shardformer model type;
duanjunwen Sep 26, 2024
4d3eaee
[fix] fix test_pipeline_utils ci;
duanjunwen Sep 26, 2024
2fd9d3e
[plugin] hybrid support zero bubble pipeline (#6060)
flybird11111 Sep 27, 2024
21c62b6
[feat] add zerobubble pp (just a frame now); add POC test for dx_dw; …
duanjunwen Aug 22, 2024
28ee5a7
[update] update text;
duanjunwen Aug 26, 2024
d44e7e6
[feat] add test run_fwd_bwd automatic scheduling;
duanjunwen Aug 26, 2024
49d68eb
[feat] fix poc format
duanjunwen Aug 28, 2024
0055c47
[fix] fix poc test; add comments in poc;
duanjunwen Aug 28, 2024
21bf510
[feat] add optim backward_b_by_grad
duanjunwen Aug 29, 2024
93ede6b
[feat] fix optimizer bwd b & w; support return accum loss & output
duanjunwen Aug 29, 2024
4ac0d6e
[fix] fix optim bwd; add license for v_schedule; remove redundant att…
duanjunwen Aug 30, 2024
262b27e
[feat] update test; rm comments;
duanjunwen Sep 2, 2024
fe99ca3
[fix] fix optim bwd;
duanjunwen Sep 2, 2024
355a3af
[fix] fix optim bwd;
duanjunwen Sep 3, 2024
4420dc1
[fix] rm output.data after send fwd;
duanjunwen Sep 3, 2024
7ba031d
[fix] fix bwd step if condition; remove useless comments and format i…
duanjunwen Sep 3, 2024
e666f5c
[fix] fix mem check;
duanjunwen Sep 4, 2024
93b3604
[fix] fix mem assertation
duanjunwen Sep 9, 2024
78ed432
[fix] fix mem; use a new model shape; only assert mem less and equal …
duanjunwen Sep 9, 2024
df12ae7
[fix] fix model zoo import;
duanjunwen Sep 9, 2024
9e90356
[fix] fix mem assert;
duanjunwen Sep 19, 2024
993f3db
[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
duanjunwen Sep 20, 2024
0767948
[plugin] hybrid support zero bubble pipeline (#6060)
flybird11111 Sep 27, 2024
3251e68
Merge branch 'feature/zerobubble' into feature/zerobubble
flybird11111 Sep 29, 2024
797d1ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2024
1637c14
Merge branch 'feature/zerobubble' of github.com:flybird11111/Colossal…
flybird11111 Oct 8, 2024
a5f0670
zbv support zero
flybird11111 Oct 8, 2024
42f2d0b
suport zbv all
flybird11111 Oct 8, 2024
d50c0a1
example support zbv
flybird11111 Oct 8, 2024
b9ac0a6
fix
flybird11111 Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,22 +1166,6 @@ def __init__(
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
Expand Down
39 changes: 30 additions & 9 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.pooler)

else:
Expand Down Expand Up @@ -430,7 +432,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)

return held_layers
Expand Down Expand Up @@ -471,7 +475,9 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers

Expand Down Expand Up @@ -511,7 +517,9 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers

Expand Down Expand Up @@ -563,7 +571,10 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
Expand Down Expand Up @@ -607,7 +618,10 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
Expand Down Expand Up @@ -638,7 +652,9 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers

Expand Down Expand Up @@ -681,7 +697,10 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
Expand Down Expand Up @@ -711,7 +730,9 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs)
return held_layers

Expand Down
8 changes: 6 additions & 2 deletions colossalai/shardformer/policies/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)

else:
Expand Down Expand Up @@ -351,7 +353,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

Expand Down
30 changes: 24 additions & 6 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ def get_held_layers(self) -> List[nn.Module]:
held_layers.append(module.drop)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.ln_f)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.h))
Expand Down Expand Up @@ -355,7 +357,9 @@ def module_policy(self):

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
if self.pipeline_stage_manager.use_zbv and self.pipeline_stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

Expand Down Expand Up @@ -423,7 +427,14 @@ def module_policy(self):

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
if self.pipeline_stage_manager.use_zbv and self.pipeline_stage_manager.is_first_stage(ignore_chunk=True):
multiple_choice_head = self.model.multiple_choice_head
held_layers.append(self.model.lm_head)
held_layers.append(multiple_choice_head.summary)
held_layers.append(multiple_choice_head.activation)
held_layers.append(multiple_choice_head.first_dropout)
held_layers.append(multiple_choice_head.last_dropout)
elif self.pipeline_stage_manager.is_last_stage():
multiple_choice_head = self.model.multiple_choice_head
held_layers.append(self.model.lm_head)
held_layers.append(multiple_choice_head.summary)
Expand Down Expand Up @@ -467,7 +478,9 @@ def module_policy(self):

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
if self.pipeline_stage_manager.use_zbv and self.pipeline_stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs)
elif self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers

Expand Down Expand Up @@ -506,7 +519,10 @@ def module_policy(self):

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
if self.pipeline_stage_manager.use_zbv and self.pipeline_stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
elif self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
Expand All @@ -533,7 +549,9 @@ def module_policy(self):

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
if self.pipeline_stage_manager.use_zbv and self.pipeline_stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.score)
elif self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers

Expand Down
6 changes: 3 additions & 3 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.norm)
elif stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)

else:
Expand Down
12 changes: 9 additions & 3 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)

else:
Expand Down Expand Up @@ -348,7 +350,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

Expand Down Expand Up @@ -404,7 +408,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.score)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers

Expand Down
12 changes: 9 additions & 3 deletions colossalai/shardformer/policies/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)

else:
Expand Down Expand Up @@ -347,7 +349,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

Expand Down Expand Up @@ -399,7 +403,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.score)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers

Expand Down
30 changes: 25 additions & 5 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig

warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -91,7 +92,7 @@ def main():
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)

parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
Expand Down Expand Up @@ -137,6 +138,28 @@ def empty_init():
# ==============================
# Initialize Booster
# ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)

scheduler_nodes = None
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.b // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()

use_empty_init = True
if args.plugin == "gemini":
plugin = GeminiPlugin(
Expand Down Expand Up @@ -227,6 +250,7 @@ def empty_init():
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
Expand Down Expand Up @@ -256,10 +280,6 @@ def empty_init():
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)

if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
Expand Down
Loading
Loading