Skip to content

Do not aggregate the losses since last log step #779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ def get_device_info():
device_type, device_module = get_device_info()


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).to(device_type)
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item()
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
if isinstance(x, DTensor):
# functional collectives do not support DTensor inputs
x = x.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am surprised by "DTensors do not support functional collectives". I thought DTensor used functional collectives under the hood already, as part of their effort to support compiler tracing.

But in any case, doing a .full_tensor() and then using funcol is probably a worse way than calling some dtensor method that triggers dtensor to do a reduction, to save on one collective?

cc @XilunWu @yifuwang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, also, what is the deal with a DTensor that has one item? (see assertion next line) - this feels like a pretty weird case to me

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @carmocca meant that you can't pass a DTensor to functional collective op? I don't think we've investigated the semantic yet. Obviously it doesn't make too much sense to just perform the collective on either the local tensor or the full tensor (e.g., performing an all-gather on a replicated tensor). Something we can think about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, also, what is the deal with a DTensor that has one item? (see assertion next line) - this feels like a pretty weird case to me

There was a request for it in #779 (comment). Maybe renaming the functions to scalar_reduce, scalar_max ... would make the intent clearer?

assert x.numel() == 1 # required by `.item()`
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()


def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).to(device_type)
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item()
def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float:
return dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh)


def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float:
return dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh)


def _warn_overwrite_env(env, val):
Expand Down
18 changes: 7 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def loss_fn(pred, labels):
)

# variables used to keep info for metrics logging
losses_since_last_log = []
ntokens_since_last_log = 0
data_loading_times = []
time_last_log = time.perf_counter()
Expand Down Expand Up @@ -292,10 +291,11 @@ def loss_fn(pred, labels):
pp_schedule.step()

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses))
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
else torch.Tensor([-1.0])
else torch.tensor([-1.0], device=device)
)
else:
# Non-PP forward / backward
Expand Down Expand Up @@ -327,26 +327,23 @@ def loss_fn(pred, labels):
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)

losses_since_last_log.append(loss)

# log metrics
if (
train_state.step == 1
or train_state.step % job_config.metrics.log_freq == 0
):
losses = [loss.item() for loss in losses_since_last_log]
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
if (
parallel_dims.dp_replicate_enabled
or parallel_dims.dp_shard_enabled
or parallel_dims.cp_enabled
):
loss = loss.detach()
global_avg_loss, global_max_loss = (
utils.dist_mean(avg_loss, world_mesh["dp_cp"]),
utils.dist_max(max_loss, world_mesh["dp_cp"]),
utils.dist_mean(loss, world_mesh["dp_cp"]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code passes PP and FSDP tests, but fails the PP+FSDP test with "RuntimeError: No backend type associated with device type cpu".
It seems only for PP+FSDP the loss is put back to CPU early as a tensor. Is this intentional?
@wconstab

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I put back the .to(device_type) in dist_reduce in the meantime? Whether it's intentional or not, it's probably a more hairy change

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i want to understand why loss is not a cuda tensor in the PP case first. @H-Huang can you take alook at this and see if its intentional?

For now i'd be ok with a workaround with a TODO comment about following up on the PP case mentioned above

  • perhaps don't modify dist_* functions, instead put .to() into the train.py code where pp merges losses per microbatch into losses via torch.mean?

wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented the suggested workaround

utils.dist_max(loss, world_mesh["dp_cp"]),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss
global_avg_loss = global_max_loss = loss.item()

# update train state
train_state.log_steps.append(train_state.step)
Expand Down Expand Up @@ -396,7 +393,6 @@ def loss_fn(pred, labels):
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
)

losses_since_last_log.clear()
ntokens_since_last_log = 0
data_loading_times.clear()
time_last_log = time.perf_counter()
Expand Down
Loading