Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not aggregate the losses since last log step #779

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def get_device_info():
device_type, device_module = get_device_info()


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
def dist_max(x: Union[int, float, torch.Tensor], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).to(device_type)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK these will result in no-ops if a tensor of with the same device type is passed.

return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item()


def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
def dist_mean(x: Union[int, float, torch.Tensor], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).to(device_type)
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item()

Expand Down
12 changes: 3 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def loss_fn(pred, labels):
)

# variables used to keep info for metrics logging
losses_since_last_log = []
ntokens_since_last_log = 0
data_loading_times = []
time_last_log = time.perf_counter()
Expand Down Expand Up @@ -327,26 +326,22 @@ def loss_fn(pred, labels):
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)

losses_since_last_log.append(loss)

# log metrics
if (
train_state.step == 1
or train_state.step % job_config.metrics.log_freq == 0
):
losses = [loss.item() for loss in losses_since_last_log]
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
if (
parallel_dims.dp_replicate_enabled
or parallel_dims.dp_shard_enabled
or parallel_dims.cp_enabled
):
global_avg_loss, global_max_loss = (
utils.dist_mean(avg_loss, world_mesh["dp_cp"]),
utils.dist_max(max_loss, world_mesh["dp_cp"]),
utils.dist_mean(loss, world_mesh["dp_cp"]),
utils.dist_max(loss, world_mesh["dp_cp"]),
Comment on lines +344 to +345
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These use functional collectives under the hood so there shouldn't be issues with passing in the same tensor reference

Comment on lines +344 to +345
Copy link
Contributor

@tianyu-l tianyu-l Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems with Tensor Parallel, the loss is a DTensor, which doesn't support functional collectives. Also we should not require gradients on this all-reduce.
Maybe it's still fine to do .item() outside as before? or use detach and full_tensor() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the detach and full_tensor calls

)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss
global_avg_loss = global_max_loss = loss.item()

# update train state
train_state.log_steps.append(train_state.step)
Expand Down Expand Up @@ -396,7 +391,6 @@ def loss_fn(pred, labels):
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
)

losses_since_last_log.clear()
ntokens_since_last_log = 0
data_loading_times.clear()
time_last_log = time.perf_counter()
Expand Down
Loading