-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do not aggregate the losses since last log step #779
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
import torch | ||
|
||
from torch.distributed.elastic.multiprocessing.errors import record | ||
from torch.distributed.tensor import DTensor | ||
|
||
from torchtitan import utils | ||
from torchtitan.checkpoint import CheckpointManager, TrainState | ||
|
@@ -228,7 +229,6 @@ def loss_fn(pred, labels): | |
) | ||
|
||
# variables used to keep info for metrics logging | ||
losses_since_last_log = [] | ||
ntokens_since_last_log = 0 | ||
data_loading_times = [] | ||
time_last_log = time.perf_counter() | ||
|
@@ -327,26 +327,25 @@ def loss_fn(pred, labels): | |
# it issues a single all-reduce for all parameters at once for better performance | ||
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) | ||
|
||
losses_since_last_log.append(loss) | ||
|
||
# log metrics | ||
if ( | ||
train_state.step == 1 | ||
or train_state.step % job_config.metrics.log_freq == 0 | ||
): | ||
losses = [loss.item() for loss in losses_since_last_log] | ||
avg_loss, max_loss = sum(losses) / len(losses), max(losses) | ||
if ( | ||
parallel_dims.dp_replicate_enabled | ||
or parallel_dims.dp_shard_enabled | ||
or parallel_dims.cp_enabled | ||
): | ||
loss = loss.detach() | ||
if isinstance(loss, DTensor): | ||
loss = loss.full_tensor() | ||
global_avg_loss, global_max_loss = ( | ||
utils.dist_mean(avg_loss, world_mesh["dp_cp"]), | ||
utils.dist_max(max_loss, world_mesh["dp_cp"]), | ||
utils.dist_mean(loss, world_mesh["dp_cp"]), | ||
utils.dist_max(loss, world_mesh["dp_cp"]), | ||
Comment on lines
+344
to
+345
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These use functional collectives under the hood so there shouldn't be issues with passing in the same tensor reference
Comment on lines
+344
to
+345
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems with Tensor Parallel, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added the detach and full_tensor calls |
||
) | ||
else: | ||
global_avg_loss, global_max_loss = avg_loss, max_loss | ||
global_avg_loss = global_max_loss = loss.item() | ||
|
||
# update train state | ||
train_state.log_steps.append(train_state.step) | ||
|
@@ -396,7 +395,6 @@ def loss_fn(pred, labels): | |
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" | ||
) | ||
|
||
losses_since_last_log.clear() | ||
ntokens_since_last_log = 0 | ||
data_loading_times.clear() | ||
time_last_log = time.perf_counter() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK these will result in no-ops if a tensor of with the same device type is passed.