Skip to content

Commit

Permalink
[cp] apply fsdp to model when CP is enabled without DP for correct lo…
Browse files Browse the repository at this point in the history
…ss and lower mem usage (#685)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #684
* __->__ #685

**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not
combined with CP. This leads to high peak memory usage and diverging
loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8
LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training
  • Loading branch information
XilunWu authored Dec 11, 2024
1 parent cb633e3 commit 40a0873
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 40 deletions.
15 changes: 10 additions & 5 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,16 @@ def build_mesh(self, device_type):
if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

if self.cp > 1:
if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP
mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
elif self.dp_shard > 1: # FSDP
mesh["dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
# Mesh for param sharding
dp_shard_cp_mesh_dim_name = []
if self.dp_shard_enabled:
dp_shard_cp_mesh_dim_name.append("dp_shard")

if self.cp_enabled:
dp_shard_cp_mesh_dim_name.append("cp")

if dp_shard_cp_mesh_dim_name != []:
mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp")

return mesh

Expand Down
48 changes: 13 additions & 35 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_if_feature_in_pytorch


def parallelize_llama(
Expand Down Expand Up @@ -78,44 +77,23 @@ def parallelize_llama(
apply_compile(model)

if (
parallel_dims.dp_shard_enabled
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
): # apply FSDP or HSDP, potentially with Context Parallel
try:
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
dp_mesh = (
world_mesh["dp_cp"]
if parallel_dims.cp_enabled
else world_mesh[(*dp_mesh_dim_names,)]
)
except IndexError:
# note: this is a workaround of the above logic for old pytorch version
# where https://github.com/pytorch/pytorch/pull/138945 is not included
# throw a warning to encourage users to upgrade to a newer pytorch version
check_if_feature_in_pytorch(
"DeviceMesh flattening over 3D+ meshes",
"https://github.com/pytorch/pytorch/pull/138945",
"2.6.0.dev20241030",
)
# TODO: remove this workaround once PyTorch 2.6 is released
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
# note that mesh can only be flattened from the finest-grained mesh dimensions
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)

if not parallel_dims.dp_shard_enabled and parallel_dims.dp_replicate_enabled:
# Composability of DDP + CP is not supported.
raise RuntimeError("Composability of DDP + CP is not supported.")

# the mesh dim names of which the model params are sharded on
dp_mesh_dim_names = []
if parallel_dims.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")

dp_mesh_dim_names.append("dp_shard_cp")

apply_fsdp(
model,
dp_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
pp_enabled=parallel_dims.pp_enabled,
Expand Down

0 comments on commit 40a0873

Please sign in to comment.