-
Notifications
You must be signed in to change notification settings - Fork 416
Do not aggregate the losses since last log step #779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -228,7 +228,6 @@ def loss_fn(pred, labels): | |
) | ||
|
||
# variables used to keep info for metrics logging | ||
losses_since_last_log = [] | ||
ntokens_since_last_log = 0 | ||
data_loading_times = [] | ||
time_last_log = time.perf_counter() | ||
|
@@ -292,10 +291,11 @@ def loss_fn(pred, labels): | |
pp_schedule.step() | ||
|
||
# accumulate losses across pipeline microbatches | ||
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
loss = ( | ||
torch.mean(torch.stack(losses)) | ||
torch.mean(torch.stack(losses)).to(device) | ||
if is_last_stage | ||
else torch.Tensor([-1.0]) | ||
else torch.tensor([-1.0], device=device) | ||
) | ||
else: | ||
# Non-PP forward / backward | ||
|
@@ -327,26 +327,23 @@ def loss_fn(pred, labels): | |
# it issues a single all-reduce for all parameters at once for better performance | ||
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) | ||
|
||
losses_since_last_log.append(loss) | ||
|
||
# log metrics | ||
if ( | ||
train_state.step == 1 | ||
or train_state.step % job_config.metrics.log_freq == 0 | ||
): | ||
losses = [loss.item() for loss in losses_since_last_log] | ||
avg_loss, max_loss = sum(losses) / len(losses), max(losses) | ||
if ( | ||
parallel_dims.dp_replicate_enabled | ||
or parallel_dims.dp_shard_enabled | ||
or parallel_dims.cp_enabled | ||
): | ||
loss = loss.detach() | ||
global_avg_loss, global_max_loss = ( | ||
utils.dist_mean(avg_loss, world_mesh["dp_cp"]), | ||
utils.dist_max(max_loss, world_mesh["dp_cp"]), | ||
utils.dist_mean(loss, world_mesh["dp_cp"]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code passes PP and FSDP tests, but fails the PP+FSDP test with "RuntimeError: No backend type associated with device type cpu". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I put back the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i want to understand why loss is not a cuda tensor in the PP case first. @H-Huang can you take alook at this and see if its intentional? For now i'd be ok with a workaround with a TODO comment about following up on the PP case mentioned above
wdyt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I implemented the suggested workaround |
||
utils.dist_max(loss, world_mesh["dp_cp"]), | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
global_avg_loss, global_max_loss = avg_loss, max_loss | ||
global_avg_loss = global_max_loss = loss.item() | ||
|
||
# update train state | ||
train_state.log_steps.append(train_state.step) | ||
|
@@ -396,7 +393,6 @@ def loss_fn(pred, labels): | |
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" | ||
) | ||
|
||
losses_since_last_log.clear() | ||
ntokens_since_last_log = 0 | ||
data_loading_times.clear() | ||
time_last_log = time.perf_counter() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am surprised by "DTensors do not support functional collectives". I thought DTensor used functional collectives under the hood already, as part of their effort to support compiler tracing.
But in any case, doing a .full_tensor() and then using funcol is probably a worse way than calling some dtensor method that triggers dtensor to do a reduction, to save on one collective?
cc @XilunWu @yifuwang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, also, what is the deal with a DTensor that has one item? (see assertion next line) - this feels like a pretty weird case to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @carmocca meant that you can't pass a DTensor to functional collective op? I don't think we've investigated the semantic yet. Obviously it doesn't make too much sense to just perform the collective on either the local tensor or the full tensor (e.g., performing an all-gather on a replicated tensor). Something we can think about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a request for it in #779 (comment). Maybe renaming the functions to
scalar_reduce
,scalar_max
... would make the intent clearer?