Skip to content

Commit

Permalink
fix loss with CP enabled
Browse files Browse the repository at this point in the history
ghstack-source-id: 1564496e6320150e0180a579a874ed5636b231c2
Pull Request resolved: #761
  • Loading branch information
tianyu-l committed Dec 27, 2024
1 parent f6a9daa commit 6b03ffa
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 42 deletions.
27 changes: 13 additions & 14 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ def estimate_memory(job_config: JobConfig):
# Get the world size
world_size = int(os.environ["WORLD_SIZE"])

# if tp > or pp > 1, we exit
if (
job_config.training.tensor_parallel_degree > 1
or job_config.experimental.pipeline_parallel_degree > 1
):
logger.info(
"Tensor parallelism and pipeline parallelism are not supported yet."
)
return

# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
Expand Down Expand Up @@ -73,6 +63,19 @@ def estimate_memory(job_config: JobConfig):
enable_loss_parallel=not job_config.training.disable_loss_parallel,
)

# only FSDP and HSDP are supported
if (
(parallel_dims.dp_replicate_enabled and not parallel_dims.dp_shard_enabled)
or parallel_dims.tp_enabled
or parallel_dims.pp_enabled
or parallel_dims.cp_enabled
):
logger.warning("DDP, TP, PP, CP are not supported yet.")
return
if not parallel_dims.dp_shard_enabled:
logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.")
return

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)

Expand All @@ -85,10 +88,6 @@ def estimate_memory(job_config: JobConfig):
# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

if not parallel_dims.dp_enabled:
logger.info("Data parallelism is not enabled. Skipping memory estimation.")
return

model_name = job_config.model.name

# build tokenizer
Expand Down
16 changes: 14 additions & 2 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,18 @@ def build_test_list():
"fsdp+cp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=1",
"--training.data_parallel_replicate_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"HSDP+CP (with dp_shard)",
"hsdp+cp_without_dp_shard",
ngpu=4,
),
OverrideDefinitions(
[
[
Expand All @@ -346,8 +358,8 @@ def build_test_list():
"--experimental.context_parallel_degree=2",
]
],
"HSDP+CP",
"hsdp+cp",
"HSDP+CP (without dp_shard)",
"hsdp+cp_with_dp_shard",
ngpu=8,
),
OverrideDefinitions(
Expand Down
31 changes: 18 additions & 13 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,32 @@ def build_mesh(self, device_type):

# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
# Mesh for param sharding
dp_shard_cp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")

dp_cp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")
dp_shard_cp_mesh_dim_names.append("dp_shard")
dp_cp_mesh_dim_names.append("dp_shard")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")

if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

# Mesh for param sharding
dp_shard_cp_mesh_dim_name = []
if self.dp_shard_enabled:
dp_shard_cp_mesh_dim_name.append("dp_shard")

if self.cp_enabled:
dp_shard_cp_mesh_dim_name.append("cp")

if dp_shard_cp_mesh_dim_name != []:
mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp")
if dp_shard_cp_mesh_dim_names != []:
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
mesh_dim_name="dp_shard_cp"
)
if dp_cp_mesh_dim_names != []:
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")

return mesh

Expand Down
13 changes: 3 additions & 10 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,10 @@ def parallelize_llama(
if (
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
): # apply FSDP or HSDP, potentially with Context Parallel

if not parallel_dims.dp_shard_enabled and parallel_dims.dp_replicate_enabled:
# Composability of DDP + CP is not supported.
raise RuntimeError("Composability of DDP + CP is not supported.")

# the mesh dim names of which the model params are sharded on
dp_mesh_dim_names = []
if parallel_dims.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")

dp_mesh_dim_names.append("dp_shard_cp")
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
else:
dp_mesh_dim_names = ("dp_shard_cp",)

apply_fsdp(
model,
Expand Down
10 changes: 7 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,14 @@ def loss_fn(pred, labels):
):
losses = [loss.item() for loss in losses_since_last_log]
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
if parallel_dims.dp_enabled:
if (
parallel_dims.dp_replicate_enabled
or parallel_dims.dp_shard_enabled
or parallel_dims.cp_enabled
):
global_avg_loss, global_max_loss = (
utils.dist_mean(avg_loss, dp_mesh),
utils.dist_max(max_loss, dp_mesh),
utils.dist_mean(avg_loss, world_mesh["dp_cp"]),
utils.dist_max(max_loss, world_mesh["dp_cp"]),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss
Expand Down

0 comments on commit 6b03ffa

Please sign in to comment.