Skip to content

Commit

Permalink
[hotfix] fix hybrid checkpointio for sp+dp (#6184)
Browse files Browse the repository at this point in the history
* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update build_on_pr.yml

* Update test_zerobubble_pp.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
flybird11111 and pre-commit-ci[bot] authored Feb 6, 2025
1 parent ca0aa23 commit 17062c8
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ jobs:
fi
- name: Upload test coverage artifact
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: report
path: report/
25 changes: 14 additions & 11 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,15 @@ def __init__(
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)

# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
Expand Down Expand Up @@ -1298,19 +1307,11 @@ def configure(
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule(
model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.mixed_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
Expand Down Expand Up @@ -1359,7 +1360,7 @@ def configure(
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=dp_group,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True,
Expand Down Expand Up @@ -1488,7 +1489,9 @@ def seed_worker(worker_id):
)

def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
return HybridParallelCheckpointIO(
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
)

def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
Expand Down
22 changes: 12 additions & 10 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,14 @@ def __init__(
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
self.dp_size = dist.get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group

self.use_fp8 = use_fp8

self.shard_config = ShardConfig(
Expand Down Expand Up @@ -404,7 +412,7 @@ def __init__(

def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group,
self.mixed_dp_group,
self.pp_group,
self.tp_group,
self.sp_group,
Expand Down Expand Up @@ -435,20 +443,14 @@ def configure(
and self.sequence_parallelism_mode == "all_to_all"
)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group

if use_ddp:
self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
)
self.ddp_config["find_unused_parameters"] = True

if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
)
Expand All @@ -457,7 +459,7 @@ def configure(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.mixed_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
Expand Down Expand Up @@ -507,7 +509,7 @@ def configure(
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=dp_group,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_pipeline/test_schedule/test_zerobubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down Expand Up @@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_model/test_shard_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def run_deepseek_commom(parallel_config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_model/test_shard_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ def run_mixtral_commom(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down

0 comments on commit 17062c8

Please sign in to comment.