Skip to content

Commit 00a3c21

Browse files
committed
Fix 8gpu PP failure due to 2D DCP disablement
DCP recently added safeties to avoid using it for 2D/3D since strided sharding (a feature needed for safe 2D/3D resharding) is not ready yet. PP uses DCP to load a seed checkpoint. Disabling the safety mechanism is enough to make 3D/PP still work (for the case where we train from the beginning or do not re-shard. (Resharding refers to saving a checkpoint from one world size/parallelism config and loading/resuming under a different one). ghstack-source-id: c069d21 Pull Request resolved: #460
1 parent 0f70507 commit 00a3c21

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,17 @@ def apply_fsdp(
507507
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
508508
)
509509

510+
if parallel_dims.pp_enabled:
511+
# TODO
512+
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
513+
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
514+
# without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
515+
# removed after strided sharding is landed in DCP.
516+
for module in model.modules():
517+
assert len(module._load_state_dict_pre_hooks) <= 1
518+
module._load_state_dict_pre_hooks.clear()
519+
assert len(module._state_dict_pre_hooks) <= 1
520+
module._state_dict_pre_hooks.clear()
510521
logger.info("Applied FSDP to the model")
511522
return model
512523

0 commit comments

Comments
 (0)