From 726efc4ee6942beed4070f7f69e73c80ac1d0d95 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 7 Jan 2025 15:10:36 +0100 Subject: [PATCH 1/4] Do not aggregate the losses since last log step --- torchtitan/utils.py | 4 ++-- train.py | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 88663c00..daf7050f 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -34,12 +34,12 @@ def get_device_info(): device_type, device_module = get_device_info() -def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: +def dist_max(x: Union[int, float, torch.Tensor], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).to(device_type) return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item() -def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: +def dist_mean(x: Union[int, float, torch.Tensor], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).to(device_type) return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item() diff --git a/train.py b/train.py index 2874d0d5..314c27aa 100644 --- a/train.py +++ b/train.py @@ -228,7 +228,6 @@ def loss_fn(pred, labels): ) # variables used to keep info for metrics logging - losses_since_last_log = [] ntokens_since_last_log = 0 data_loading_times = [] time_last_log = time.perf_counter() @@ -327,26 +326,22 @@ def loss_fn(pred, labels): # it issues a single all-reduce for all parameters at once for better performance float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) - losses_since_last_log.append(loss) - # log metrics if ( train_state.step == 1 or train_state.step % job_config.metrics.log_freq == 0 ): - losses = [loss.item() for loss in losses_since_last_log] - avg_loss, max_loss = sum(losses) / len(losses), max(losses) if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled ): global_avg_loss, global_max_loss = ( - utils.dist_mean(avg_loss, world_mesh["dp_cp"]), - utils.dist_max(max_loss, world_mesh["dp_cp"]), + utils.dist_mean(loss, world_mesh["dp_cp"]), + utils.dist_max(loss, world_mesh["dp_cp"]), ) else: - global_avg_loss, global_max_loss = avg_loss, max_loss + global_avg_loss = global_max_loss = loss.item() # update train state train_state.log_steps.append(train_state.step) @@ -396,7 +391,6 @@ def loss_fn(pred, labels): f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) - losses_since_last_log.clear() ntokens_since_last_log = 0 data_loading_times.clear() time_last_log = time.perf_counter() From 7a3e1fb059192e626cbc555f8316919457e8f6a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 9 Jan 2025 17:01:51 +0100 Subject: [PATCH 2/4] Dtensor --- train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train.py b/train.py index 314c27aa..d8c2a409 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,7 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor import DTensor from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState @@ -336,6 +337,9 @@ def loss_fn(pred, labels): or parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled ): + loss = loss.detach() + if isinstance(loss, DTensor): + loss = loss.full_tensor() global_avg_loss, global_max_loss = ( utils.dist_mean(loss, world_mesh["dp_cp"]), utils.dist_max(loss, world_mesh["dp_cp"]), From d16d42ff2aee10554098046361c1e4ea67a55068 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 10 Jan 2025 12:39:13 +0100 Subject: [PATCH 3/4] Review comments --- torchtitan/utils.py | 18 ++++++++++++------ train.py | 3 --- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index daf7050f..06f07090 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -34,14 +34,20 @@ def get_device_info(): device_type, device_module = get_device_info() -def dist_max(x: Union[int, float, torch.Tensor], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).to(device_type) - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item() +def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: + if isinstance(x, DTensor): + # DTensors do not support functional collectives + x = x.full_tensor() + assert x.numel() == 1 # required by `.item()` + return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() -def dist_mean(x: Union[int, float, torch.Tensor], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).to(device_type) - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item() +def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float: + return dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh) + + +def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float: + return dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh) def _warn_overwrite_env(env, val): diff --git a/train.py b/train.py index d8c2a409..6a09f60e 100644 --- a/train.py +++ b/train.py @@ -11,7 +11,6 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.tensor import DTensor from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState @@ -338,8 +337,6 @@ def loss_fn(pred, labels): or parallel_dims.cp_enabled ): loss = loss.detach() - if isinstance(loss, DTensor): - loss = loss.full_tensor() global_avg_loss, global_max_loss = ( utils.dist_mean(loss, world_mesh["dp_cp"]), utils.dist_max(loss, world_mesh["dp_cp"]), From 9fd2b5172bcb9a5c49d18c75c3ab833abe990819 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Jan 2025 16:44:11 +0100 Subject: [PATCH 4/4] Address comments --- torchtitan/utils.py | 2 +- train.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 06f07090..c9dcf2fa 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -36,7 +36,7 @@ def get_device_info(): def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: if isinstance(x, DTensor): - # DTensors do not support functional collectives + # functional collectives do not support DTensor inputs x = x.full_tensor() assert x.numel() == 1 # required by `.item()` return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() diff --git a/train.py b/train.py index 6a09f60e..4fb3a56b 100644 --- a/train.py +++ b/train.py @@ -291,10 +291,11 @@ def loss_fn(pred, labels): pp_schedule.step() # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( - torch.mean(torch.stack(losses)) + torch.mean(torch.stack(losses)).to(device) if is_last_stage - else torch.Tensor([-1.0]) + else torch.tensor([-1.0], device=device) ) else: # Non-PP forward / backward