From 6b03ffa321b2208fc54fd61b7e0de204dff5dc9e Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 26 Dec 2024 15:40:38 -0800 Subject: [PATCH] fix loss with CP enabled ghstack-source-id: 1564496e6320150e0180a579a874ed5636b231c2 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/761 --- scripts/estimate/estimation.py | 27 ++++++++--------- tests/integration_tests.py | 16 ++++++++-- torchtitan/parallelisms/parallel_dims.py | 31 ++++++++++++-------- torchtitan/parallelisms/parallelize_llama.py | 13 ++------ train.py | 10 +++++-- 5 files changed, 55 insertions(+), 42 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index c4d1e9c9..fbdc9f09 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -33,16 +33,6 @@ def estimate_memory(job_config: JobConfig): # Get the world size world_size = int(os.environ["WORLD_SIZE"]) - # if tp > or pp > 1, we exit - if ( - job_config.training.tensor_parallel_degree > 1 - or job_config.experimental.pipeline_parallel_degree > 1 - ): - logger.info( - "Tensor parallelism and pipeline parallelism are not supported yet." - ) - return - # fake tensor doesn't work with fused rmsnorm if ( job_config.model.norm_type == "fused_rmsnorm" @@ -73,6 +63,19 @@ def estimate_memory(job_config: JobConfig): enable_loss_parallel=not job_config.training.disable_loss_parallel, ) + # only FSDP and HSDP are supported + if ( + (parallel_dims.dp_replicate_enabled and not parallel_dims.dp_shard_enabled) + or parallel_dims.tp_enabled + or parallel_dims.pp_enabled + or parallel_dims.cp_enabled + ): + logger.warning("DDP, TP, PP, CP are not supported yet.") + return + if not parallel_dims.dp_shard_enabled: + logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.") + return + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) @@ -85,10 +88,6 @@ def estimate_memory(job_config: JobConfig): # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") - if not parallel_dims.dp_enabled: - logger.info("Data parallelism is not enabled. Skipping memory estimation.") - return - model_name = job_config.model.name # build tokenizer diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 41c9d209..d6655331 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -338,6 +338,18 @@ def build_test_list(): "fsdp+cp", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=1", + "--training.data_parallel_replicate_degree=2", + "--experimental.context_parallel_degree=2", + ] + ], + "HSDP+CP (with dp_shard)", + "hsdp+cp_without_dp_shard", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -346,8 +358,8 @@ def build_test_list(): "--experimental.context_parallel_degree=2", ] ], - "HSDP+CP", - "hsdp+cp", + "HSDP+CP (without dp_shard)", + "hsdp+cp_with_dp_shard", ngpu=8, ), OverrideDefinitions( diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 9af771a2..13d066a8 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -62,27 +62,32 @@ def build_mesh(self, device_type): # Create all the submesh here to ensure all required process groups are # initialized: - # Mesh for data loading + # Mesh for data loading (no communication on this mesh) dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + if self.dp_replicate_enabled: dp_mesh_dim_names.append("dp_replicate") - + dp_cp_mesh_dim_names.append("dp_replicate") if self.dp_shard_enabled: dp_mesh_dim_names.append("dp_shard") + dp_shard_cp_mesh_dim_names.append("dp_shard") + dp_cp_mesh_dim_names.append("dp_shard") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") if dp_mesh_dim_names != []: mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - - # Mesh for param sharding - dp_shard_cp_mesh_dim_name = [] - if self.dp_shard_enabled: - dp_shard_cp_mesh_dim_name.append("dp_shard") - - if self.cp_enabled: - dp_shard_cp_mesh_dim_name.append("cp") - - if dp_shard_cp_mesh_dim_name != []: - mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp") + if dp_shard_cp_mesh_dim_names != []: + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( + mesh_dim_name="dp_shard_cp" + ) + if dp_cp_mesh_dim_names != []: + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") return mesh diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fce22c48..9728569a 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -79,17 +79,10 @@ def parallelize_llama( if ( parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled ): # apply FSDP or HSDP, potentially with Context Parallel - - if not parallel_dims.dp_shard_enabled and parallel_dims.dp_replicate_enabled: - # Composability of DDP + CP is not supported. - raise RuntimeError("Composability of DDP + CP is not supported.") - - # the mesh dim names of which the model params are sharded on - dp_mesh_dim_names = [] if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - - dp_mesh_dim_names.append("dp_shard_cp") + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) apply_fsdp( model, diff --git a/train.py b/train.py index 0b69690a..2874d0d5 100644 --- a/train.py +++ b/train.py @@ -336,10 +336,14 @@ def loss_fn(pred, labels): ): losses = [loss.item() for loss in losses_since_last_log] avg_loss, max_loss = sum(losses) / len(losses), max(losses) - if parallel_dims.dp_enabled: + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): global_avg_loss, global_max_loss = ( - utils.dist_mean(avg_loss, dp_mesh), - utils.dist_max(max_loss, dp_mesh), + utils.dist_mean(avg_loss, world_mesh["dp_cp"]), + utils.dist_max(max_loss, world_mesh["dp_cp"]), ) else: global_avg_loss, global_max_loss = avg_loss, max_loss